In [None]:
!pip install --upgrade pip -q
!pip install torchmetrics -q

In [None]:
from torchmetrics import Dice
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms, models
from torchvision.models import resnet50, efficientnet_v2_m, convnext_base, wide_resnet101_2, vgg19_bn, regnet_x_32gf, swin_b, maxvit_t
import matplotlib.pyplot as plt
import time
import os
import copy
os.environ['TORCH_HOME'] = os.path.join('/','home','ngsci','project')

In [None]:
# Dice metrics computes similarity of two sets
def dice_loss(predictions, targets):
    dice_ = Dice(average='macro', num_classes=5, zero_division=1).to(targets.device)
    dice_score = dice_(predictions, targets.int())
    return 1 - dice_score


In [None]:
MEANS_TRAIN, STDS_TRAIN = [0.91049168, 0.87755836, 0.92714607], [0.03925299, 0.09144832, 0.05845553]
MEANS_VALID, STDS_VALID = [0.91001739, 0.87733131, 0.92679886], [0.04010004, 0.09253081, 0.05938738]

MEANS = [(mean1 + mean2)/2 for mean1, mean2 in zip(MEANS_TRAIN, MEANS_VALID)]
STDS = [(std1 + std2)/2 for std1, std2 in zip(STDS_TRAIN, STDS_VALID)]
print(MEANS, STDS)

In [None]:
data_dir = os.path.join('/','home','ngsci','project', 'ami-ahead-wombcare')
num_classes = 5
batch_size = 32
num_epochs = 50
feature_extract = False
models = ["efficientnet", "swin", "maxvit", "wide_resnet101", "convnext", "resnet50", "regnet", "vgg"]

In [None]:
# adapted from Pytorch Tutorial on Pretrained CV models
def train_model(model, model_name, dataloaders, criterion, optimizer, num_epochs, is_inception=False, lr=1e-5):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        for phase in ['training', 'validation']:
            if phase == 'training':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
        
            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'training'):
                    outputs = model(inputs)

                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    
                    dice_loss_ = dice_loss(preds, labels)
                    
                    loss += dice_loss_ # cce + dice loss
                    
                    if phase == 'training':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print('{} (Dice + CCE) Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'validation' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                path = os.path.join('/','home','ngsci','project','final_weights','ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr_))
                torch.save(model.state_dict(), path)
        
        print()
    
    time_elapsed = time.time() - since
    
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best validation (Dice + CCE) loss: {:4f}'.format(best_loss))
    
    model.load_state_dict(best_model_wts)
    
    return model

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
import re
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    model_ft = None
    input_size = 0
    
    if model_name == "resnet50":
        """Resnet50"""
        model_ft = resnet50(weights=None)
        checkpoints = torch.load('pretrained_weights/resnet50-11ad3fa6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "wide_resnet101":
        """Wide Resnet101 V2"""
        model_ft = wide_resnet101_2(weights=None)
        checkpoints = torch.load('pretrained_weights/wide_resnet101_2-d733dc28.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    
    if model_name == "efficientnet":
        """EFFICIENTNET_V2_M"""
        model_ft = efficientnet_v2_m(weights=None)
        checkpoints = torch.load('pretrained_weights/efficientnet_v2_m-dc08266a.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "convnext":
        """convnext_base"""
        model_ft = convnext_base(weights=None)
        checkpoints = torch.load('pretrained_weights/convnext_base-6075fbad.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[2].in_features
        model_ft.classifier[2] = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    
    if model_name == "vgg":
        """VGG19_bn"""
        model_ft = vgg19_bn(weights=None)
        checkpoints = torch.load('pretrained_weights/vgg19_bn-c79401a0.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    if model_name == "regnet":
        """regnet_x_32gf"""
        model_ft = regnet_x_32gf(weights=None)
        checkpoints = torch.load('pretrained_weights/regnet_x_32gf-6eb8fdc6.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "swin":
        """swin_b"""
        model_ft = swin_b(weights=None)
        checkpoints = torch.load('pretrained_weights/swin_b-68c6b09e.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.head.in_features
        model_ft.head = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    if model_name == "maxvit":
        """maxvit"""
        model_ft = maxvit_t(weights=None)
        checkpoints = torch.load('pretrained_weights/maxvit_t-bc5ab103.pth')
        model_ft.load_state_dict(checkpoints)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[5].in_features
        model_ft.classifier[5] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    return model_ft, input_size

In [None]:
# Building an inherent deep ensemble model that learns and combines representations
# from different pretrained models into a joint presentation

class JoinRepresentationEnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models)
    
    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(model(x))
        
        outputs = torch.stack(outputs, dim=0) # (n_models, batch_size, 5)
        outputs = outputs.sum(dim=0) # (batch_size, 5)
        return outputs

In [None]:
all_models = []
models_dict = {}

models_ = models + ['shared_model']
for model_name in models_:
    if model_name != 'shared_model':
        model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
        all_models.append(model_ft)
    else:
        input_size = 224
        model_ft = JoinRepresentationEnsembleModel(all_models)

    # Data augmentation and normalization for training
    # Other transformation are avoided as they could induce noise in the data samples
    # Just normalization for validation
    
    data_transforms = {
        'training': transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(MEANS, STDS)
        ]),
        'validation': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(MEANS, STDS)
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['training', 'validation']}
    dataloaders_dict = {x[0]: torch.utils.data.DataLoader(image_datasets[x[0]], batch_size=batch_size, shuffle=x[1], num_workers=4) for x in [('training', True), ('validation', False)]}
    num_gpus = [i for i in range(torch.cuda.device_count())]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if len(num_gpus) > 1:
        print("Let's use", len(num_gpus), "GPUs!")
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in num_gpus)
        model_ft = torch.nn.DataParallel(model_ft, device_ids=num_gpus)
        model_ft = model_ft.module
    
    model_ft = model_ft.to(device)

    params_to_update = model_ft.parameters()
    if feature_extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                pass
    
    model_lr_map = {'resnet50': 1e-4, 'efficientnet': 4e-4, 'convnext': 1e-5, "wide_resnet101": 1e-4,
                   "vgg": 1e-4, "regnet": 1e-5, "swin": 1e-5, "maxvit": 1e-4, 'shared_model': 1e-4}    
    lr_ = model_lr_map[model_name]

    print('Training {} with lr: {}'.format(model_name, lr_))
    optimizer_ft = optim.AdamW(params_to_update, lr=lr_)
    criterion = nn.CrossEntropyLoss()
    model_ft = train_model(model_ft, model_name, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=False, lr=lr_)
    
    models_dict[model_name] = model_ft

In [None]:
import glob
import pandas as pd

test_data = os.path.join('/','home','ngsci', 'project', 'ami-ahead-wombcare', 'testing')
test_slides_fp = os.path.join(test_data,'*')
test_slides_list = glob.glob(test_slides_fp)
print('Test Images: {}'.format(len(test_slides_list)))

In [None]:
from PIL import Image
from tqdm import tqdm
import csv
import pandas as pd

def run_inference_image(path, model):
    model.eval()
    slide_id = path.split('/')[-1].split('.')[0]
    image = Image.open(path)

    transform_data = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(MEANS, STDS)])
    
    img_t = transform_data(image)
    img_t = img_t.float().unsqueeze(0)
    with torch.no_grad():
        output = model(img_t.to(device))
    
    prediction = output.squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    all_proba_class_id = prediction.cpu().numpy().tolist() + [class_id]
    return slide_id, np.array(all_proba_class_id)

def run_inference(paths, model):
    predictions = []
    for index in tqdm(range(len(paths)), desc ="Evaluation Progress"):
        predictions.append(run_inference_image(paths[index], model))
    pred_dict = {biopsy: preds for biopsy, preds in predictions}
    return pred_dict

def build_model(model_name):
    
    if model_name == "resnet50":
        lr = 1e-4
        model_ft = resnet50(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "wide_resnet101":
        lr = 1e-4
        model_ft = wide_resnet101_2(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "vgg":
        lr = 1e-4
        model_ft = vgg19_bn(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "efficientnet":
        lr = 4e-4
        model_ft = efficientnet_v2_m(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "convnext":
        lr = 1e-5
        model_ft = convnext_base(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[2].in_features
        model_ft.classifier[2] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "regnet":
        lr = 1e-5
        model_ft = regnet_x_32gf(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "swin":
        lr = 1e-5
        model_ft = swin_b(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.head.in_features
        model_ft.head = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft

    if model_name == "maxvit":
        lr = 1e-4
        model_ft = maxvit_t(weights=None)
        checkpoints = torch.load('final_weights/ami-ahead-wombcare_{}_{}_{}_{}.pt'.format(model_name, batch_size, num_epochs, lr))
        num_ftrs = model_ft.classifier[5].in_features
        model_ft.classifier[5] = nn.Linear(num_ftrs, 5)
        model_ft.load_state_dict(checkpoints)
        return model_ft
    
    if model_name == "shared_model":
        lr = 1e-4
        model_ft = models_dict[model_name]
        return model_ft
    
def to_device(model):
    if len(num_gpus) > 1:
        model = torch.nn.DataParallel(model, device_ids=num_gpus)
        model = model.module
    model = model.to(device)
    return model

def save_predictions(pred_dict, name_model):
    frame = pd.DataFrame()
    frame['slide_id'] = list(pred_dict.keys())
    preds = np.array(list(pred_dict.values()))
    frame['prob_stage_0'] = preds[:, 0]
    frame['prob_stage_1'] = preds[:, 1]
    frame['prob_stage_2'] = preds[:, 2]
    frame['prob_stage_3'] = preds[:, 3]
    frame['prob_stage_4'] = preds[:, 4]
    frame['stage_pred'] = preds[:, 5]
    frame.to_csv('predictions/predictions_{}_{}_{}.csv'.format(name_model, batch_size, num_epochs), index=False)

for model_ in models:
    print('Predicting for {}'.format(model_))
    predictions_dict = run_inference(test_slides_list, to_device(build_model(model_)))
    save_predictions(predictions_dict, model_)

In [None]:
home = os.getenv("HOME")
contest_dir = os.path.join(home, "datasets", "brca-psj-path", "contest-phase-2")
slide_manifest = pd.read_csv(os.path.join(contest_dir, "slide-manifest-holdout.csv"))

for model_ in models_:
    file_ = 'predictions/predictions_{}_{}_{}.csv'.format(model_, batch_size, num_epochs)
    predictions_model = pd.read_csv(file_)
    
    biopsy_stage_prediction = (
        predictions_model
        .merge(slide_manifest)
        .drop(columns=['slide_id','slide_path','patient_ngsci_id'])
        .groupby("biopsy_id")
        .mean()
        .reset_index()
    )
    biopsy_stage_prediction.to_csv('predictions/final_predictions_{}_{}_{}.csv'.format(model_, batch_size, num_epochs), index=False)

In [None]:
from scipy.stats import gmean, tmean
import numpy as np

data = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format("efficientnet", batch_size, num_epochs))
biopsy_ids = data.biopsy_id.tolist()
final_frame = pd.DataFrame()

def get_model_column(model_name, column):
    data = pd.read_csv('predictions/final_predictions_{}_{}_{}.csv'.format(model_name, batch_size, num_epochs))
    return data[column].tolist()

columns = ['prob_stage_0', 'prob_stage_1', 'prob_stage_2', 'prob_stage_3', 'prob_stage_4']

for column in columns:
    columns_geo_mean = []
    for name_model in models_:
        columns_geo_mean.append(get_model_column(name_model, column))
    
    geo_mean = [gmean([efficientnet, maxvit, swin, wide_resnet101, vgg, convnext, resnet50, regnet]) *
                tmean([efficientnet, maxvit, swin, wide_resnet101, vgg, convnext, resnet50, regnet])
                for efficientnet, maxvit, swin, wide_resnet101, vgg, convnext, resnet50, regnet
                in zip(*columns_geo_mean)]

    final_frame[column] = geo_mean

final_frame["Sum"] = final_frame.sum(axis=1)
final_frame = final_frame.loc[:,"prob_stage_0":"prob_stage_4"].div(final_frame["Sum"], axis=0)

def column_to_index(x):
    return columns.index(x)

final_frame['biopsy_id'] = biopsy_ids
cols = ['biopsy_id'] + columns
final_frame = final_frame.loc[:, cols]
stage_pred = []

for _ in range(len(final_frame)):
    stage_pred.append(np.argmax(final_frame.loc[_, :].values[1:]))

final_frame['stage_pred'] = stage_pred
final_frame.to_csv('predictions/filename.csv', index=False, header=False)
final_frame.head()

In [None]:
import ngsci
submission_file = 'predictions/wombcare_final.csv'
ngsci.submit_contest_entry(submission_file, description="wombcare_best")