In [8]:
import os
import math
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary
from torchvision import transforms
from torchinfo import summary
import timm

import matplotlib.pyplot as plt

from data.datamodule import Animal_DataModule

### Loading Configuration

In the following steps, we will load the configuration settings using the `load_configuration` function. The configuration is stored in the `config` variable which will be used throughout the script.

In [2]:
from config.load_configuration import load_configuration
config = load_configuration()

PC Name: DESKTOP-LUKAS
Loaded configuration from config/config_default.yaml


### Setting Seeds for Reproducibility

To ensure comparable and reproducible results, we set the random seed using the `seed_everything` function from PyTorch Lightning. This helps in achieving consistent behavior across multiple runs of the notebook.

In [9]:
pl.seed_everything(config['seed'])
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"   # disable oneDNN optimizations for reproducibility

Seed set to 42


### Checking for GPU Devices

In this step, we check for the availability of GPU devices and print the device currently being used by PyTorch. This ensures that the computations are performed on the most efficient hardware available.

In [10]:
# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('Torch Version: ', torch.__version__)
print('Using device: ', device)
if device.type == 'cuda':
    print('Cuda Version: ', torch.version.cuda)
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
    torch.set_float32_matmul_precision('high')

Torch Version:  2.7.0+cpu
Using device:  cpu


### Defining Transformations and Instantiating DataModule

In this step, we will define the necessary data transformations and initialize the `Animal_DataModule` with the provided configuration.

In [None]:
# TODO: Define transformations here

# Example for transformation
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((300, 300)),  # Resize images to match EfficientNet input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])

dl = Animal_DataModule(config['path_to_data'])

### Creating the Model

In this step, we will define the model architecture and print its summary using the `ModelSummary` utility from PyTorch Lightning. This provides an overview of the model's layers, parameters, and structure.

In [None]:
# Load pretrained model
model = timm.create_model(
    'efficientnet_b3',      # Hardcoded for now
    pretrained=True,
)
# Define number of classes and classifier
num_classes = 1             # Hardcoded for now, Dwarf Rabbit OK/NOK output    
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)

# Freeze all layers except the classifier
for param in model.parameters():
    param.requires_grad = False
model.classifier.requires_grad = True

# Print model summary
summary(model, input_size=(1, 3, 300, 300), depth=2)


Layer (type:depth-idx)                        Output Shape              Param #
EfficientNet                                  [1, 2]                    --
├─Conv2d: 1-1                                 [1, 40, 150, 150]         (1,080)
├─BatchNormAct2d: 1-2                         [1, 40, 150, 150]         80
│    └─Identity: 2-1                          [1, 40, 150, 150]         --
│    └─SiLU: 2-2                              [1, 40, 150, 150]         --
├─Sequential: 1-3                             [1, 384, 10, 10]          --
│    └─Sequential: 2-3                        [1, 24, 150, 150]         (3,504)
│    └─Sequential: 2-4                        [1, 32, 75, 75]           (48,118)
│    └─Sequential: 2-5                        [1, 48, 38, 38]           (110,912)
│    └─Sequential: 2-6                        [1, 96, 19, 19]           (638,700)
│    └─Sequential: 2-7                        [1, 136, 19, 19]          (1,387,760)
│    └─Sequential: 2-8                        [1, 232, 1

In [None]:
# Initialize the Wandb logger
wandb_config = {
    'project_name': config['wandb_project_name'],
    'experiment_name': config['wandb_experiment_name'],
    'batch_size': config['batch_size'],
    'max_epochs': config['max_epochs'],
    'learning_rate': config['learning_rate'],
}

wandb_logger = WandbLogger(
    project=config['wandb_project_name'],
    name=config['wandb_experiment_name'],
    config=config
    # save_dir=os.path.join(config['path_to_data'], 'logs')
)

In [None]:
# Initialize Trainer with wandb logger, using early stopping callback (https://lightning.ai/docs/pytorch/stable/common/early_stopping.html)
trainer = Trainer(
    max_epochs=config['max_epochs'], 
    default_root_dir='model/checkpoint/', #data_directory, 
    accelerator="auto", 
    devices="auto", 
    strategy="auto",
    callbacks=[EarlyStopping(monitor='val_loss', patience=5, mode='min')], 
    logger=wandb_logger)

# Training of the model
trainer.fit(model=model, datamodule=dm)

# Finish wandb
wandb.finish()