# Setup for Google Colab
(Skip this step if you are running in a Jupyter notebook)

**Note: please make sure you switched on GPU runtime by choosing "Runtime -> Change runtime type -> Hardware accelerator -> GPU" and click "Save"**

In [None]:
!git clone https://github.com/imjoy-team/imjoy-interactive-segmentation.git
%cd imjoy-interactive-segmentation
!pip install -r requirements.txt
!python3 -m ipykernel install --user --name imjoy-interactive-ml --display-name "ImJoy Interactive ML"

# Setup for Jupyter notebook
(Skip this step for Google Colab)

Please make sure you installed the ImJoy extension by running `pip install imjoy imjoy-jupyter-extension`. After installing, restart your Jupyter notebook.

**Note: Before start, please make sure you see an ImJoy icon in the toolbar.**


# Download example dataset

The dataset will be saved to `./data/hpa_dataset_v2`

In [None]:
!python download_example_dataset.py

# Start the interactive segmentation interface

In [None]:
from imjoy_plugin import start_interactive_segmentation

model_config = dict( type="cellpose",
                     model_dir='./data/hpa_dataset_v2/__models__',
                     use_gpu=True,
                     channels=[2, 3],
                     style_on=0,
                     batch_size=1,
                     default_diameter=100,
                     pretrained_model=False,
                     resume=False)

start_interactive_segmentation(model_config,
                               "./data/hpa_dataset_v2",
                               ["microtubules.png", "er.png", "nuclei.png"],
                               object_name="cell",
                               scale_factor=1.0)

# Interact with the trainer

In [None]:
from interactive_trainer import InteractiveTrainer
trainer = InteractiveTrainer.get_instance()

In [None]:
reports = trainer.get_reports()
import matplotlib.pyplot as plt
loss = [report['loss'] for report in reports]
plt.plot(loss)

In [None]:
reports

In [None]:
trainer.stop()

In [None]:
image, _, info = trainer.get_test_sample()
print(str(image.shape))
image.shape,info

In [None]:
x, y, info = trainer.get_training_sample()

In [None]:
import matplotlib.pyplot as plt
plt.imshow(x)
plt.figure()
plt.imshow(y)