In [None]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner.tuning import Tuner
import torch
import torch.nn as nn

from trailcaml import TrailCaML
from datasets.trailcamera import TrailCameraDataset

In [None]:
data_set = TrailCameraDataset(size=(240, 240))
train, valid, test = data_set.dataloader_splits(num_workers=4, batch_size=32)

In [None]:
x, y = next(iter(train))
x.shape, y.shape

In [None]:
logger = TensorBoardLogger(save_dir="lightning_logs")

In [None]:
epochs = 10
trainer = L.Trainer(
    max_epochs=epochs,
    logger=logger,
    log_every_n_steps=6,
    callbacks=[
        # Save best models
        ModelCheckpoint(
            monitor='val_loss',
            mode='min',
            save_top_k=3,
            filename='{epoch}-{val_loss:.2f}'
        ),
        # Stop if not improving
        EarlyStopping(
            monitor='val_loss',
            patience=5,
            mode='min'
        )
    ],
    gradient_clip_val=0.5,
    deterministic=True
)

In [None]:
tcml = TrailCaML(lr=1e-3, fine_tune_after=5, img_size=(240, 240))

In [None]:
trainer.fit(model=tcml, train_dataloaders=train, val_dataloaders=valid)

In [None]:
validation_result = trainer.validate(tcml, valid)

In [None]:
test_result = trainer.test(tcml, test)

In [None]:
tcml.hparams