In [None]:
import torchvision.models as models
import mil_dsmil_softmax as mil

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable

from torch.utils.tensorboard import SummaryWriter

import sys
import argparse
import os, glob
import pandas as pd
import csv
import numpy as np
import random
import math
from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, roc_auc_score
from PIL import Image
from collections import OrderedDict
import matplotlib.pyplot as plt
from collections import OrderedDict

import torchvision.transforms.functional as VF
from torchvision import transforms, utils

os.environ['CUDA_VISIBLE_DEVICES']='0'
use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(0)

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
args = argparse.Namespace()
args.num_classes = 2
args.num_feats = 512
args.num_epochs = 90
args.batch_size = 400
args.num_workers = 4
args.top_k = 8
args.lr = 0.0001
args.patch_size = 224
args.img_channel = 3
# args.class_weights = [1, 1, 1]

In [None]:
class FCLayer(nn.Module):
    def __init__(self, in_size, out_size=1):
        super(FCLayer, self).__init__()
        self.fc = nn.Sequential(nn.Linear(in_size, out_size))
    def forward(self, feats):
        x = self.fc(feats)
        return feats, x
i_classifier = FCLayer(in_size=args.num_feats, out_size=args.num_classes).cuda()
b_classifier = mil.BClassifier(input_size=args.num_feats, output_class=args.num_classes).cuda()

# gpu_ids = [0, 1, 2, 3, 4, 5]
# torch.cuda.set_device(gpu_ids[0])
# i_classifier = torch.nn.DataParallel(i_classifier, device_ids=gpu_ids).cuda()

milnet = mil.MILNet(i_classifier, b_classifier).cuda()
print(milnet)
criterion = nn.BCEWithLogitsLoss()

In [None]:
# state_dict = torch.load('tcga-ds-feats-simclr-10x.pth')
# milnet.load_state_dict(state_dict, strict=False)

In [None]:
def get_bag_feats(csv_file_df):
    feats_csv_path = 'data_feats_simclr_10x/' + csv_file_df.iloc[0].split('/')[1] + '.csv'
    df = pd.read_csv(feats_csv_path)
    feats = shuffle(df).reset_index(drop=True)
    feats = feats.to_numpy()
    label = np.zeros(args.num_classes)
    label[int(csv_file_df.iloc[1])] = 1
    return label, feats

In [None]:
def train(train_df):
    csvs = shuffle(train_df).reset_index(drop=True)
    total_loss = 0
    bc = 0
    for i in range(len(train_df)):
        optimizer.zero_grad()
        label, feats = get_bag_feats(train_df.iloc[i])
        bag_label = Variable(Tensor([label]))
        bag_feats = Variable(Tensor([feats]))
        bag_feats = bag_feats.view(-1, args.num_feats)
        ins_prediction, bag_prediction, _, _ = milnet(bag_feats)
        max_prediction, _ = torch.max(ins_prediction, 0)        
        bag_loss = criterion(bag_prediction.view(1, -1), bag_label.view(1, -1))
        max_loss = criterion(max_prediction.view(1, -1), bag_label.view(1, -1))
        loss = 0.5*bag_loss + 0.5*max_loss
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        sys.stdout.write('\r[%d/%d] bag loss: %.4f, %.4f' % (i, len(train_df), max_loss.item(), bag_loss.item()))
    return total_loss / len(train_df)

In [None]:
def test(test_df):
    csvs = shuffle(test_df).reset_index(drop=True)
    total_loss = 0
    test_labels = []
    test_predictions = []
    with torch.no_grad():
        for i in range(len(test_df)):
            label, feats = get_bag_feats(test_df.iloc[i])
            bag_label = Variable(Tensor([label]))
            bag_feats = Variable(Tensor([feats]))
            bag_feats = bag_feats.view(-1, args.num_feats)
            ins_prediction, bag_prediction, _, _ = milnet(bag_feats)
            max_prediction, _ = torch.max(ins_prediction, 0)  
            bag_loss = criterion(bag_prediction.view(1, -1), bag_label.view(1, -1))
            max_loss = criterion(max_prediction.view(1, -1), bag_label.view(1, -1))
            loss = 0.5*bag_loss + 0.5*max_loss
            total_loss = total_loss + loss.item()
            sys.stdout.write('\r[%d/%d] bag loss: %.4f, %.4f' % (i, len(test_df), max_loss.item(), bag_loss.item()))
            test_labels.extend([label])
            test_predictions.extend([(0.0*torch.sigmoid(max_prediction)+1.0*torch.sigmoid(bag_prediction)).squeeze().cpu().numpy()])
    test_labels = np.array(test_labels)
    test_predictions = np.array(test_predictions)
    auc_value, _, thresholds_optimal = multi_label_roc(test_labels, test_predictions, args.num_classes, pos_label=1)
    for i in range(args.num_classes):
        class_prediction_bag = test_predictions[:, i]
        class_prediction_bag[class_prediction_bag>=thresholds_optimal[i]] = 1
        class_prediction_bag[class_prediction_bag<thresholds_optimal[i]] = 0
        test_predictions[:, i] = class_prediction_bag
    bag_score = 0
    for i in range(0, len(test_df)):
        bag_score = np.array_equal(test_labels[i], test_predictions[i]) + bag_score         
    avg_score = bag_score / len(test_df)
    
    return total_loss / len(test_df), avg_score, auc_value, thresholds_optimal

In [None]:
def multi_label_roc(labels, predictions, num_classes, pos_label=1):
    fprs = []
    tprs = []
    thresholds = []
    thresholds_optimal = []
    aucs = []
    for c in range(0, num_classes):
        label = labels[:, c]
        prediction = predictions[:, c]
        fpr, tpr, threshold = roc_curve(label, prediction, pos_label=1)
        fpr_optimal, tpr_optimal, threshold_optimal = optimal_thresh(fpr, tpr, threshold)
        c_auc = roc_auc_score(label, prediction)
        aucs.append(c_auc)
        thresholds.append(threshold)
        thresholds_optimal.append(threshold_optimal)
    return aucs, thresholds, thresholds_optimal

def optimal_thresh(fpr, tpr, thresholds, p=0):
    loss = (fpr - tpr) - p * tpr / (fpr + tpr + 1)
    idx = np.argmin(loss, axis=0)
    return fpr[idx], tpr[idx], thresholds[idx]

In [None]:
runs = 'logs/tcga-ds-feats-simclr-10x'
writer = SummaryWriter(runs)
optimal_score = 0

In [None]:
optimizer = torch.optim.Adam(milnet.parameters(), lr=0.0001, betas=(0.5, 0.9), weight_decay=5e-3)

In [None]:
optimal_score = 0

In [None]:
optimizer = torch.optim.SGD(milnet.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-3)

In [None]:
bags_LUAD_path = pd.read_csv('LUAD.csv')
bags_LUSC_path = pd.read_csv('LUSC.csv')
cut_LUAD = int(len(bags_LUAD_path)*0.8)
cut_LUSC = int(len(bags_LUSC_path)*0.8)
train_bags_LUAD_path = bags_LUAD_path.iloc[0:cut_LUAD, :]
train_bags_LUSC_path = bags_LUSC_path.iloc[0:cut_LUSC, :]
test_bags_LUAD_path = bags_LUAD_path.iloc[cut_LUAD:, :]
test_bags_LUSC_path = bags_LUSC_path.iloc[cut_LUSC:, :]
Tensor = torch.cuda.FloatTensor
for epoch in range(1, 1000):
    train_path = shuffle(train_bags_LUAD_path.append(train_bags_LUSC_path)).reset_index(drop=True)
    test_path = shuffle(test_bags_LUAD_path.append(test_bags_LUSC_path)).reset_index(drop=True)
    train_loss_bag = train(train_path) # iterate all bags
    test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path)
    print('\r Epoch [%d/%d] train loss: %.4f, test loss: %.4f, average score: %.4f, auc_LUAD: %.4f, auc_LUSC: %.4f' % 
          (epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score, aucs[0], aucs[1]))
    print(thresholds_optimal)
    writer.add_scalar(runs+'/loss/train', train_loss_bag, epoch)
    writer.add_scalar(runs+'/loss/test', test_loss_bag, epoch)
    writer.add_scalar(runs+'/score', avg_score, epoch)
    writer.add_scalar(runs+'/auc_LUAD', aucs[0], epoch)
    writer.add_scalar(runs+'/auc_LUSC', aucs[1], epoch)
    current_score = avg_score
    if optimal_score < current_score:
        optimal_score = current_score
        torch.save(milnet.state_dict(), 'tcga-ds-feats-simclr.pth')
        print('saved')