In [46]:
import yaml
import shutil
import torch
import pandas as pd
from torchvision.transforms import v2
from pathlib import Path
from typing import Sequence

from ba_dev.dataset import MammaliaDataImage
from ba_dev.datamodule import MammaliaDataModule
from ba_dev.transform import ImagePipeline
from ba_dev.model import LightningModelImage
from ba_dev.trainer import MammaliaTrainer
from ba_dev.utils import count_trainable_parameters

In [55]:
def print_banner(text, width=80, border_char='-'):
    inner_width = width - 4
    line = border_char * (width - 2)
    centered = text.center(inner_width)

    print(f'+{line}+')
    print(f'| {centered} |')
    print(f'+{line}+')


def read_config_yaml(config_path):
    try:
        with open(config_path) as f:
            return yaml.load(f, Loader=yaml.FullLoader)
    except FileNotFoundError:
        raise FileNotFoundError(
            f'Config file not found at {config_path}. Please provide a valid path.'
            )
    except yaml.YAMLError as e:
        raise ValueError(
            f'Error parsing YAML file at {config_path}: {e}'
            )


def set_up_image_pipeline(cfg):
    pre_ops = []
    if cfg['to_rgb']:
        pre_ops.append(('to_rgb', {}))
    if cfg['crop_by_bb']:
        pre_ops.append(('crop_by_bb', {'crop_shape': cfg['crop_by_bb']}))

    ops = []
    if cfg['to_tensor']:
        ops.append(v2.ToImage())
        ops.append(v2.ToDtype(torch.float32, scale=True))

    resize = cfg['resize']
    if resize:
        if isinstance(resize, int):
            ops.append(v2.Resize((resize, resize)))
        elif isinstance(resize, Sequence) and len(resize) == 2:
            ops.append(v2.Resize((resize)))
        else:
            raise ValueError(
                f'Invalid resize value: {resize}. Must be int or Sequence of two ints.'
                )

    norm = cfg['normalize']
    if norm:
        if isinstance(norm, dict):
            mean = norm['mean']
            std = norm['std']
        elif isinstance(norm, str):
            if norm.lower() == 'imagenet':
                mean = [0.485, 0.456, 0.406]
                std = [0.229, 0.224, 0.225]
            else:
                stats = torch.load(norm)
                mean = stats['mean']
                std = stats['std']
        ops.append(v2.Normalize(mean=mean, std=std))

    image_pipeline = ImagePipeline(
        pre_ops=pre_ops,
        transform=v2.Compose(ops)
        )

    augment = cfg['augmentation']
    if not augment:
        augmented_image_pipeline = None
    else:
        ops_aug = list(ops)

        for entry in augment:
            name, params = next(iter(entry.items()))
            Op = getattr(v2, name, None)
            if Op is None:
                raise ValueError(f'Unknown transform: {name!r}')
            ops_aug.append(Op(**(params or {})))

        augmented_image_pipeline = ImagePipeline(
                    pre_ops=pre_ops,
                    transform=v2.Compose(ops_aug)
                    )

    return image_pipeline, augmented_image_pipeline

In [56]:
args_output_dir = "/cfs/earth/scratch/kraftjul/BA/output/test"
args_config_path = "/cfs/earth/scratch/kraftjul/BA/code/run/config_template.yaml"
args_dev_run = True

In [60]:
dir = Path(args_output_dir)
if dir.exists():
    shutil.rmtree(dir)

dir.mkdir(parents=True, exist_ok=True)

In [61]:
if args_dev_run:
    print_banner('!!!   Running in dev mode   !!!', width=80)

config_path = Path(args_config_path)
output_dir = Path(args_output_dir)
experiment_info_path = output_dir / 'experiment_info.yaml'

if not output_dir.exists():
    raise FileNotFoundError(
        f'Output directory {output_dir} does not exist. Please provide a valid path.'
    )

cfg = read_config_yaml(config_path)

try:
    shutil.copy2(config_path, experiment_info_path)
except Exception as e:
    raise RuntimeError(
        f"Failed to copy config file to {experiment_info_path}: {e}"
        )

# setting up image pipeline
image_pipeline, augmented_image_pipeline = set_up_image_pipeline(cfg['image_pipeline'])

# setting up datamodule config
label_key = 'test_labels' if args_dev_run else 'labels'
dataset_raw = cfg.get('dataset') or {}
paths = cfg['paths']
dataset_kwargs = {
    'path_labelfiles': paths[label_key],
    'path_to_dataset': paths['dataset'],
    'path_to_detector_output': paths['md_output'],
    **dataset_raw
    }

datamodule_raw = cfg.get('data_module') or {}
datamodule_cfg = {
    'dataset_cls': MammaliaDataImage,
    'image_pipeline': image_pipeline,
    'augmented_image_pipeline': augmented_image_pipeline,
    **datamodule_raw
    }

# setting up model config
model_cfg = cfg['model']

# setting up trainer config
trainer_raw = cfg.get('trainer') or {}

_not_dev_defaults = {
    'limit_train_batches': 1.0,
    'limit_val_batches': 1.0,
    'limit_test_batches': 1.0,
    'limit_predict_batches': 1.0,
    'max_epochs': -1,
    'log_every_n_steps': 10,
    }

if args_dev_run:
    dev_run_args = {
        'limit_train_batches': 1,
        'limit_val_batches': 1,
        'limit_test_batches': 1,
        'limit_predict_batches': 1,
        'max_epochs': 1,
        'log_every_n_steps': 1
        }
else:
    dev_run_args = (trainer_raw.get('not_dev') or {}).copy()
    for key, value in _not_dev_defaults.items():
        if key not in dev_run_args:
            dev_run_args[key] = value

trainer_kwargs = {
    **(trainer_raw.get('trainer_kwargs') or {}),
    **dev_run_args
    }

trainer_cfg = (trainer_raw.get('base_args') or {}).copy()
trainer_cfg['trainer_kwargs'] = trainer_kwargs

trainer_do_predict = trainer_raw['do_predict']

# setting up folds or cross-validation
cross_val = cfg['cross_val']['apply']
n_folds = cfg['cross_val']['n_folds']
test_fold = cfg['cross_val']['test_fold']

if cross_val:
    if args_dev_run:
        folds = range(2)
    else:
        folds = range(n_folds)
else:
    folds = [test_fold]

log_dir = output_dir / 'logs'

run_params = {}
if cross_val:
    run_params['folds'] = {}
    
first_pass = True


+------------------------------------------------------------------------------+
|                       !!!   Running in dev mode   !!!                        |
+------------------------------------------------------------------------------+


In [62]:
# running the experiment
for fold in folds:
    if cross_val:
        trainer_log_dir = log_dir / f'fold_{fold}'
        print_statement = f'Running cross-validation fold {fold+1}/{len(folds)}'
        
    else:
        trainer_log_dir = log_dir
        print_statement = f'Running Experiment with test fold = {test_fold}'

    print_banner(print_statement, width=80)
    trainer_log_dir.mkdir(parents=True)

    dm_cfg = datamodule_cfg.copy()
    dm_cfg['dataset_kwargs'] = dataset_kwargs.copy()
    datamodule = MammaliaDataModule(
                    n_folds=5,
                    test_fold=fold,
                    **dm_cfg,
                    )

    m_cfg = model_cfg.copy()
    model = LightningModelImage(
                    num_classes=datamodule.num_classes,
                    class_weights=datamodule.class_weights,
                    **m_cfg
                    )

    t_cfg = trainer_cfg.copy()
    trainer = MammaliaTrainer(
                    log_dir=trainer_log_dir,
                    **t_cfg
                    )
    
    trainer.fit(
        model=model,
        datamodule=datamodule
        )

    test_metrics = trainer.test(
        model=model,
        datamodule=datamodule
        )


    if trainer_do_predict:
        best_ckpt = trainer_log_dir / 'checkpoints' / 'best.ckpt'
        trainer.predict(
            model=model,
            datamodule=datamodule,
            ckpt_path=best_ckpt,
            return_predictions=False
            )
    
    fold_params = {}
    fold_params['test_metrics'] = test_metrics[0]
    fold_params['class_weights'] = datamodule.class_weights.tolist()

    if first_pass:
        dataset = datamodule.get_dataset('pred')
        df = dataset.get_ds_with_folds()
        df.to_csv(log_dir / 'dataset.csv', index=False)
        run_params['trainable_params'] = count_trainable_parameters(model)
        run_params['label_decoder'] = dataset.label_decoder
        run_params['num_classes'] = datamodule.num_classes
        del dataset, df  
        first_pass = False
    
    if cross_val:
        fold_params['test_fold'] = fold
        run_params['folds'][fold] = fold_params
    else:
        run_params.update(fold_params)

run_output = {'output': run_params}
with open(experiment_info_path, "a") as f:
    f.write("\n")
    yaml.dump(run_output, f, default_flow_style=False)

print_banner('Experiment completed!', width=80)

+------------------------------------------------------------------------------+
|                      Running cross-validation fold 1/2                       |
+------------------------------------------------------------------------------+


8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightni

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /cfs/earth/scratch/kraftjul/BA/output/test/logs/fold_0/checkpoints/best.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                 0.796875
      test_bal_acc              0.19921875
        test_loss           1.2343661785125732
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Loaded model weights from the checkpoint at /cfs/earth/scratch/kraftjul/BA/output/test/logs/fold_0/checkpoints/best.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

+------------------------------------------------------------------------------+
|                      Running cross-validation fold 2/2                       |
+------------------------------------------------------------------------------+


8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightni

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /cfs/earth/scratch/kraftjul/BA/output/test/logs/fold_1/checkpoints/best.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    1.0
      test_bal_acc                  1.0
        test_loss           0.8041755557060242
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Loaded model weights from the checkpoint at /cfs/earth/scratch/kraftjul/BA/output/test/logs/fold_1/checkpoints/best.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

+------------------------------------------------------------------------------+
|                            Experiment completed!                             |
+------------------------------------------------------------------------------+


In [54]:
datamodule.class_weights


tensor([0.6070, 2.0583, 0.7733, 1.7436])

In [1]:
import random
import torch
import shutil

import numpy as np
from pathlib import Path
from PIL import Image

from torchvision.transforms import v2
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
import torch.nn.functional as F
from torchmetrics import Accuracy

from ba_dev.dataset import MammaliaDataSequence, MammaliaDataImage
from ba_dev.datamodule import MammaliaDataModule
from ba_dev.transform import ImagePipeline, BatchImagePipeline
from ba_dev.model import LightningModelImage
from ba_dev.trainer import MammaliaTrainer
from ba_dev.utils import load_path_yaml

paths = load_path_yaml('/cfs/earth/scratch/kraftjul/BA/data/path_config.yml')


### Running Tests

In [2]:
stats = torch.load(paths['feature_stats'])

image_pipeline = ImagePipeline(
        pre_ops=[
            ('to_rgb', {}),
            ('crop_by_bb', {'crop_shape': 1.0})
            ],
        transform=v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((224, 224)),
            v2.Normalize(
                mean=stats['mean'],
                std=stats['std']
                )
            ])
        )

dataset_kwargs = {
        'path_labelfiles': paths['test_labels'],
        'path_to_dataset': paths['dataset'],
        'path_to_detector_output': paths['md_output'],
        }

datamodule = MammaliaDataModule(
                dataset_cls=MammaliaDataImage,
                dataset_kwargs=dataset_kwargs,
                n_folds=5,
                test_fold=0,
                image_pipeline=image_pipeline,
                augmented_image_pipeline=None,
                batch_size=32,
                num_workers=1,
                pin_memory=True,
                )

model = LightningModelImage(
            num_classes=datamodule.num_classes,
            class_weights=datamodule.class_weights,
            backbone_name='efficientnet_b0',
            backbone_pretrained=True,
            backbone_weights='DEFAULT',
            optimizer_name='AdamW',
            optimizer_kwargs={
                'lr': 1e-3,
                'weight_decay': 1e-5,
                'amsgrad': False
                },
            scheduler_name='CosineAnnealingLR',
            scheduler_kwargs={'T_max': 5},
            )

log_dir = Path('/cfs/earth/scratch/kraftjul/BA/output/test')
if log_dir.exists():
    shutil.rmtree(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)

trainer = MammaliaTrainer(
            log_dir=log_dir,
            pred_writer_log_keys=['class_id', 'set', 'pred_id', 'probs'],
            pred_writer_prob_precision=4,
            accelerator='cpu',
            patience=5,
            trainer_kwargs={
                'log_every_n_steps': 1,
                'max_epochs': 1,
                }
            )


8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [3]:
trainer.fit(model, datamodule=datamodule)

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory /cfs/earth/scratch/kraftjul/BA/output/test/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | backbone      | EfficientNet     | 4.0 M  | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.051    Total estimated model params size (MB)
347       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [5]:
testmetrics = trainer.test(model, datamodule=datamodule)

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9027237296104431
      test_bal_acc          0.5089578628540039
        test_loss           0.3492245376110077
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [8]:
testmetrics

[{'test_loss': 0.3492245376110077,
  'test_acc': 0.9027237296104431,
  'test_bal_acc': 0.5089578628540039}]

In [3]:
trainer.fit(model, datamodule=datamodule)

trainer.test(model, datamodule=datamodule)

best_ckpt = trainer.checkpoint_callback.best_model_path
trainer.predict(
    model,
    datamodule=datamodule,
    ckpt_path=best_ckpt,
    return_predictions=False  # our PredictionWriter will write to CSV
)

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory /cfs/earth/scratch/kraftjul/BA/output/test/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | backbone      | EfficientNet     | 4.0 M  | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.051    Total estimated model params size (MB)
347       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[W503 22:56:28.291574881 NNPACK.cpp:62] Could not initialize NNPACK! Reason: Unsupported hardware.
Restoring states from the checkpoint path at /cfs/earth/scratch/kraftjul/BA/output/test/checkpoints/best.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7898832559585571
      test_bal_acc          0.6755728721618652
        test_loss           0.9037746787071228
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Loaded model weights from the checkpoint at /cfs/earth/scratch/kraftjul/BA/output/test/checkpoints/best.ckpt
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [24]:
# run predict and capture the return value
predictions = trainer.predict(
    model,
    datamodule=datamodule,
    ckpt_path='best',            # or best_ckpt = trainer.checkpoint_callback.best_model_path
    return_predictions=True      # make sure you return, not False
)

# `predictions` is now a list of per‐batch outputs
# (each entry is whatever your `predict_step` returned,
# e.g. a dict with keys: 'class_id','bbox','conf','seq_id','set','file_path','preds','probs')

# inspect the first batch:
first_batch = predictions[0]


Restoring states from the checkpoint path at /cfs/earth/scratch/kraftjul/BA/output/test/checkpoints/best.ckpt
Loaded model weights from the checkpoint at /cfs/earth/scratch/kraftjul/BA/output/test/checkpoints/best.ckpt
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [21]:
first_batch

{'class_id': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'bbox': [tensor([0.0000e+00, 4.0000e-04, 6.3900e-02, 0.0000e+00, 1.1710e-01, 4.0000e-04,
          4.0000e-04, 0.0000e+00, 9.2200e-02, 4.4580e-01, 2.1900e-02, 1.3570e-01,
          2.8220e-01, 1.0440e-01, 1.3470e-01, 1.3570e-01, 1.3810e-01, 1.3370e-01,
          4.0000e-04, 2.1330e-01, 4.0000e-04, 1.4000e-03, 2.5000e-01, 4.0000e-04,
          4.0000e-04, 0.0000e+00, 1.2150e-01, 1.5960e-01, 1.4400e-01, 4.5400e-02,
          9.8100e-02, 4.0180e-01, 1.5180e-01, 6.1470e-01, 2.1480e-01, 2.3090e-01,
          2.1330e-01, 5.2920e-01, 5.3660e-01, 2.1920e-01, 1.2500e-01, 6.9720e-01,
          1.3470e-01, 9.5700e-02, 5.3270e-01, 5.3850e-01, 5.3900e-01, 5.3220e-01,
          5.3220e-01, 5.3220e-01, 5.4730e-01, 6.7720e-01, 5.3220e-01, 5.3360e-01,
          5.1120e-01,

In [29]:
reconstructed_batch = []
batch_size = len(first_batch['class_id'])

for i in range(batch_size):
    item_dict = {
        'class_id': first_batch['class_id'][i].item(),
        'bbox': [tensor[i].item() for tensor in first_batch['bbox']],
        'conf': first_batch['conf'][i].item(),
        'seq_id': first_batch['seq_id'][i].item(),
        'set': first_batch['set'][i],
        'file': first_batch['file'][i],
        'preds': first_batch['preds'][i].item(),
        'probs': first_batch['probs'][i].tolist()
    }
    reconstructed_batch.append(item_dict)

In [34]:
print(batch_size)
len(reconstructed_batch)

64


64

In [35]:
import torch

def reconstruct_batch(batch):
    batch_size = len(batch['class_id'])
    reconstructed = []

    for i in range(batch_size):
        bbox = [tensor[i].item() for tensor in batch['bbox']]
        item = {
            'class_id': batch['class_id'][i].item(),
            'bbox': bbox,
            'conf': batch['conf'][i].item(),
            'seq_id': batch['seq_id'][i].item(),
            'set': batch['set'][i],
            'file': batch['file'][i],
            'pred': batch['preds'][i].item(),
            'probs': batch['probs'][i].tolist()
        }
        reconstructed.append(item)

    return reconstructed

# Usage
reconstructed_batch = reconstruct_batch(first_batch)

In [37]:
reconstructed_batch[0]

{'class_id': 0,
 'bbox': [0.0, 0.4687, 0.4111, 0.3671],
 'conf': 0.983,
 'seq_id': 4007156,
 'set': 'train',
 'file': 'IMG_6165.JPG',
 'pred': 0,
 'probs': [0.983549952507019,
  0.00028202414978295565,
  0.015514460392296314,
  0.0006537171429954469]}

In [15]:
total = 0.

list = samples[0]['probs'].tolist()
for i in list:
    total += i
print(total)



1.0000001541920938


In [10]:
keys = ['class_id', 'bbox', 'conf', 'seq_id', 'set', 'file', 'preds', 'probs']

for key in keys:
    item = first_batch[key][0]

    print(f'{key}: has type:{type(item)}')

    print(item)

class_id: has type:<class 'torch.Tensor'>
tensor(0)
bbox: has type:<class 'torch.Tensor'>
tensor([0.0000e+00, 4.0000e-04, 6.3900e-02, 0.0000e+00, 1.1710e-01, 4.0000e-04,
        4.0000e-04, 0.0000e+00, 9.2200e-02, 4.4580e-01, 2.1900e-02, 1.3570e-01,
        2.8220e-01, 1.0440e-01, 1.3470e-01, 1.3570e-01, 1.3810e-01, 1.3370e-01,
        4.0000e-04, 2.1330e-01, 4.0000e-04, 1.4000e-03, 2.5000e-01, 4.0000e-04,
        4.0000e-04, 0.0000e+00, 1.2150e-01, 1.5960e-01, 1.4400e-01, 4.5400e-02,
        9.8100e-02, 4.0180e-01, 1.5180e-01, 6.1470e-01, 2.1480e-01, 2.3090e-01,
        2.1330e-01, 5.2920e-01, 5.3660e-01, 2.1920e-01, 1.2500e-01, 6.9720e-01,
        1.3470e-01, 9.5700e-02, 5.3270e-01, 5.3850e-01, 5.3900e-01, 5.3220e-01,
        5.3220e-01, 5.3220e-01, 5.4730e-01, 6.7720e-01, 5.3220e-01, 5.3360e-01,
        5.1120e-01, 5.5320e-01, 5.3560e-01, 5.3220e-01, 5.2680e-01, 5.3120e-01,
        5.4240e-01, 5.4000e-01, 5.3120e-01, 6.7040e-01], dtype=torch.float64)
conf: has type:<class 'torch.Ten

In [23]:
logits = model(batch['sample'])        # shape [B, num_classes]
probs  = F.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)    # shape [B]

In [24]:
preds

tensor([2, 0, 3, 0, 3, 3, 3, 2, 2, 2, 1, 2, 3, 2, 0, 1, 2, 0, 0, 0, 1, 2, 3, 2,
        0, 0, 3, 2, 2, 0, 0, 0])

In [26]:
batch['class_id']

tensor([0, 2, 0, 0, 3, 0, 1, 0, 3, 3, 2, 0, 2, 0, 2, 2, 2, 2, 0, 3, 2, 3, 0, 1,
        0, 0, 2, 0, 2, 0, 0, 2])

In [27]:
correct = (preds == batch['class_id']).sum()
total = batch['class_id'].numel()
correct / total

tensor(0.2812)

In [32]:
batch_acc = train_acc(logits, batch['class_id'])

In [33]:
batch_acc

tensor(0.2812)

In [11]:
output = model.training_step(batch, batch_idx=0)
print(output)            # should be a tensor == the loss
print(output.item())     # the scalar loss valueprobs

tensor(1.3109, grad_fn=<NllLossBackward0>)
1.3109310865402222


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [9]:
# 2) Check configure_optimizers output
cfg = model.configure_optimizers()
print(cfg)
# You should see a dict of the form {'optimizer': <AdamW>, 'lr_scheduler': {…}}

assert isinstance(cfg, dict)
assert 'optimizer' in cfg
assert 'lr_scheduler' in cfg
print("✅ Scheduler hook-up looks good")

# 3) Run a single training+validation batch through Lightning
#    to make sure the scheduler.step() call doesn’t error out.
trainer = Trainer(
    fast_dev_run=True,
    logger=False,
    enable_checkpointing=False,
    # devices=1, accelerator='gpu'   # add if you have a GPU
)
trainer.fit(model, datamodule)
print("✅ fast_dev_run with scheduler completed without error")

/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | backbone  | ResNet           | 23.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.065    Total estimated model params size (MB)
152       Modu

{'optimizer': AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 1e-05
), 'lr_scheduler': {'scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x15543264e740>, 'monitor': 'val_loss', 'interval': 'epoch'}}
✅ Scheduler hook-up looks good


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


✅ fast_dev_run with scheduler completed without error


In [10]:
# 2) Grab one batch from train
batch = next(iter(datamodule.train_dataloader()))
x, y = batch['sample'], batch['class_id']

print("Sample tensor shape:", x.shape)   # expect (32, 3, 224, 224)
print("Label tensor shape: ", y.shape)   # expect (32,)

# 3) Forward pass
logits = model(x)
print("Logits shape:       ", logits.shape)  # expect (32, num_classes)

# 4) Optimizer/scheduler check
opt_cfg = model.configure_optimizers()
print("configure_optimizers() returned:", opt_cfg)

# 5) One‐step Lightning run
trainer = Trainer(fast_dev_run=True, logger=False, enable_checkpointing=False)
trainer.fit(model, datamodule)

Sample tensor shape: torch.Size([32, 3, 224, 224])
Label tensor shape:  torch.Size([32])


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | backbone  | ResNet           | 23.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.065    Total estimated model params size (MB)
152       Modu

Logits shape:        torch.Size([32, 4])
configure_optimizers() returned: {'optimizer': AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 1e-05
), 'lr_scheduler': {'scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x1554350abc40>, 'monitor': 'val_loss', 'interval': 'epoch'}}


/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/cfs/earth/scratch/kraftjul/.conda/envs/mega/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [11]:
batch = next(iter(datamodule.train_dataloader()))
x, y_true = batch['sample'], batch['class_id']

logits = model(x)                        # (B, num_classes)
probs  = torch.softmax(logits, dim=1)    # (B, num_classes)
preds  = torch.argmax(probs, dim=1)      # (B,)

# map indices → names
decoder = datamodule.get_label_decoder()
pred_names = [decoder[int(i)] for i in preds]

print("Predicted classes for this batch:")
for i, name in enumerate(pred_names):
    print(f"  sample {i:>2d}: {name} (true: {decoder[int(y_true[i])]})")

Predicted classes for this batch:
  sample  0: apodemus_sp (true: apodemus_sp)
  sample  1: apodemus_sp (true: apodemus_sp)
  sample  2: soricidae (true: soricidae)
  sample  3: cricetidae (true: cricetidae)
  sample  4: cricetidae (true: cricetidae)
  sample  5: mustela_erminea (true: mustela_erminea)
  sample  6: apodemus_sp (true: apodemus_sp)
  sample  7: apodemus_sp (true: apodemus_sp)
  sample  8: apodemus_sp (true: apodemus_sp)
  sample  9: apodemus_sp (true: apodemus_sp)
  sample 10: soricidae (true: soricidae)
  sample 11: cricetidae (true: cricetidae)
  sample 12: apodemus_sp (true: mustela_erminea)
  sample 13: cricetidae (true: cricetidae)
  sample 14: cricetidae (true: cricetidae)
  sample 15: apodemus_sp (true: apodemus_sp)
  sample 16: cricetidae (true: cricetidae)
  sample 17: soricidae (true: soricidae)
  sample 18: mustela_erminea (true: mustela_erminea)
  sample 19: apodemus_sp (true: apodemus_sp)
  sample 20: cricetidae (true: cricetidae)
  sample 21: apodemus_sp (t