In [1]:
import sys
sys.path.append('../../')
from cassetta.losses import LoadableMSE
from cassetta.optimizers import LoadableAdam
from cassetta.models.segmentation import SegNet
from cassetta.datasets import DummySupervisedDataset
from cassetta.training.trainers import SimpleSupervisedTrainer, TrainerConfig

# 1 Common variables
We will begin by defining some standard and commonly used variables:
1. `device`: This is the device that will be used to preform the data generation and training operations.
2. `experiment_dir`: The directory in which you want your trained model, associated checkpoints, and experiment parameters to live.

In [2]:
device = "cpu"
experiment_dir = "../../models/version_1"

# 2 Dataset
Preparing the dataset for supervised training means establishing a dataset that outputs two tensors: the input tensor (x) and the ground truth tensor (y).

Preparing the dataset for supervised training involves several critical steps to ensure that the data is in the right format and quality for the model to learn effectively. Once these criteria are fulfilled, you should have a dataset that outputs two tensors:
- **Input Tensor (X)**: Represents the features or input data fed into the model
- **Ground Truth Tensor (y)**: Represents the target labels or true values that the model aims to predict.

For expediency, we will use a dataset that generates random data for both `X` and `y` called `DummySupervisedDataset`. Note that we've set `x_shape` = `y_shape`. This is because we want to use a UNet for segmentation, and so we'd like our ground truth (and thus, predictions) to be the same size as our input.

In [3]:
dataset = DummySupervisedDataset(
    x_shape=(1, 32, 32, 32),
    y_shape=(1, 32, 32, 32),
    n_classes=None,
    device=device
)

# 3 Model-Specific Inits
We will next initialize the model, the loss function, and the optimzer. Converting the model parameters into a list is quite odd, but we do it so that we can save the model later (not doing so would result in a non-pickleable error)

In [4]:
# Init segmentation network and send to device
model = SegNet(3, 1, 1).to(device)
# Building loss function (more in `cassetta/losses`) 
loss = LoadableMSE()
# Building optimizer (more in `cassetta/optimizers`)
optimizer = LoadableAdam(list(model.parameters()))

# 4 Trainer-Specific Inits
Now, we will build the trainer configuration object (data container) for storing all trainer-related arguments for customizing the training process. We will also initialize the trainer itself using the loss and the dataset that we've already defined.

In [5]:
# Set up all training related configurations
trainer_config = TrainerConfig(
    experiment_dir=experiment_dir,
    lr=1e-3,
    batch_size=16,
    nb_epochs=2,
    refresh_experiment_dir=True
)

# Build trainer
trainer = SimpleSupervisedTrainer(
    model=model,
    optimizer=optimizer,
    loss=loss,
    dataset=dataset,
    trainer_config=trainer_config
)

# 5 Training
We have successfully set everything up! Now, we run the training loop by simply calling the `train()` method.

In [6]:
trainer.train()

All contents of ../../models/version_1 have been deleted.


# 7 Saving
We can save our trainer (and all of its attributes) with cassetta's custom serialization methods! This makes loading the model much easier at test time.

In [27]:
trainer.save(f'{experiment_dir}/model.pt')