Import libraries

In [None]:
from dataset.dataset import load_data
from models import MRnet
from ordinal_config import ordinal_config
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.ordinal_utils import _train_model, _evaluate_model, _get_lr
import time
import torch.utils.data as data
import torch.nn as nn
import os

Method for training a model

In [None]:
def train(config: dict):
    """
    Function where actual fine-tuning takes place using MSE loss
    for 4-class classification by rounding predictions to closest class.
    """

    print('Starting to Train Model...')

    train_loader, val_loader, test_loader, train_wts, val_wts, test_wts = load_data()

    print('Initializing Model...')
    model = MRnet()

    # Load pretrained weights
    checkpoint = torch.load("weights/acl/model_test_acl_val_auc_0.9677_train_auc_0.9903_epoch_20.pth")
    model.load_state_dict(checkpoint["model_state_dict"])

    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    # Fine-tune the last conv block
    for name, param in model.named_parameters():
        if any(layer in name for layer in ['axial.10', 'coronal.10', 'saggital.10']):
            param.requires_grad = True

    # Replace the final layer for regression-style output
    num_features = model.fc[0].in_features if isinstance(model.fc, torch.nn.Sequential) else model.fc.in_features
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(num_features, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 1)  # Scalar output for regression-style classification
    )

    for param in model.fc.parameters():
        param.requires_grad = True

    if torch.cuda.is_available():
        model = model.cuda()
        train_wts = train_wts.cuda()
        val_wts = val_wts.cuda()

    print('Initializing Loss Method...')
    criterion = torch.nn.MSELoss()
    val_criterion = torch.nn.MSELoss()

    if torch.cuda.is_available():
        criterion = criterion.cuda()
        val_criterion = val_criterion.cuda()

    print('Setup the Optimizer')
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=3, factor=0.3, threshold=1e-4, verbose=True
    )

    starting_epoch = config['starting_epoch']
    num_epochs = config['max_epoch']
    log_train = config['log_train']
    log_val = config['log_val']

    best_val_loss = float('inf')
    best_val_acc = float(0)

    writer = SummaryWriter(comment=f"lr={config['lr']} task=acl-grading")
    t_start_training = time.time()

    print('Starting Training')

    for epoch in range(starting_epoch, num_epochs):
        current_lr = _get_lr(optimizer)
        epoch_start_time = time.time()

        print(f'Starting Epoch {epoch + 1}/{num_epochs}')
        train_loss, train_acc = _train_model(
            model, train_loader, epoch, num_epochs, optimizer, criterion, writer,
            current_lr, log_every=log_train, use_regression=True
        )

        print('Train loop ended, now evaluating on validation set...')
        val_loss, val_acc = _evaluate_model(
            model, val_loader, val_criterion, epoch, num_epochs, writer,
            current_lr, log_val, use_regression=True
        )

        writer.add_scalar('Train/Avg Loss', train_loss, epoch)
        writer.add_scalar('Val/Avg Loss', val_loss, epoch)

        scheduler.step(val_loss)

        epoch_time = time.time() - epoch_start_time

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
              f"Epoch Time: {epoch_time:.2f}s")
        print('-' * 50)
        writer.flush()

        if val_acc > best_val_acc:
            best_val_acc = val_acc

        if bool(config['save_model']) and (epoch + 1) % 10 == 0:
            file_name = f"model_{config['exp_name']}_acl_val_acc_{val_acc:.4f}_train_acc_{train_acc:.4f}_epoch_{epoch + 1}.pth"
            save_path = os.path.join('weights', config['task'], file_name)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save({'model_state_dict': model.state_dict()}, save_path)

    t_end_training = time.time()
    print(f'Training completed in {t_end_training - t_start_training:.2f}s')
    writer.close()

Train the model

In [None]:
print('Training Configuration')
print(ordinal_config)

train(config=ordinal_config)

print('Training Ended...')

Training Configuration
{'max_epoch': 50, 'log_train': 100, 'lr': 0.001, 'starting_epoch': 0, 'batch_size': 1, 'log_val': 10, 'weight_decay': 0.01, 'patience': 5, 'save_model': 1, 'exp_name': 'test'}
Starting to Train Model...
Loading Train Dataset of ACL task...
['001', '008', '015', '016', '084', '098', '101', '107', '109', '124', '129', '145', '150', '164', '172', '183', '201', '209', '225', '230', '245', '251']
Unique labels found in dataset: [0, 1, 2, 3]
Number of classes: 4
Class distribution:
Class 0: 7 samples
Class 1: 3 samples
Class 2: 8 samples
Class 3: 4 samples
Class weights for loss are: tensor([0.6713, 1.5664, 0.5874, 1.1748])
Total samples: 22 | Num classes: 4
Loading Validation Dataset of ACL task...
['037', '062', '133', '184', '207', '161']
Unique labels found in dataset: [0, 1, 2, 3]
Number of classes: 4
Class distribution:
Class 0: 2 samples
Class 1: 1 samples
Class 2: 1 samples
Class 3: 2 samples
Class weights for loss are: tensor([0.6667, 1.3333, 1.3333, 0.6667])




Initializing Loss Method...
Setup the Optimizer
Starting Training
Starting Epoch 1/50




[Epoch: 1 / 50 | Batch : 0 / 22 ]| Avg Train Loss: 1.2969 | Accuracy: 1.0000 | lr: 0.001
[Epoch: 1 / 50 | Batch : 2 / 22 ]| Avg Train Loss: 2.4473 | Accuracy: 0.3333 | lr: 0.001
[Epoch: 1 / 50 | Batch : 4 / 22 ]| Avg Train Loss: 1.8088 | Accuracy: 0.4000 | lr: 0.001
[Epoch: 1 / 50 | Batch : 6 / 22 ]| Avg Train Loss: 2.0987 | Accuracy: 0.2857 | lr: 0.001
[Epoch: 1 / 50 | Batch : 8 / 22 ]| Avg Train Loss: 1.9829 | Accuracy: 0.2222 | lr: 0.001
[Epoch: 1 / 50 | Batch : 10 / 22 ]| Avg Train Loss: 1.8661 | Accuracy: 0.1818 | lr: 0.001
[Epoch: 1 / 50 | Batch : 12 / 22 ]| Avg Train Loss: 1.7834 | Accuracy: 0.1538 | lr: 0.001
[Epoch: 1 / 50 | Batch : 14 / 22 ]| Avg Train Loss: 1.8083 | Accuracy: 0.2000 | lr: 0.001
[Epoch: 1 / 50 | Batch : 16 / 22 ]| Avg Train Loss: 1.8355 | Accuracy: 0.1765 | lr: 0.001
[Epoch: 1 / 50 | Batch : 18 / 22 ]| Avg Train Loss: 1.7790 | Accuracy: 0.2105 | lr: 0.001
[Epoch: 1 / 50 | Batch : 20 / 22 ]| Avg Train Loss: 1.7278 | Accuracy: 0.2381 | lr: 0.001
Train loop ende