Import libraries

In [1]:
from dataset.dataset import load_data
from models import MRnet
from config import config
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.utils import _train_model, _evaluate_model, _get_lr
import time
import torch.utils.data as data
import torch.nn as nn
import os

Method for training a model

In [2]:
"""Performs training of a specified model.
    
Input params:
    config_file: Takes in configurations to train with 
"""

def train(config : dict):
    """
    Function where actual training takes place

    Args:
        config (dict) : Configuration to train with
    """
    
    print('Starting to Train Model...')

    train_loader, val_loader, test_loader, train_wts, val_wts, test_wts = load_data()

    print('Initializing Model...')
    model = MRnet()
    # Load the weights from the previous model
    checkpoint = torch.load("weights/acl/model_test_acl_val_auc_0.9677_train_auc_0.9903_epoch_20.pth")
    model.load_state_dict(checkpoint["model_state_dict"])

   # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    # Replace final classification head with a new one (e.g., two layers)
    num_features = model.fc[0].in_features if isinstance(model.fc, nn.Sequential) else model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(),
        nn.Linear(128, 4)
    )

    # Unfreeze only the last two layers (ReLU has no parameters, so we unfreeze both Linear layers)
    for param in model.fc[0].parameters():
        param.requires_grad = True
    for param in model.fc[2].parameters():
        param.requires_grad = True


    if torch.cuda.is_available():
        model = model.cuda()
        train_wts = train_wts.cuda()
        val_wts = val_wts.cuda()

    print('Initializing Loss Method...')
    criterion = nn.CrossEntropyLoss(weight=train_wts)
    val_criterion = nn.CrossEntropyLoss(weight=val_wts)

    if torch.cuda.is_available():
        criterion = criterion.cuda()
        val_criterion = val_criterion.cuda()

    print('Setup the Optimizer')
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)
    
    starting_epoch = config['starting_epoch']
    num_epochs = config['max_epoch']
    patience = config['patience']
    log_train = config['log_train']
    log_val = config['log_val']

    best_val_loss = float('inf')
    best_val_auc = float(0)

    print('Starting Training')

    writer = SummaryWriter(comment='lr={} task=acl'.format(config['lr']))
    t_start_training = time.time()

    for epoch in range(starting_epoch, num_epochs):

        current_lr = _get_lr(optimizer)
        epoch_start_time = time.time()  # timer for entire epoch

        print('Started Training')
        train_loss, train_auc = _train_model(
            model, train_loader, epoch, num_epochs, optimizer, criterion, writer, current_lr, log_every=log_train)

        print('train loop ended, now val')
        val_loss, val_auc = _evaluate_model(
            model, val_loader, val_criterion,  epoch, num_epochs, writer, current_lr, log_val)

        writer.add_scalar('Train/Avg Loss', train_loss, epoch)
        writer.add_scalar('Val/Avg Loss', val_loss, epoch)

        scheduler.step(val_loss)

        t_end = time.time()
        delta = t_end - epoch_start_time

        print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
            train_loss, train_auc, val_loss, val_auc, delta))

        print('-' * 30)

        writer.flush()

        if val_auc > best_val_auc:
            best_val_auc = val_auc

        if bool(config['save_model']) and (epoch+1) % 10 == 0:
            file_name = 'model_{}_{}_val_auc_{:0.4f}_train_auc_{:0.4f}_epoch_{}.pth'.format(config['exp_name'], config['task'], val_auc, train_auc, epoch+1)
            torch.save({
                'model_state_dict': model.state_dict()
            }, './weights/{}/{}'.format(config['task'],file_name))

    t_end_training = time.time()
    print(f'training took {t_end_training - t_start_training} s')
    writer.flush()
    writer.close()

Train the model

In [None]:
print('Training Configuration')
print(config)

train(config=config)

print('Training Ended...')

Training Configuration
{'max_epoch': 50, 'log_train': 100, 'lr': 1e-05, 'starting_epoch': 0, 'batch_size': 1, 'log_val': 10, 'weight_decay': 0.01, 'patience': 5, 'save_model': 1, 'exp_name': 'test'}
Starting to Train Model...
Loading Train Dataset of ACL task...
['001', '008', '015', '016', '084', '098', '101', '107', '109', '124', '129', '145', '150', '164', '172', '183', '201', '209', '225', '230', '245', '251']
Unique labels found in dataset: [0, 1, 2, 3]
Number of classes: 4
Class distribution:
Class 0: 7 samples
Class 1: 3 samples
Class 2: 8 samples
Class 3: 4 samples
Class weights for loss are: tensor([0.6713, 1.5664, 0.5874, 1.1748])
Total samples: 22 | Num classes: 4
Loading Validation Dataset of ACL task...
['037', '062', '133', '184', '207', '121']
Unique labels found in dataset: [0, 1, 2, 3]
Number of classes: 4
Class distribution:
Class 0: 2 samples
Class 1: 1 samples
Class 2: 1 samples
Class 3: 2 samples
Class weights for loss are: tensor([0.6667, 1.3333, 1.3333, 0.6667])




Initializing Loss Method...
Setup the Optimizer
Starting Training
Started Training


