Import libraries

In [9]:
from dataset.dataset import load_data
from models import MRnet
from config import config
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from utils.utils import _train_model, _evaluate_model, _get_lr
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import torch.utils.data as data
import torch.nn as nn
import os

Method for training a model

In [13]:
def train(config: dict):
    """
    Function where actual fine-tuning takes place

    Args:
        config (dict): Configuration to train with
    """

    print('Starting to Train Model...')

    train_loader, val_loader, test_loader, train_wts, val_wts, test_wts = load_data()

    print('Initializing Model...')
    model = MRnet()

    # Load the weights from the pretrained binary model
    checkpoint = torch.load("weights/acl/model_test_acl_val_auc_0.9677_train_auc_0.9903_epoch_20.pth")
    model.load_state_dict(checkpoint["model_state_dict"])

    # --------- FINE-TUNING STRATEGY ---------

    # Freeze all layers initially
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze later convolutional layers for each axis
    for name, param in model.named_parameters():
        if any(layer in name for layer in ['axial.10', 'coronal.10', 'saggital.10', 'axial.8', 'coronal.8', 'saggital.8']):
            param.requires_grad = True

    # Replace the final fully connected layer for 4-class output with Dropout and LayerNorm
    num_features = model.fc[0].in_features if isinstance(model.fc, torch.nn.Sequential) else model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.LayerNorm(128),  # Added LayerNorm
        nn.ReLU(),
        nn.Dropout(0.5),  # Added Dropout
        nn.Linear(128, 4)  # 4-class severity grading
    )

    # Make sure the new fc layers are trainable
    for param in model.fc.parameters():
        param.requires_grad = True

    # Move model to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
        train_wts = train_wts.cuda()
        val_wts = val_wts.cuda()

    print('Initializing Loss Method...')
    criterion = nn.CrossEntropyLoss(weight=train_wts)
    val_criterion = nn.CrossEntropyLoss(weight=val_wts)

    if torch.cuda.is_available():
        criterion = criterion.cuda()
        val_criterion = val_criterion.cuda()

    print('Setup the Optimizer')
    # Only optimize trainable parameters
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )

    scheduler = ReduceLROnPlateau(
        optimizer, patience=3, factor=0.3, threshold=1e-4, verbose=True
    )

    starting_epoch = config['starting_epoch']
    num_epochs = config['max_epoch']
    log_train = config['log_train']
    log_val = config['log_val']

    best_val_loss = float('inf')
    best_val_acc = float(0)
    patience_counter = 0
    early_stopping_patience = 10  # Example patience for early stopping
    best_model_state = None # To store the best model's state

    print('Starting Training')

    writer = SummaryWriter(comment=f"lr={config['lr']} task=acl-grading")
    t_start_training = time.time()

    for epoch in range(starting_epoch, num_epochs):

        current_lr = _get_lr(optimizer)
        epoch_start_time = time.time()

        print(f'Starting Epoch {epoch + 1}/{num_epochs}')
        train_loss, train_acc = _train_model(
            model, train_loader, epoch, num_epochs, optimizer, criterion, writer, current_lr, log_every=log_train
        )

        print('Train loop ended, now evaluating on validation set...')
        val_loss, val_acc = _evaluate_model(
            model, val_loader, val_criterion, epoch, num_epochs, writer, current_lr, log_val
        )

        writer.add_scalar('Train/Avg Loss', train_loss, epoch)
        writer.add_scalar('Val/Avg Loss', val_loss, epoch)
        writer.add_scalar('Train/Avg Accuracy', train_acc, epoch)
        writer.add_scalar('Val/Avg Accuracy', val_acc, epoch)

        scheduler.step(val_loss)

        epoch_time = time.time() - epoch_start_time

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
              f"Epoch Time: {epoch_time:.2f}s")

        print('-' * 50)

        writer.flush()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc # Update best val accuracy when loss improves
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch + 1}")
                if best_model_state is not None:
                    model.load_state_dict(best_model_state) # Load the best model
                break

        if bool(config['save_model']) and (epoch + 1) % 10 == 0:
            file_name = f"model_{config['exp_name']}_acl_val_acc_{val_acc:.4f}_train_acc_{train_acc:.4f}_epoch_{epoch + 1}.pth"
            save_path = os.path.join('weights', "acl", file_name)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save({'model_state_dict': model.state_dict()}, save_path)

    t_end_training = time.time()
    print(f'Training completed in {t_end_training - t_start_training:.2f}s')
    writer.close()

Train the model

In [None]:
print('Training Configuration')
print(config)

train(config=config)

print('Training Ended...')