In [1]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner.tuning import Tuner
import torch
import torch.nn as nn

from trailcaml import TrailCaML
from datasets.trailcamera import DEFAULT_LOADERS
from preprocessors import ImagePreprocessor

In [2]:
train, valid, test = DEFAULT_LOADERS()

In [3]:
tcml = TrailCaML(lr=1e-4)

In [4]:
logger = TensorBoardLogger(save_dir="lightning_logs")

In [5]:
epochs = 30
trainer = L.Trainer(
    max_epochs=epochs,
    logger=logger,
    log_every_n_steps=6,
    callbacks=[
        # Save best models
        ModelCheckpoint(
            monitor='val_loss',
            mode='min',
            save_top_k=3,
            filename='{epoch}-{val_loss:.2f}'
        ),
        # Stop if not improving
        EarlyStopping(
            monitor='val_loss',
            patience=5,
            mode='min'
        )
    ],
    gradient_clip_val=0.5,
    deterministic=True
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model=tcml, train_dataloaders=train, val_dataloaders=valid)


  | Name             | Type              | Params | Mode 
---------------------------------------------------------------
0 | loss_fn          | BCEWithLogitsLoss | 0      | train
1 | backbone         | ResNet            | 11.2 M | train
2 | accuracy_metrics | ModuleDict        | 0      | train
---------------------------------------------------------------
5.2 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.689    Total estimated model params size (MB)
73        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                     | 0/? [00:00<?, ?it/s]

/home/hayden/code/trailcaml/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/hayden/code/trailcaml/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [7]:
validation_result = trainer.validate(tcml, valid)

Validation: |                                                          | 0/? [00:00<?, ?it/s]

In [8]:
test_result = trainer.test(tcml, test)

/home/hayden/code/trailcaml/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |                                                             | 0/? [00:00<?, ?it/s]

In [9]:
tcml.hparams

"fine_tune_after": 5
"lr":              1e-06
"lr_reduction":    100.0

In [10]:
test_batch = next(iter(test))

In [15]:
x, y = test_batch
preds = tcml(x)
((torch.sigmoid(preds) > 0.5) == y.bool()).float().mean()

tensor(0.9609)