In [118]:
import numpy
import matplotlib.pyplot

import yaml 

import wandb

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

import torchmetrics

import lightning as L
from pytorch_lightning.loggers import WandbLogger

In [119]:
# test
backbone_weights = torchvision.models.ResNet34_Weights.DEFAULT
preprocessor = backbone_weights.transforms()
backbone = torchvision.models.resnet34(weights=backbone_weights)

In [153]:
# set parameters and write to yaml file
parameters = {
    'Training': {
        'limit_train_batches': 0.10,
        'max_epochs': 10,
        'batch_size': 32,
        'early_stopping_patience': 2,
        'log_every_n_steps': 25,
    },
    'Optimizer': {
        'optimizer': 'adam',
        'weight_decay': 0.9,
        'lr': 6e-6,  
    },
    'Loss_Function': {
        'loss_function': 'qwk',
    },
    'Model': {
        'backbone': 'resnet34',
        'features': 512,
        'outputs': 3,
   },
}

parameters

{'Training': {'limit_train_batches': 1.0,
  'max_epochs': 10,
  'batch_size': 32,
  'early_stopping_patience': 2,
  'log_every_n_steps': 50},
 'Optimizer': {'optimizer': 'adam', 'weight_decay': 0.9, 'lr': 6e-06},
 'Loss_Function': {'loss_function': 'qwk'},
 'Model': {'backbone': 'resnet34', 'features': 512, 'outputs': 3}}

In [154]:
with open("parameters.yaml", 'w') as yaml_file:
    data = yaml.dump(parameters, yaml_file)

In [123]:
with open("parameters.yaml", "r") as yaml_file:
    parameters = yaml.load(yaml_file, Loader=yaml.FullLoader)
print(parameters)

{'Loss_Function': {'loss_function': 'qwk'}, 'Model': {'backbone': 'resnet34', 'features': 512, 'outputs': 3}, 'Optimizer': {'lr': 6e-06, 'optimizer': 'adam', 'weight_decay': 0.9}, 'Training': {'batch_size': 32, 'early_stopping_patience': 2, 'limit_train_batches': 1.0, 'max_epochs': 10}}


In [144]:
# fully-supervised fine-tuning
class Backbone(L.LightningModule):
    
    def __init__(self, n_classes, user_parameters):
        super().__init__()
        
        self.n_classes = n_classes
        self.user_parameters = user_parameters
        
        self.backbone_weights = torchvision.models.ResNet34_Weights.DEFAULT
        self.preprocessor = self.backbone_weights.transforms()
        self.backbone = torchvision.models.resnet34(weights=self.backbone_weights)
        self.n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(self.n_features, self.n_classes)
        
        match self.user_parameters['Loss_Function']['loss_function']:
            case 'cross_entropy':
                self.loss_function = F.cross_entropy
            case 'qwk':
                from WeightedKappaLoss import WeightedKappaLoss
                self.loss_function = WeightedKappaLoss(self.n_classes, mode='quadratic')
            case _:
                self.loss_function = F.cross_entropy  # defaults to cross entropy

        self.save_hyperparameters()  # wandb
        
    def forward(self, x):
        x_processed = self.preprocessor(x)
        return self.backbone(x_processed)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.backbone.parameters(),
            lr=self.user_parameters['Optimizer']['lr'],
            weight_decay=self.user_parameters['Optimizer']['weight_decay'],
        )
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        probas = F.softmax(logits, dim=1)
        y_hat = probas.argmax(dim=1)
        
        accuracy = torchmetrics.functional.accuracy(
            y_hat, y, task='multiclass', num_classes=self.n_classes
        )
        
        qwk = torchmetrics.functional.cohen_kappa(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
            weights='quadratic'
        )   
        
        recall = torchmetrics.functional.recall(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
        )
        
        self.log("train/loss", loss)  # wandb
        self.log("train/accuracy", accuracy)  # wandb
        self.log("train/recall", recall)  # wandb
        self.log("train/qwk", qwk)  # wandb
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        probas = F.softmax(logits, dim=1)
        y_hat = probas.argmax(dim=1)
        
        accuracy = torchmetrics.functional.accuracy(
            y_hat, y, task='multiclass', num_classes=self.n_classes
        )
        
        qwk = torchmetrics.functional.cohen_kappa(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
            weights='quadratic'
        )
        
        recall = torchmetrics.functional.recall(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
        )
        
        self.log("val/loss", loss)  # wandb
        self.log("val/accuracy", accuracy)  # wandb
        self.log("val/recall", recall)  # wandb
        self.log("val/qwk", qwk)  # wandb

    
        

In [145]:
wandb_logger = WandbLogger(
    project='debug-runs',
    save_dir='wandb-outputs',
    config=parameters,
)

In [146]:
fake_train_data = torchvision.datasets.FakeData(
    size=1000,
    image_size=(3, 512, 512),
    num_classes=3,
    transform=torchvision.transforms.ToTensor(),
)

fake_val_data = torchvision.datasets.FakeData(
    size=1000,
    image_size=(3, 512, 512),
    num_classes=3,
    transform=torchvision.transforms.ToTensor()
)

fake_train_data

Dataset FakeData
    Number of datapoints: 1000
    StandardTransform
Transform: ToTensor()

In [147]:
fake_train_dataloader = torch.utils.data.DataLoader(
    batch_size=parameters['Training']['batch_size'],
    dataset=fake_train_data,
    shuffle=True,
    num_workers=4,
)

fake_val_dataloader = torch.utils.data.DataLoader(
    batch_size=parameters['Training']['batch_size'],
    dataset=fake_val_data,
    num_workers=4,
)

In [149]:
callback_model_checkpoint = L.pytorch.callbacks.ModelCheckpoint(
    dirpath='checkpoints',
    filename='{epoch}',
    monitor='val/loss',
    save_last=True,
    save_top_k=1,
)

callback_early_stopping = L.pytorch.callbacks.EarlyStopping(
    monitor='val/loss',
    patience=parameters['Training']['early_stopping_patience'],
)


callbacks = [
    callback_model_checkpoint,
    callback_early_stopping,
]

In [150]:
model = Backbone(n_classes=3, user_parameters=parameters)
trainer = L.Trainer(
    limit_train_batches=parameters['Training']['limit_train_batches'], 
    max_epochs=parameters['Training']['max_epochs'],
    logger=wandb_logger,
    callbacks=callbacks,
    log_every_n_steps=parameters['Training']['log_every_n_steps'],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..


In [151]:
for batch in fake_train_dataloader:
    x, y = batch[0], batch[1]
    print(x.shape, y.shape)
    break

torch.Size([32, 3, 512, 512]) torch.Size([32])


In [152]:
trainer.fit(model, fake_train_dataloader, fake_val_dataloader)
wandb.finish()

/home/felipe/Projects/multiple-instance-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/felipe/Projects/multiple-instance-learning/src/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                | Params | Mode 
--------------------------------------------------------------
0 | preprocessor  | ImageClassification | 0      | train
1 | backbone      | ResNet              | 21.3 M | train
2 | loss_function | WeightedKappaLoss   | 0      | train
--------------------------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.145    Total estimated model params size (MB)
118       Modules in train mode
0         Modules in eval mode


                                                                           

/home/felipe/Projects/multiple-instance-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 32/32 [00:10<00:00,  3.15it/s, v_num=lbfe]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/32 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/32 [00:00<?, ?it/s][A
Validation DataLoader 0:   3%|▎         | 1/32 [00:00<00:02, 10.91it/s][A
Validation DataLoader 0:   6%|▋         | 2/32 [00:00<00:03,  8.90it/s][A
Validation DataLoader 0:   9%|▉         | 3/32 [00:00<00:03,  8.34it/s][A
Validation DataLoader 0:  12%|█▎        | 4/32 [00:00<00:03,  8.28it/s][A
Validation DataLoader 0:  16%|█▌        | 5/32 [00:00<00:03,  8.13it/s][A
Validation DataLoader 0:  19%|█▉        | 6/32 [00:00<00:03,  8.16it/s][A
Validation DataLoader 0:  22%|██▏       | 7/32 [00:00<00:03,  8.05it/s][A
Validation DataLoader 0:  25%|██▌       | 8/32 [00:00<00:02,  8.06it/s][A
Validation DataLoader 0:  28%|██▊       | 9/32 [00:01<00:02,  8.07it/s][A
Validation DataLoader 0:  31%|███▏      | 10/32 [00:01<00:02,  8.10it/s][A
Validation

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 32/32 [00:14<00:00,  2.16it/s, v_num=lbfe]


0,1
epoch,▁▂▂▃▃▃▄▄▅▆▆▆▆▇██
train/accuracy,▄▁▃▄▅█
train/loss,▄█▄▃▃▁
train/qwk,▅▁▁▃▅█
train/recall,▄▁▃▄▅█
trainer/global_step,▁▁▂▃▃▃▄▄▅▅▆▆▆▇██
val/accuracy,▁▂▂▃▄▅▆▆▇█
val/loss,█▇▆▅▄▄▃▂▂▁
val/qwk,▁▂▃▄▄▅▆▆▇█
val/recall,▁▂▂▃▄▅▆▆▇█

0,1
epoch,9.0
train/accuracy,0.71875
train/loss,0.89111
train/qwk,0.48624
train/recall,0.71875
trainer/global_step,319.0
val/accuracy,0.791
val/loss,0.82822
val/qwk,0.66596
val/recall,0.791


In [91]:
for batch in fake_train_dataloader:
    x, y = batch[0], batch[1]
    y_hat = model(x)
    acc = torch.metrics.accuracy(y_hat, y)
    print(x.shape, y.shape)
    break

Backbone(
  (preprocessor): ImageClassification(
      crop_size=[224]
      resize_size=[256]
      mean=[0.485, 0.456, 0.406]
      std=[0.229, 0.224, 0.225]
      interpolation=InterpolationMode.BILINEAR
  )
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 