
**Train a Stardist model from scratch with a pre-trained encoder**

In [None]:
import pytorch_lightning as pl
import cellseg_models_pytorch as csmp

# Set training, validation and test datasets.
train_ds = csmp.datasets.SegmentationFolderDataset(
    path="/path/to/train/im_patches",
    mask_path="/path/to/train/mask_patches",
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["stardist", "dist"],
    return_sem=False,
    return_type=True,
    return_inst=False,
    return_weight=False,
    normalization="percentile",
)
valid_ds = csmp.datasets.SegmentationFolderDataset(
    path="/path/to/valid/im_patches",
    mask_path="/path/to/valid/mask_patches",
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["stardist", "dist"],
    return_sem=False,
    return_type=True,
    return_inst=False,
    return_weight=False,
    normalization="percentile",
)
test_ds = csmp.datasets.SegmentationFolderDataset(
    path="/path/to/test/im_patches",
    mask_path="/path/to/test/mask_patches",
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["stardist", "dist"],
    return_sem=False,
    return_type=True,
    return_inst=False,
    return_weight=False,
    normalization="percentile",
)

# define a lightning datamodule
datamodule = csmp.datamodules.CustomDataModule(
    [train_ds, valid_ds, test_ds],
    batch_size=8,
    num_workers=8
)


# Define the model and lightning experiment (LightningModule).
model = csmp.models.stardist_base_multiclass(n_rays=32, type_classes=5)
experiment = csmp.training.SegmentationExperiment(
    model=model,
    branch_metrics={"dist": [None], "stardist":[None], "type": ["miou"]},
    branch_losses={"dist": "ssim_mse", "stardist": "ssim_mse", "type": "ce_dice"},
    optimizer="adamp",
    lookahead=False,
    scheduler="cosine_annealing"
)

# Lightning callbacks
callbacks = []
ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath="/path/to/checkpoint_dir/",
    save_top_k=1,
    save_last=True,
    verbose=True,
    monitor='val_loss',
    mode='min',
)
callbacks.append(ckpt_callback)

# Lightning training
trainer = pl.Trainer(
    max_epochs=15,
    gpus=1,
    callbacks=callbacks,
    profiler="simple",
    move_metrics_to_cpu=True,
)

In [None]:
# Train
trainer.fit(model=experiment, datamodule=datamodule)