### Setup

Load the data and import the necessary libraries

In [None]:
!unzip "leaf_disease" -d "data/leaf_disease"

In [None]:
from functools import partial

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from timm.optim import AdaBelief
from timm.scheduler import CosineLRScheduler

from segmentator.leaf_disease_dataset import LeafDiseaseDataModule
from segmentator.model import PcsModel
from segmentator.models.unet import get_unet

### Datamodule

Lightning's preferred way of handling the data is via the abstraction called 'datamodule'.

In [None]:
datamodule = LeafDiseaseDataModule(root="data/leaf_disease", batch_size=16, random_state=1337, num_workers=4)

In [None]:
datamodule.prepare_data()

### Model

The model is initialized using partials because the main pipeline has all of this handled by Hydra.

In [None]:
model_instance = get_unet(backbone_name="convnext_tiny", num_classes=2)
model = PcsModel(model_instance=model_instance,
                 optimizer_partial=partial(AdaBelief, weight_decay=0.000001, lr=0.0001),
                 scheduler_partial=partial(CosineLRScheduler, t_initial=10,
                                                                    lr_min=0.0000003,
                                                                    cycle_decay=0.8,
                                                                    warmup_t=5,
                                                                    warmup_lr_init=0.00001))


### Training

Once again, the training loop is all being done by Lightning.

In [None]:
trainer = Trainer(accelerator="gpu", devices=1, max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss", patience=3, verbose=True, mode="min"),
                                                                            LearningRateMonitor(logging_interval="epoch")],
                  precision=16, accumulate_grad_batches=4, auto_scale_batch_size=None, auto_lr_find=False)

In [None]:
trainer.fit(model, datamodule=datamodule)

In [None]:
datamodule.train_dataset[0]["mask"].shape