In [8]:
# Author: Bonaventure F. P. Dossou - bonaventure.dossou@mila.quebec (bonaventuredossou.github.io)
# Data transformation, Models Configurations and Training (more details on Tuberculosis_Solution.md)
# Check License under LICENSE.md
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms, models
from torchvision.models import vit_h_14, resnet18, resnet50, resnet152, efficientnet_v2_m, efficientnet_v2_l, convnext_base, convnext_large, wide_resnet101_2, vgg19_bn, regnet_x_32gf, swin_b, swin_v2_b, maxvit_t, regnet_y_128gf
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm
import copy
os.environ['TORCH_HOME'] = os.path.join('/','home','ngsci','project')

In [9]:
!rm -r '../tb/train/0/.ipynb_checkpoints'
# !rm -r '/home/ngsci/project/tb/train/1/*ipynb_checkpoints*'
# !rm -r '/home/ngsci/project/tb/val/0/*ipynb_checkpoints*'
# !rm -r '/home/ngsci/project/tb/val/1/*ipynb_checkpoints*'

rm: cannot remove '../tb/train/0/.ipynb_checkpoints': No such file or directory


In [10]:
data_dir = os.path.join('/','home','ngsci','project', 'tb')
num_classes = 2
batch_size = 32 # 32 gives current best results
num_epochs = 10
feature_extract = False

In [11]:
from sklearn.metrics import average_precision_score

# adapted from Pytorch Tutorial on Pretrained CV models
def train_model(model, model_name, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, lr=1e-5):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_loss = 1000
    best_pr_auc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
        
            running_loss = 0.0
            running_corrects = 0.0
            all_labels, all_preds = [], []
            for inputs, labels in tqdm(dataloaders[phase], desc=phase.capitalize()):
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                    # _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                preds = torch.softmax(outputs, dim=1)
                preds = preds.detach().cpu().numpy()[:, 1] # we only take the proba to belong to class one
                y_test = labels.data.detach().cpu().numpy()
                all_preds.extend(preds.tolist())
                all_labels.extend(y_test.tolist())
                # running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            # The average precision score is used to calculate PR-AUC
            # epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            epoch_acc = average_precision_score(np.array(all_labels), np.array(all_preds), average='weighted')
            
            if phase == 'val' and epoch_acc > best_pr_auc:
                # selecting the weights with lowest eval loss
                best_pr_auc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                path = os.path.join('/','home','ngsci','project', 'code', 'finetuned_weights','no_aug_tb_0_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr_))
                torch.save({'state_dict': model.state_dict(), 'best_auc': best_pr_auc}, path)
                print('{} - {} Loss: {:.4f} PR-AUC: {:.4f}'.format(model_name, phase, epoch_loss, epoch_acc))
        
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('{} - Best val PR-AUC: {:4f}'.format(model_name, best_pr_auc))
    model.load_state_dict(best_model_wts)
    return model

In [12]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [17]:
import re
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    model_ft = None
    input_size = 0
    
    # You need to download every pretrained weights of each model. Place it in the `pretrained_weights` folder in your directory. 
    # They are mostly on ImageNET V2 (sometimes V1) dataset
    if model_name == "resnet18":
        """Resnet18"""
        model_ft = resnet18(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet18-f37072fd.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "resnet50":
        """Resnet50"""
        model_ft = resnet50(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet50-11ad3fa6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "wide_resnet101":
        """Wide Resnet101 V2"""
        model_ft = wide_resnet101_2(weights=None)
        checkpoints = torch.load('pretrained_weights/wide_resnet101_2-d733dc28.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "resnet152":
        """Resnet152"""
        model_ft = resnet152(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet152-394f9c45.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "efficientnet":
        """EFFICIENTNET_V2_M"""
        model_ft = efficientnet_v2_l(weights=None)
        checkpoints = torch.load('pretrained_weights/efficientnet_v2_l-59c71312.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "convnext":
        """convnext_base"""
        model_ft = convnext_large(weights=None)
        checkpoints = torch.load('pretrained_weights/convnext_large-ea097f82.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[2].in_features
        model_ft.classifier[2] = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    
    if model_name == "vgg":
        """VGG19_bn"""
        model_ft = vgg19_bn(weights=None)
        checkpoints = torch.load('pretrained_weights/vgg19_bn-c79401a0.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    if model_name == "regnet":
        """regnet_x_32gf"""
        model_ft = regnet_y_128gf(weights="DEFAULT")
        checkpoints = torch.load('pretrained_weights/regnet_x_32gf-6eb8fdc6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "swin":
        """swin_b"""
        model_ft = swin_v2_b(weights=None)
        checkpoints = torch.load('pretrained_weights/swin_v2_b-781e5279.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.head.in_features
        model_ft.head = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "maxvit":
        """maxvit"""
        model_ft = maxvit_t(weights=None)
        checkpoints = torch.load('pretrained_weights/maxvit_t-bc5ab103.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[5].in_features
        model_ft.classifier[5] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    return model_ft, input_size

In [None]:
models = ["resnet50", "wide_resnet101", "efficientnet", "convnext", "vgg", "regnet", "swin", "maxvit"]

for model_name in models:
    model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomResizedCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
        
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
    dataloaders_dict = {x[0]: torch.utils.data.DataLoader(image_datasets[x[0]], batch_size=batch_size, shuffle=x[1], num_workers=4) for x in [('train', True), ('val', False)]}
    num_gpus = [i for i in range(torch.cuda.device_count())]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if len(num_gpus) > 1:
        print("Let's use", len(num_gpus), "GPUs!")
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in num_gpus)
        model_ft = torch.nn.DataParallel(model_ft, device_ids=num_gpus)
        model_ft = model_ft.module
    
    model_ft = model_ft.to(device)

    params_to_update = model_ft.parameters()
    if feature_extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                pass
    
    # using best learning rate for each model (for full report, see more on table of results on Solution.md)
    # this will produce best weights for each model
    model_lr_map = {'resnet18': 1e-5, 'resnet50': 1e-4, 'resnet152': 1e-5, 
                    'efficientnet': 4e-4, 'convnext': 1e-5, "wide_resnet101": 1e-4,
                   "vgg": 1e-4, "regnet": 1e-5, "swin": 1e-5, "maxvit": 1e-4}    
    lr_ = model_lr_map[model_name]

    print('Training {} with lr: {}'.format(model_name, lr_))
    optimizer_ft = optim.AdamW(params_to_update, lr=lr_)
    criterion = nn.CrossEntropyLoss()
    model_ft = train_model(model_ft, model_name, dataloaders_dict, criterion, optimizer_ft, num_epochs=5, is_inception=False, lr=lr_)