In [None]:
# Author: Bonaventure F. P. Dossou - bonaventure.dossou@mila.quebec (bonaventuredossou.github.io)
# Data transformation, Models Configurations and Training (more details on Solution.md)
# Check License under LICENSE.md
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms, models
from torchvision.models import resnet18, resnet50, resnet152, efficientnet_v2_m, convnext_base, wide_resnet101_2, vgg19_bn, regnet_x_32gf, swin_b, maxvit_t
import matplotlib.pyplot as plt
import time
import os
import copy
os.environ['TORCH_HOME'] = os.path.join('/','home','ngsci','project')

In [None]:
data_dir = os.path.join('/','home','ngsci','project', 'breast_cancer')
num_classes = 5
batch_size = 32
num_epochs = 50
feature_extract = False

In [None]:
# adapted from Pytorch Tutorial on Pretrained CV models
def train_model(model, model_name, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, lr=1e-5):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
        
            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_loss < best_loss:
                # selecting the weights with lowest eval loss
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                path = os.path.join('/','home','ngsci','project','final_weights','breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr_))
                torch.save(model.state_dict(), path)
        
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    model.load_state_dict(best_model_wts)
    return model

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
import re
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    model_ft = None
    input_size = 0
    
    # You need to download every pretrained weights of each model. Place it in the `pretrained_weights` folder in your directory. 
    # They are mostly on ImageNET V2 (sometimes V1) dataset
    if model_name == "resnet18":
        """Resnet18"""
        model_ft = resnet18(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet18-f37072fd.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "resnet50":
        """Resnet50"""
        model_ft = resnet50(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet50-11ad3fa6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "wide_resnet101":
        """Wide Resnet101 V2"""
        model_ft = wide_resnet101_2(weights=None)
        checkpoints = torch.load('pretrained_weights/wide_resnet101_2-d733dc28.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "resnet152":
        """Resnet152"""
        model_ft = resnet152(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet152-394f9c45.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "efficientnet":
        """EFFICIENTNET_V2_M"""
        model_ft = efficientnet_v2_m(weights=None)
        checkpoints = torch.load('pretrained_weights/efficientnet_v2_m-dc08266a.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "convnext":
        """convnext_base"""
        model_ft = convnext_base(weights=None)
        checkpoints = torch.load('pretrained_weights/convnext_base-6075fbad.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[2].in_features
        model_ft.classifier[2] = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    
    if model_name == "vgg":
        """VGG19_bn"""
        model_ft = vgg19_bn(weights=None)
        checkpoints = torch.load('pretrained_weights/vgg19_bn-c79401a0.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    if model_name == "regnet":
        """regnet_x_32gf"""
        model_ft = regnet_x_32gf(weights=None)
        checkpoints = torch.load('pretrained_weights/regnet_x_32gf-6eb8fdc6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "swin":
        """swin_b"""
        model_ft = swin_b(weights=None)
        checkpoints = torch.load('pretrained_weights/swin_b-68c6b09e.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.head.in_features
        model_ft.head = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "maxvit":
        """maxvit"""
        model_ft = maxvit_t(weights=None)
        checkpoints = torch.load('pretrained_weights/maxvit_t-bc5ab103.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[5].in_features
        model_ft.classifier[5] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    return model_ft, input_size

In [None]:
for model_name in ["resnet50", "wide_resnet101", "efficientnet", "convnext", "vgg", "regnet", "swin", "maxvit"]:
    model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
    dataloaders_dict = {x[0]: torch.utils.data.DataLoader(image_datasets[x[0]], batch_size=batch_size, shuffle=x[1], num_workers=4) for x in [('train', True), ('val', False)]}
    num_gpus = [i for i in range(torch.cuda.device_count())]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if len(num_gpus) > 1:
        print("Let's use", len(num_gpus), "GPUs!")
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in num_gpus)
        model_ft = torch.nn.DataParallel(model_ft, device_ids=num_gpus)
        model_ft = model_ft.module
    
    model_ft = model_ft.to(device)

    params_to_update = model_ft.parameters()
    if feature_extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                pass
    
    # using best learning rate for each model (for full report, see more on table of results on Solution.md)
    # this will produce best weights for each model
    model_lr_map = {'resnet18': 1e-5, 'resnet50': 1e-4, 'resnet152': 1e-5, 
                    'efficientnet': 4e-4, 'convnext': 1e-5, "wide_resnet101": 1e-4,
                   "vgg": 1e-4, "regnet": 1e-5, "swin": 1e-5, "maxvit": 1e-4}    
    lr_ = model_lr_map[model_name]

    print('Training {} with lr: {}'.format(model_name, lr_))
    optimizer_ft = optim.AdamW(params_to_update, lr=lr_)
    criterion = nn.CrossEntropyLoss()
    model_ft = train_model(model_ft, model_name, dataloaders_dict, criterion, optimizer_ft, num_epochs=10, is_inception=False, lr=lr_)

In [None]:
# Evaluation
import glob
import pandas as pd

test_data = os.path.join('/','home','ngsci','datasets', 'brca-psj-path', 'basic-downsampling', 'holdout')
test_slides_fp = os.path.join(test_data,'*')
test_slides_list = glob.glob(test_slides_fp)
print('Test Images: {}'.format(len(test_slides_list)))

In [None]:
from PIL import Image
from tqdm import tqdm
import csv
import pandas as pd

def run_inference_image(path, model):
    model.eval()
    biopsy_id = os.path.basename(path).split('.')[0]
    image = Image.open(path)
    # Apply same transformation as on validation set
    transform_data = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    
    img_t = transform_data(image)
    img_t = img_t.float().unsqueeze(0)
    with torch.no_grad():
        output = model(img_t)
    
    prediction = output.squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    return biopsy_id, class_id

def run_inference(paths, model):
    predictions = []
    for index in tqdm(range(len(paths)), desc ="Evaluation Progress"):
        predictions.append(run_inference_image(paths[index], model))
    pred_dict = {biopsy: class_id for biopsy, class_id in predictions}
    return pred_dict

def build_model(model_name):
    if model_name == "resnet18":
        lr = 1e-5
        model_ft = resnet18(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "resnet50":
        lr = 1e-4
        model_ft = resnet50(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "resnet152":
        lr = 1e-5
        model_ft = resnet152(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "wide_resnet101":
        lr = 1e-4
        model_ft = wide_resnet101_2(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "vgg":
        lr = 1e-4
        model_ft = vgg19_bn(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "efficientnet":
        lr = 4e-4
        model_ft = efficientnet_v2_m(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "convnext":
        lr = 1e-5
        model_ft = convnext_base(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[2].in_features
        model_ft.classifier[2] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "regnet":
        lr = 1e-5
        model_ft = regnet_x_32gf(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "swin":
        lr = 1e-5
        model_ft = swin_b(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.head.in_features
        model_ft.head = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "maxvit":
        lr = 1e-4
        model_ft = maxvit_t(weights=None)
        checkpoints = torch.load('best_final_weights/breast_cancer_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[5].in_features
        model_ft.classifier[5] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

def save_predictions(pred_dict, name_model):
    frame = pd.DataFrame()
    frame['slide_id'] = list(pred_dict.keys())
    frame['cancer_stage'] = list(pred_dict.values())    
    frame.to_csv('predictions/predictions_{}_{}_{}.csv'.format(name_model, batch_size, num_epochs), index=False)

for model_ in ["resnet50", "wide_resnet101", "efficientnet", "convnext", "vgg", "regnet", "swin", "maxvit"]:
    print('Predicting for {}'.format(model_))
    predictions_dict = run_inference(test_slides_list, build_model(model_))
    save_predictions(predictions_dict, model_)

In [None]:
holdout_dir = os.path.join("/","home","ngsci", "datasets", "brca-psj-path", "holdout")
slide_biopsy_map = pd.read_csv(os.path.join(holdout_dir, "v2", "slide-biopsy-map.csv"))
slide_biopsy_map.head()

In [None]:
for model_ in ["resnet50", "wide_resnet101", "efficientnet", "convnext", "vgg", "regnet", "swin", "maxvit"]:
    file_ = 'predictions/predictions_{}_{}_{}_accuracy.csv'.format(model_, batch_size, num_epochs)
    predictions_model = pd.read_csv(file_)
    predictions_model = predictions_model.merge(
    slide_biopsy_map, on="slide_id", how="left")
    
    biopsy_stage_prediction = (
        predictions_model.groupby("biopsy_id")
        .agg({"cancer_stage": "mean"})
        .reset_index()
    )
    print(biopsy_stage_prediction.head())
    biopsy_stage_prediction.to_csv('predictions/final_predictions_{}_{}_{}_accuracy.csv'.format(model_, batch_size, num_epochs), index=False)

In [None]:
resnet18 = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("resnet18", batch_size, num_epochs))
resnet50 = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("resnet50", batch_size, num_epochs))
resnet152 = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("resnet152", batch_size, num_epochs))
efficientnet = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("efficientnet", batch_size, num_epochs))
convnext = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("convnext", batch_size, num_epochs))
wide_resnet = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("wide_resnet101", batch_size, num_epochs))
vgg = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("vgg", batch_size, num_epochs))
regnet = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("regnet", batch_size, num_epochs))
swin = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("swin", batch_size, num_epochs))
efficientnets = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("efficientnets", batch_size, num_epochs))
maxvit = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("maxvit", batch_size, num_epochs))

biopsy_ids = resnet50.biopsy_id.tolist()
rest50_preds = resnet50.cancer_stage.tolist()
efficientnet_preds = efficientnet.cancer_stage.tolist()
convnext_preds = convnext.cancer_stage.tolist()
wide_resnet_preds = wide_resnet.cancer_stage.tolist()
vgg_preds = vgg.cancer_stage.tolist()
regnet_preds = regnet.cancer_stage.tolist()
swin_preds = swin.cancer_stage.tolist()
maxvit_preds = maxvit.cancer_stage.tolist()

# Using only models that have loss < 1
avg_preds = [(r50 + effnet + convnext_pred + wide_resnet_pred + vgg_pred + regnet_pred + swin_pred + maxvit_pred)/8
             for r50, effnet, convnext_pred, wide_resnet_pred, vgg_pred, regnet_pred, swin_pred, maxvit_pred in 
             zip(rest50_preds, efficientnet_preds, convnext_preds, wide_resnet_preds, vgg_preds, regnet_preds, swin_preds, maxvit_preds)]

final_frame = pd.DataFrame()
final_frame['biopsy_id'] = biopsy_ids
final_frame['cancer_stage'] = avg_preds

final_frame.to_csv('predictions/final_predictions_deep_ensemble_{}_{}_with_AdamW_31.csv'.format(batch_size, num_epochs), index=False, header=False)
final_frame.head()

In [None]:
import ngsci
submission_file = 'predictions/final_predictions_deep_ensemble_{}_{}_with_AdamW_28.csv'.format(batch_size, num_epochs)
# best file --- project/best_predictions/final_predictions_deep_ensemble_32_50_with_AdamW_28.csv is using only average predictions of models with losses < 1
ngsci.submit_contest_entry(submission_file, description="Your description")