In [1]:
import os
import math
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary
from torchvision.transforms import v2
import timm
from datetime import datetime

import matplotlib.pyplot as plt

from data.datamodule import MultiClassImageDataModule
from models.model_facedetection import TransferLearningModuleMulticlass

### Loading Configuration

In the following steps, we will load the configuration settings using the `load_configuration` function. The configuration is stored in the `config` variable which will be used throughout the script.

In [2]:
from config.load_configuration import load_configuration
config = load_configuration()

PC Name: DESKTOP-MIKA
Loaded configuration from config/config_mika.yaml


### Setting Seeds for Reproducibility

To ensure comparable and reproducible results, we set the random seed using the `seed_everything` function from PyTorch Lightning. This helps in achieving consistent behavior across multiple runs of the notebook.

In [3]:
pl.seed_everything(config['seed'])
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"   # disable oneDNN optimizations for reproducibility

Seed set to 42


### Checking for GPU Devices

In this step, we check for the availability of GPU devices and print the device currently being used by PyTorch. This ensures that the computations are performed on the most efficient hardware available.

In [5]:
# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('Torch Version: ', torch.__version__)
print('Using device: ', device)
if device.type == 'cuda':
    print('Cuda Version: ', torch.version.cuda)
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
    torch.set_float32_matmul_precision('high')

Torch Version:  2.7.1+cpu
Using device:  cpu


### Defining Transformations and Instantiating DataModule

In this step, we will define the necessary data transformations and initialize the `MulticlassImageDataModule` with the provided configuration.

In [None]:
transform = v2.Compose([
    v2.Resize((300, 300)),  # Resize images to match EfficientNet input size
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])

dm = MultiClassImageDataModule(data_dir=config['path_to_bunnie_data_aug'], name_list=config['name_list'], transform=transform, batch_size=config['batch_size'], num_workers=2, persistent_workers=True)
dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
# test_loader = dm.test_dataloader()
test_loader = val_loader  # Use validation data (from Janik) for testing in this example

# Show a few images from the training set
# from torchvision.utils import make_grid
# def show_images(loader):
#     images, labels = next(iter(loader))
#     images = images[:16]  # Show only the first 16 images
#     labels = labels[:16]
#     grid = make_grid(images, nrow=4, padding=2)
#     plt.figure(figsize=(10, 10))
#     plt.imshow(grid.permute(1, 2, 0).numpy())
#     plt.title('Sample Images')
#     plt.axis('off')
#     plt.show()
#     print(labels[:16])  # Print corresponding labels
# # show_images(train_loader)

print('Train dataset size:', len(dm.train_dataset))
print('Validation dataset size:', len(dm.val_dataset))
print('Test dataset size:', len(dm.test_dataset))
print('Example train data shape:', dm.train_dataset[0][0].shape)
print('Example train label:', dm.train_dataset[0][1])
print('Example val data shape:', dm.val_dataset[0][0].shape)
print('Example val label:', dm.val_dataset[0][1])

### Creating the Model

In this step, we will define the model architecture and print its summary using the `ModelSummary` utility from PyTorch Lightning. This provides an overview of the model's layers, parameters, and structure.

In [None]:
from models.model_facedetection import TL_EfficientNetB4, TL_ConvNextV2
# model = TL_EfficientNetB4(num_classes=len(config['name_list']))
model = TL_ConvNextV2(num_classes=len(config['name_list']))
print(ModelSummary(model, max_depth=-1))

# Initialize the Wandb logger
# add time to the name of the experiment
import datetime
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d_%H-%M-%S")

# Initialize wandb logger
wandb_logger = WandbLogger(
    project=config['wandb_project_name'],
    name=f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}",
    config={
        'model': type(model).__name__,
        'dataset': 'DwarfRabbits-binary',
        'batch_size': config['batch_size'],
        'max_epochs': config['max_epochs'],
        'learning_rate': config['learning_rate']
    }
)

# Initialize Trainer with wandb logger, using early stopping callback (https://lightning.ai/docs/pytorch/stable/common/early_stopping.html)
trainer = Trainer(
    max_epochs=config['max_epochs'], 
    default_root_dir='model/checkpoint/', #data_directory, 
    accelerator="auto", 
    devices="auto", 
    strategy="auto",
    callbacks=[EarlyStopping(monitor='val_loss', patience=5, mode='min')], 
    logger=wandb_logger
    )

## Train the model

In [None]:
# Training of the model
trainer.fit(model=model, datamodule=dm)

# Finish wandb
wandb.finish()

# Save the trained model checkpoint
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(config['path_to_models'], f"{model.model_name}_model_{timestamp}.ckpt")
trainer.save_checkpoint(save_path)
print(f"Model saved to: {save_path}")
config['path_to_model'] = save_path

## Load pretrained model and perform evaluation

In [None]:
model = TransferLearningModuleMulticlass.load_from_checkpoint(
    config['path_to_model'],
    model=timm.create_model('efficientnet_b4', pretrained=False, num_classes=len(config['name_list'])),
    num_classes=len(config['name_list'])
)
# Put model in evaluation mode and move to correct device
model.eval()
trainer.test(model=model, dataloaders=val_loader)

In [None]:
# Show information about the test_loader
print("Test DataLoader Info:")
print(f"Batch size: {test_loader.batch_size}")
print(f"Number of batches: {len(test_loader)}")
print(f"Shuffle: {test_loader.shuffle if hasattr(test_loader, 'shuffle') else 'N/A'}")
print(f"Number of workers: {test_loader.num_workers}")
print(f"Dataset: {test_loader.dataset}")
print(f"Sampler: {test_loader.sampler}")
print(f"Drop last: {test_loader.drop_last}")
print(f"Pin memory: {test_loader.pin_memory}")
print(f"Persistent workers: {test_loader.persistent_workers if hasattr(test_loader, 'persistent_workers') else 'N/A'}")
print(f"Prefetch factor: {test_loader.prefetch_factor if hasattr(test_loader, 'prefetch_factor') else 'N/A'}")
print(f"Timeout: {test_loader.timeout}")

# Show a batch of test images, their shapes, and model predictions

images, labels = next(iter(test_loader))
print("Test batch images shape:", images.shape)
print("Test batch labels shape:", labels.shape)

# Move images to the correct device
images = images.to(device)
model = model.to(device)
model.eval()

with torch.no_grad():
    outputs = model(images)
    preds = torch.argmax(outputs, dim=1)

print("Predicted labels:", preds.cpu().numpy())
print("True labels:     ", labels.cpu().numpy())