# Training models with BioEngine

In [None]:
import micropip
await micropip.install('pyotritonclient')

import io
import asyncio
import os
from pyotritonclient import SequenceExcutor, execute
import numpy as np
import pickle
import imageio
from js import fetch

async def download_data(url):
    response = await fetch(url)
    bytes = await response.arrayBuffer()
    bytes = bytes.to_py()
    buffer = io.BytesIO(bytes)
    return pickle.load(buffer)

test_samples = await download_data("https://raw.githubusercontent.com/imjoy-team/imjoy-tutorials/master/2-bioengine/test_samples_4.pkl")
train_samples = await download_data("https://raw.githubusercontent.com/imjoy-team/imjoy-tutorials/master/2-bioengine/train_samples_4.pkl")
print("Dataset downloaded", len(train_samples), len(test_samples))

(image, labels, info) = train_samples[0]

## Train a cellpose model

In [None]:

async def train(model_id=102, epochs=1, model_token = None, pretrained_model = "cyto"):
    """
    Train a model through the BioEngine
    # set pretrained_model to None if you want to train from scratch
    # set model_token to a string if you want to protect the model
    # from overwriting by other users
    """
    seq = SequenceExcutor(
        server_url="https://ai.imjoy.io/triton",
        model_name="cellpose-train",
        decode_json=True,
        sequence_id=model_id,
    )
    for epoch in range(epochs):
        losses = []
        for (image, labels, info) in train_samples:
            inputs = [
                image.astype("float32"),
                labels.astype("uint16"),
                {
                    "steps": 16,
                    "pretrained_model": pretrained_model,
                    "resume": True,
                    "model_token": model_token,
                    "channels": [1, 2],
                    "diam_mean": 30,
                },
            ]
            result = await seq.step(inputs, select_outputs=["info"])
            losses.append(result["info"][0]["loss"])
        avg_loss = np.array(losses).mean()
        print(f"Epoch {epoch}  loss={avg_loss}")

    valid_image = test_samples[0][0].astype("float32")
    valid_labels = np.zeros_like(labels).astype("uint16")
    result = await seq.end(
        [
            valid_image,
            valid_labels,
            {
                "resume": True,
                "model_token": model_token,
                "channels": [1, 2],
                "diameter": 100.0,
                "model_format": "bioimageio",
            },
        ],
        decode_json=True,
        select_outputs=["model", "info"],
    )
    # Save the weights
    model_package = result["model"][0]
    filename = result["info"][0]["model_files"][0]
    with open(filename, "wb") as fil:
        fil.write(model_package)
    print(f"Model package saved to {filename}")

await train()

# Use the trained model for prediction

In [None]:
async def predict():
    # Start the prediction
    seq = SequenceExcutor(
        server_url="https://ai.imjoy.io/triton",
        model_name="cellpose-predict",
        decode_json=True,
        sequence_id=model_id,
    )
    for i, sample in enumerate(test_samples):
        inputs = [sample[0].astype("float32"), {"channels": [1, 2], "diameter": 100}]
        results = await seq.step(inputs, select_outputs=["mask"])
        imageio.imwrite(f"test_result_{i}.png", results["mask"].astype("uint8"))
        print(results["mask"].shape, results["mask"].mean())

    await seq.end()

await predict()