In [12]:
import random
import torch

import numpy as np
from pathlib import Path
from PIL import Image

from torchvision.transforms import v2
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
import torch.nn.functional as F
from torchmetrics import Accuracy

from ba_dev.dataset import MammaliaDataSequence, MammaliaDataImage
from ba_dev.datamodule import MammaliaDataModule
from ba_dev.transform import ImagePipeline, BatchImagePipeline
from ba_dev.model import LightningModelImage
from ba_dev.utils import load_config_yaml

paths = load_config_yaml('../path_config.yml')


### Running Tests

In [2]:
stats = torch.load(paths['feature_stats'])

image_pipeline = ImagePipeline(
        pre_ops=[
            ('to_rgb', {}),
            ('crop_by_bb', {'crop_shape': 1.0})
            ],
        transform=v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((224, 224)),
            v2.Normalize(
                mean=stats['mean'],
                std=stats['std']
                )
            ])
        )

dataset_kwargs = {
        'path_labelfiles': paths['test_labels'],
        'path_to_dataset': paths['dataset'],
        'path_to_detector_output': paths['md_output'],
        }

datamodule = MammaliaDataModule(
                dataset_cls=MammaliaDataImage,
                dataset_kwargs=dataset_kwargs,
                n_folds=5,
                test_fold=0,
                image_pipeline=image_pipeline,
                augmented_image_pipeline=None,
                batch_size=32,
                num_workers=1,
                pin_memory=True,
                )

model = LightningModelImage(
            num_classes=datamodule.num_classes,
            class_weights=datamodule.class_weights,
            backbone_name='efficientnet_b0',
            backbone_pretrained=True,
            backbone_weights='DEFAULT',
            optimizer_name='AdamW',
            optimizer_kwargs={
                'lr': 1e-3,
                'weight_decay': 1e-5,
                'amsgrad': False
                },
            scheduler_name='CosineAnnealingLR',
            scheduler_kwargs={'T_max': 5},
            )


8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]


In [28]:
train_acc = Accuracy(task="multiclass", num_classes=4)

In [29]:
train_acc

MulticlassAccuracy()

In [16]:
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

In [34]:
batch

{'sample': tensor([[[[-7.8289e-01, -7.8781e-01, -7.7689e-01,  ..., -1.0733e-01,
            -1.0673e-01, -1.3620e-01],
           [-7.8289e-01, -8.0010e-01, -7.9269e-01,  ..., -1.3293e-01,
            -1.3253e-01, -1.0673e-01],
           [-7.8205e-01, -7.9927e-01, -7.9186e-01,  ..., -1.4623e-01,
            -1.3884e-01, -9.8167e-02],
           ...,
           [-1.3600e+00, -1.3515e+00, -1.3615e+00,  ...,  5.6469e-01,
             3.2700e-01,  1.6107e-01],
           [-1.3893e+00, -1.3893e+00, -1.3944e+00,  ...,  7.7370e-01,
             6.2795e-01,  4.9603e-01],
           [-1.4281e+00, -1.4281e+00, -1.4281e+00,  ...,  8.3627e-01,
             7.8916e-01,  7.1357e-01]],
 
          [[-5.6069e-01, -5.6069e-01, -5.7003e-01,  ...,  1.3886e-01,
             1.3950e-01,  1.0939e-01],
           [-5.1852e-01, -5.1852e-01, -5.1805e-01,  ...,  1.1096e-01,
             1.1139e-01,  1.0469e-01],
           [-5.1761e-01, -5.1761e-01, -5.1761e-01,  ...,  9.6470e-02,
             1.0451e-01,  1.0

In [23]:
logits = model(batch['sample'])        # shape [B, num_classes]
probs  = F.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)    # shape [B]

In [24]:
preds

tensor([2, 0, 3, 0, 3, 3, 3, 2, 2, 2, 1, 2, 3, 2, 0, 1, 2, 0, 0, 0, 1, 2, 3, 2,
        0, 0, 3, 2, 2, 0, 0, 0])

In [26]:
batch['class_id']

tensor([0, 2, 0, 0, 3, 0, 1, 0, 3, 3, 2, 0, 2, 0, 2, 2, 2, 2, 0, 3, 2, 3, 0, 1,
        0, 0, 2, 0, 2, 0, 0, 2])

In [27]:
correct = (preds == batch['class_id']).sum()
total = batch['class_id'].numel()
correct / total

tensor(0.2812)

In [32]:
batch_acc = train_acc(logits, batch['class_id'])

In [33]:
batch_acc

tensor(0.2812)

In [11]:
output = model.training_step(batch, batch_idx=0)
print(output)            # should be a tensor == the loss
print(output.item())     # the scalar loss valueprobs

tensor(1.3109, grad_fn=<NllLossBackward0>)
1.3109310865402222


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [9]:
# 2) Check configure_optimizers output
cfg = model.configure_optimizers()
print(cfg)
# You should see a dict of the form {'optimizer': <AdamW>, 'lr_scheduler': {…}}

assert isinstance(cfg, dict)
assert 'optimizer' in cfg
assert 'lr_scheduler' in cfg
print("✅ Scheduler hook-up looks good")

# 3) Run a single training+validation batch through Lightning
#    to make sure the scheduler.step() call doesn’t error out.
trainer = Trainer(
    fast_dev_run=True,
    logger=False,
    enable_checkpointing=False,
    # devices=1, accelerator='gpu'   # add if you have a GPU
)
trainer.fit(model, datamodule)
print("✅ fast_dev_run with scheduler completed without error")

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | backbone  | ResNet           | 23.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.065    Total estimated model params size (MB)
152       Modu

{'optimizer': AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 1e-05
), 'lr_scheduler': {'scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x15543264e740>, 'monitor': 'val_loss', 'interval': 'epoch'}}
✅ Scheduler hook-up looks good


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


✅ fast_dev_run with scheduler completed without error


In [10]:
# 2) Grab one batch from train
batch = next(iter(datamodule.train_dataloader()))
x, y = batch['sample'], batch['class_id']

print("Sample tensor shape:", x.shape)   # expect (32, 3, 224, 224)
print("Label tensor shape: ", y.shape)   # expect (32,)

# 3) Forward pass
logits = model(x)
print("Logits shape:       ", logits.shape)  # expect (32, num_classes)

# 4) Optimizer/scheduler check
opt_cfg = model.configure_optimizers()
print("configure_optimizers() returned:", opt_cfg)

# 5) One‐step Lightning run
trainer = Trainer(fast_dev_run=True, logger=False, enable_checkpointing=False)
trainer.fit(model, datamodule)

Sample tensor shape: torch.Size([32, 3, 224, 224])
Label tensor shape:  torch.Size([32])


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | backbone  | ResNet           | 23.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.065    Total estimated model params size (MB)
152       Modu

Logits shape:        torch.Size([32, 4])
configure_optimizers() returned: {'optimizer': AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 1e-05
), 'lr_scheduler': {'scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x1554350abc40>, 'monitor': 'val_loss', 'interval': 'epoch'}}


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [11]:
batch = next(iter(datamodule.train_dataloader()))
x, y_true = batch['sample'], batch['class_id']

logits = model(x)                        # (B, num_classes)
probs  = torch.softmax(logits, dim=1)    # (B, num_classes)
preds  = torch.argmax(probs, dim=1)      # (B,)

# map indices → names
decoder = datamodule.get_label_decoder()
pred_names = [decoder[int(i)] for i in preds]

print("Predicted classes for this batch:")
for i, name in enumerate(pred_names):
    print(f"  sample {i:>2d}: {name} (true: {decoder[int(y_true[i])]})")

Predicted classes for this batch:
  sample  0: apodemus_sp (true: apodemus_sp)
  sample  1: apodemus_sp (true: apodemus_sp)
  sample  2: soricidae (true: soricidae)
  sample  3: cricetidae (true: cricetidae)
  sample  4: cricetidae (true: cricetidae)
  sample  5: mustela_erminea (true: mustela_erminea)
  sample  6: apodemus_sp (true: apodemus_sp)
  sample  7: apodemus_sp (true: apodemus_sp)
  sample  8: apodemus_sp (true: apodemus_sp)
  sample  9: apodemus_sp (true: apodemus_sp)
  sample 10: soricidae (true: soricidae)
  sample 11: cricetidae (true: cricetidae)
  sample 12: apodemus_sp (true: mustela_erminea)
  sample 13: cricetidae (true: cricetidae)
  sample 14: cricetidae (true: cricetidae)
  sample 15: apodemus_sp (true: apodemus_sp)
  sample 16: cricetidae (true: cricetidae)
  sample 17: soricidae (true: soricidae)
  sample 18: mustela_erminea (true: mustela_erminea)
  sample 19: apodemus_sp (true: apodemus_sp)
  sample 20: cricetidae (true: cricetidae)
  sample 21: apodemus_sp (t