In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

In [None]:
from __future__ import print_function, division
import os
import pickle
import sys
import time
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.metrics import auc as auc_area_calc
from torchvision import datasets, transforms, models
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, confusion_matrix, roc_curve
import copy
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import tensorflow as tf
import datetime, os

# Dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import time
import copy
from torchvision import models, transforms
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

In [None]:
# !pip install googledrivedownloader
# from google_drive_downloader import GoogleDriveDownloader as gdd
# gdd.download_file_from_google_drive(file_id='13vFqjgBqTWLuBwAVSOmZ8HotckfIZjor',
#                                     dest_path='/tmp/tcga/filtered.zip',
#                                     unzip=True)
# !rm /tmp/tcga/filtered.zip

# https://drive.google.com/file/d/1SmsFlYp2CHndQ4Jx_ngHtQfzIT6_4wfO/view?usp=sharing
!pip install googledrivedownloader
from google_drive_downloader import GoogleDriveDownloader as gdd
gdd.download_file_from_google_drive(file_id='1SmsFlYp2CHndQ4Jx_ngHtQfzIT6_4wfO',
                                    dest_path='/tmp/tcga/filtered.zip',
                                    unzip=True)
!rm /tmp/tcga/filtered.zip

In [None]:
# from google_drive_downloader import GoogleDriveDownloader as gdd
# gdd.download_file_from_google_drive(file_id='1WCQ8l-Q-BAGnOczwMXWtmtHySvROaC14',
#                                     dest_path='/tmp/tcga/has_been_moved_and_filtered',
#                                     unzip=False)

from google_drive_downloader import GoogleDriveDownloader as gdd
gdd.download_file_from_google_drive(file_id='1ae3IfTucMSsUNEnC-HNbAaycDjIkI1T0',
                                    dest_path='/tmp/tcga/has_been_moved_and_filtered',
                                    unzip=False)

In [None]:
# filtered_tiles_output_folder = /tmp/tcga/filtered_tiles/
has_been_filtered_filename = "/tmp/tcga/has_been_moved_and_filtered"
data_folders = open(has_been_filtered_filename, 'r').read().splitlines()

In [None]:
slides_folders = list(map(lambda f: os.path.join("/tmp/tcga/", f.replace("/Users/shunshao/Documents/GitHub/tcga_segmentation/output_folder/", "/tmp/tcga/")), data_folders))

In [None]:
len(slides_folders)

In [None]:
slides_folders

In [None]:
# Dataset
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, slides_folders, model_input_size, is_training, max_bag_size, logger, max_dataset_size=None,
                 with_data_augmentation=True, seed=123, normalization_mean=None, normalization_std=None):
        """
        :param slides_folders: list of abs paths of slide folder (which should contains images, summary/label/percent
            files
        :param model_input_size: expected model input size (for cropping)
        :param is_training: True if is training, else False (for data augmentation)
        :param max_bag_size: maximum number of instances to be returned per bag
        """

        def verify_slide_folder_exists(slide_folder):
            if not os.path.exists(slide_folder):
                raise FileExistsError('parent dataset folder %s does not exist' % slide_folder)

        list(map(verify_slide_folder_exists, slides_folders))

        # self.slides_folders = np.asarray(slides_folders)
        self.slides_folders = slides_folders
        self.model_input_size = model_input_size
        self.max_bag_size = max_bag_size
        self.max_dataset_size = max_dataset_size

        self.is_training = is_training

        # self.logger = logger

        self.slides_ids = []  # ids slides
        self.slides_labels = []  # raw str labels
        self.slides_summaries = []  # list of all initial tiles of slides
        self.slides_cases = []  # list of all cases IDs
        self.slides_images_filepaths = []  # list of all in-dataset tilespaths of slides

        self.with_data_augmentation = with_data_augmentation
        normalization_mean = (0, 0, 0) if normalization_mean is None else normalization_mean
        normalization_std = (1, 1, 1) if normalization_std is None else normalization_std
        self.transform = self._define_data_transforms(normalization_mean, normalization_std)

        self.seed = seed

        slides_ids, slides_labels, slides_summaries, slides_cases, slides_images_filepaths = self.load_data()
        self.slides_ids = slides_ids
        self.slides_labels = slides_labels
        self.slides_summaries = slides_summaries
        self.slides_cases = slides_cases
        self.slides_images_filepaths = slides_images_filepaths

        assert len(self.slides_ids) == len(self.slides_labels) == len(self.slides_summaries) == \
               len(self.slides_images_filepaths), 'mismatch in slides containers lengths %s' % (
            ' '.join(str(len(l)) for l in [self.slides_ids, self.slides_labels, self.slides_summaries,
                                           self.slides_images_filepaths]))

        self.retrieve_tiles_ids_with_images = True  # True will return bag of images and associated tiles ids

    def _define_data_transforms(self, mean, std):
        if self.with_data_augmentation:
            return transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.01),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        return transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

    def load_data(self):
        slides_ids, slides_labels, slides_summaries, slides_cases, slides_images_filepaths = [], [], [], [], []

        # Name of expected non-image files for all slides folders
        label_filename = 'label.txt'
        case_id_filename = 'case_id.txt'
        summary_filename = 'summary.txt'

        # Seek all slides folders, and load static data including list of tiles filepaths and bag label
        for i, slide_folder in enumerate(tqdm(self.slides_folders)):
            if self.max_dataset_size is not None and i + 1 > self.max_dataset_size:
                break
            # print(slide_folder)
            # a = filter(lambda f: os.path.isfile(os.path.join(slide_folder, f)), os.listdir(slide_folder))
            # print(a)
            # b = list(a)
            all_slide_files = list(filter(lambda f: os.path.isfile(os.path.join(slide_folder, f)),
                                          os.listdir(slide_folder)))

            # Seek and save label, case_id and summary files: expects 1 and only 1 for each
            for data_filename in [label_filename, case_id_filename, summary_filename]:
                assert sum([f == data_filename for f in all_slide_files]) == 1, \
                    'slide %s: found %d files for %s, expected 1' % (slide_folder,
                                                                     sum([f == data_filename for f in
                                                                          all_slide_files], ),
                                                                     data_filename)

            label_file = os.path.join(slide_folder, [f for f in all_slide_files if f == label_filename][0])
            case_id_file = os.path.join(slide_folder, [f for f in all_slide_files if f == case_id_filename][0])
            summary_file = os.path.join(slide_folder, [f for f in all_slide_files if f == summary_filename][0])
            with open(label_file, 'r') as f:
                slide_label = int(f.read())
            with open(case_id_file, 'r') as f:
                slide_case_id = f.read()
            with open(summary_file, 'r') as f:
                slide_original_tiles = f.read().splitlines()

            # Seek all filtered images of slide (not-background images)
            slide_images_filenames = list(filter(lambda f: f.endswith(('.jpeg', '.jpg', '.png')), all_slide_files))

            if len(slide_images_filenames) == 0:
                self.logger.warning('Discarding slide %s of class %d because there are no images' %
                                    (slide_folder, slide_label))
                continue

            # Save data
            slides_ids.append(os.path.basename(slide_folder))
            slides_labels.append(slide_label)
            slides_summaries.append(slide_original_tiles)
            slides_cases.append(slide_case_id)
            slides_images_filepaths.append(
                list(map(lambda f: os.path.abspath(os.path.join(slide_folder, f)), slide_images_filenames)))

        slides_ids = np.asarray(slides_ids)
        slides_labels = np.asarray(slides_labels)
        print("end\n")

        return slides_ids, slides_labels, slides_summaries, slides_cases, slides_images_filepaths

    # def show_bag(self, bag_idx, savefolder=None):
    #     """ Plot/save tiles sampled from the slide of provided index """
    #     bag = self._get_slide_instances(bag_idx)
    #     bag_label = self.slides_labels[bag_idx]
    #     tr = transforms.ToTensor()
    #     bag = [tr(b) for b in bag]
    #     imgs = make_grid(bag)

    #     npimgs = imgs.numpy()
    #     plt.imshow(np.transpose(npimgs, (1, 2, 0)), interpolation='nearest')
    #     plt.title('Bag label: %s | %d instances' % (bag_label, len(bag)))
    #     if savefolder is not None:
    #         plt.savefig(os.path.join(savefolder, 'show_' + str(bag_idx) + '.png'), dpi=1000)
    #     else:
    #         plt.show()

    def _get_slide_instances(self, item):
        """ Memory load all tiles or randomly sampled tiles from slide of specified index """
        slide_images_filepaths = self.slides_images_filepaths[item]

        # Randomly sample the specified max number of tiles from the slide with replacement
        if self.max_bag_size is not None:
            slide_images_filepaths = random.choices(slide_images_filepaths, k=self.max_bag_size)
        # print(f"the self.slides_folders is {self.slides_folders}")
        # print(f"the slides_labels is {self.slides_labels}")
        # print(f"the self.slides_folders[item] is {self.slides_folders[item]}")
        # Load images
        bag_images = [pil_loader(slide_image_filepath) for slide_image_filepath in slide_images_filepaths]
        # print(f"item is {item}")
        # print(f"len self.slides_summaries[item] is {len(self.slides_summaries[item])}")
        if self.retrieve_tiles_ids_with_images:
            # return bag of images as well as the associated ids of the tiles
            return bag_images, list(map(os.path.basename, slide_images_filepaths)), self.slides_summaries[item], self.slides_folders[item]
        return bag_images

    def __getitem__(self, item):
        if not self.retrieve_tiles_ids_with_images:
            slide_instances = self._get_slide_instances(item)
            slide_instances = torch.stack([self.transform(instance) for instance in slide_instances])
            slide_label = self.slides_labels[item]
            return slide_instances, slide_label

        slide_instances, tiles_ids, slide_summary, slides_folder = self._get_slide_instances(item)
        slide_instances = torch.stack([self.transform(instance) for instance in slide_instances])
        slide_label = self.slides_labels[item]
        return slide_instances, slide_label, tiles_ids, slide_summary, slides_folder

    def __len__(self):
        return len(self.slides_labels)

In [None]:
import random

def split_svs_samples_casewise(svs_files, associated_cases_ids, val_size, test_size, seed=123):
    assert len(svs_files) == len(associated_cases_ids), 'Expected same number of SVS files than associated case ID'
    random.seed(seed)
    train_size = 1. - val_size - test_size

    unique_cases_ids = list(set(associated_cases_ids))
    random.shuffle(unique_cases_ids)
    total_unique_cases_ids = len(unique_cases_ids)

    # Extract cases ids for training, validation and testing sets
    train_cases_ids = unique_cases_ids[:int(train_size*total_unique_cases_ids)]
    val_cases_ids = unique_cases_ids[int(train_size*total_unique_cases_ids):
                                     int(train_size*total_unique_cases_ids)+int(val_size*total_unique_cases_ids)]
    test_cases_ids = unique_cases_ids[int(train_size*total_unique_cases_ids)+int(val_size*total_unique_cases_ids):]
    assert len(train_cases_ids) + len(val_cases_ids) + len(test_cases_ids) == total_unique_cases_ids

    # Compute associated split set for SVS files
    train_svs_files, val_svs_files, test_svs_files = [], [], []
    for svs_file, associated_case_id in zip(svs_files, associated_cases_ids):
        if associated_case_id in train_cases_ids:
            train_svs_files.append(svs_file)
        elif associated_case_id in val_cases_ids:
            val_svs_files.append(svs_file)
        else:
            test_svs_files.append(svs_file)

    return train_svs_files, val_svs_files, test_svs_files

def build_datasets(source_slides_folders, model_input_width, hyper_parameters, logger):
    normalization_channels_mean = (0.6387467, 0.51136744, 0.6061169)
    normalization_channels_std = (0.31200314, 0.3260718, 0.30386254)

    # First load all data into a single Dataset
    whole_dataset = Dataset(slides_folders=source_slides_folders, model_input_size=model_input_width,
                            is_training=False, max_bag_size=hyper_parameters['max_bag_size'],
                            logger=logger, max_dataset_size=hyper_parameters['dataset_max_size'],
                            with_data_augmentation=hyper_parameters['with_data_augmentation'],
                            seed=hyper_parameters['seed'],
                            normalization_mean=normalization_channels_mean,
                            normalization_std=normalization_channels_std)
    whole_cases_ids = whole_dataset.slides_cases
    whole_indexes = list(range(len(whole_dataset)))

    val_size = hyper_parameters['val_size']
    test_size = hyper_parameters['test_size']
    train_idx, val_idx, test_idx = split_svs_samples_casewise(whole_indexes, whole_cases_ids,
                                                              val_size=val_size, test_size=test_size,
                                                              seed=hyper_parameters['seed'])

    val_dataset = torch.utils.data.Subset(whole_dataset, val_idx)
    test_dataset = torch.utils.data.Subset(whole_dataset, test_idx)
    train_dataset = torch.utils.data.Subset(whole_dataset, train_idx)
    train_dataset.dataset.is_training = True
    train_dataset.dataset.transform = train_dataset.dataset._define_data_transforms(normalization_channels_mean,
                                                                                    normalization_channels_std)

    return train_dataset, val_dataset, test_dataset, whole_cases_ids, whole_indexes, whole_dataset

# Model

In [None]:
torch.cuda.is_available()

In [None]:
# for batch_idx, (data, bag_label) in enumerate(train_dataloader):
#   print(f"batch_idx is {batch_idx}, batch size :{data.size(1)}, input size: {data.size(-1)}")

In [None]:
class ResNet50(nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()

        self.pretrained = pretrained
        self.model = models.resnet50(pretrained=self.pretrained)
        self.model.fc = nn.Sequential(nn.Linear(2048, 2),) 
        self.output_activation = nn.LogSoftmax(dim=1)

    def forward(self, x):
        logit = self.model(x)
        output = self.output_activation(logit)

        return output, logit

## cMIL

In [None]:
class cMIL(object):
    """
    criterion: 
              max-max: select the data with highest response as the representative of the image, regardless of the image class.
              max-min: select the data with highest response as the representative of the image for class 1, and those with lowest
                       response to represent class 0.

    """
    purpose_set = {'train', 'valid', 'infer'}
    criterion_set = {'max-max', 'max-min'}

    def __init__(self, writer, criterion, train_loader, val_loader, full_loader, save_path=None, pretrained=False, save_interval=1000):
        # self.data_dict_path = data_dict_path
        self.criterion = criterion
        self.save_path = save_path

        self.pretrained = pretrained
        # self.bag_size = bag_size
        # self.batch_size = batch_size
        # self.valid_ratio = valid_ratio

        # self.gpu_num = gpu_num
        # self.worker_ratio = worker_ratio

        self.save_interval = save_interval

        # Intialize dataloader
        # self.kwargs = {'num_workers': self.gpu_num * self.worker_ratio, 'pin_memory': False} if torch.cuda.is_available() else {}
        # self.kwargs = {} if torch.cuda.is_available() else {}

        # self.dataset_train = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='train', valid_ratio=self.valid_ratio)
        # self.dataset_val = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='valid', valid_ratio=self.valid_ratio)
        # self.dataset_full = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='full')

        # self.training_set_size = train_loader.dataset.__len__()
        # self.valid_set_size = val_loader.dataset.__len__()
        # self.full_set_size = full_loader.dataset.__len__()
        
        self.training_set_size = 0
        self.valid_set_size = 0
        self.full_set_size = 0

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.full_loader = full_loader

        self.best_acc = 0
        self.best_model_wts = 0

        # Initialize model
        self.model = ResNet50(pretrained=self.pretrained)
        self.loss_func = nn.NLLLoss()

        self.writer = writer

    def load_model(self, best_model_wts):
        self.model = ResNet50(pretrained=False)
        self.model.load_state_dict(best_model_wts)

    def train(self, epoch_num, lr=1e-4):
        self.training_set_size = self.train_loader.dataset.__len__()

        # Instantiate optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if torch.cuda.is_available():
            self.model = self.model.cuda()

            # if self.multi_gpu:
            #     self.model = nn.DataParallel(self.model)
        best_acc = 0
        for epoch in range(1, epoch_num + 1):
            epoch_start_time = time.time()

            self.forward_and_backward(purpose='train', epoch=epoch)

            print("Epoch took {} secs.".format(time.time() - epoch_start_time))
        
        print(f'best_acc is {self.best_acc}')
        
        return self.best_model_wts
        # ?
        # if isinstance(self.model, nn.DataParallel):
        #     model = self.model.module
        # else:
        #     model = self.model

        # ?
        # model_cpu = model.cpu()

    def validate(self):
        self.valid_set_size = self.val_loader.dataset.__len__()
        if torch.cuda.is_available():
            self.model = self.model.cuda()

        with torch.no_grad():
            loss, auc, acc, sens, prec = self.forward_and_backward(purpose='valid')

        return loss, auc, acc, sens, prec

    def infer(self, data_loader):
        if torch.cuda.is_available():
            self.model = self.model.cuda()

        if torch.cuda.is_available():
            self.model = self.model.cuda()

            # if self.multi_gpu:
            #     self.model = nn.DataParallel(self.model)

        with torch.no_grad():
            infer_total_list = self.forward_and_backward(purpose='infer', data_loader=data_loader)

        # return infer_total_list
        return infer_total_list
        
    def forward_and_backward(self, purpose='train', epoch=1, data_loader=None):
        init_time = time.time()

        if purpose not in self.purpose_set:
            raise ValueError("Invlid purpose given!")

        auc_list = []
        correct_list = []

        nc_total = 0
        c_total = 0
        correct = 0
        num_rows = 0
        test_loss = 0
        rep_list = []

        if purpose == 'train':
            data_loader = self.train_loader
            self.model.train()

            print()
            print("Training...")
            print("Training set size: {}".format(len(data_loader.dataset)))
            loss_tensorboard = 0
            acc_tensorboard = 0
        elif purpose == 'valid':
            data_loader = self.val_loader
            self.model.eval()

            print()
            print("Validating...")
            print("Test set size: {}".format(len(data_loader.dataset)))

            pred_total_list = []
            target_total_list = []
        elif purpose == 'infer':
            infer_total_list = []

            self.model.eval()

            print()
            print("Inferring...")
            print("Test set size: {}".format(len(data_loader.dataset)))

            pbar = tqdm(total=len(data_loader.dataset))

        # for batch_idx, (data, target, index) in enumerate(data_loader):
        # for batch_idx, (data, target, _, _) in enumerate(data_loader):
        for batch_idx, (data, target, tiles_ids, slide_summary, slides_folder) in enumerate(data_loader):
            # print(f"the slides_folder is {slides_folder}")
            # print(f"the number of tiles is {len(slide_summary)}")
            index = tiles_ids
            train_start_time = time.time()

            if torch.cuda.is_available():
                # print(data)
                # print(target)
                data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
                # data, target = data.cuda(), target.cuda()
            
            if purpose == 'train':
                data, target = data.requires_grad_(), target
                self.optimizer.zero_grad()
            
            batch_size, bag_size, input_size = data.size(0), data.size(1), data.size(-1)
            num_rows += batch_size
            # print(f"the batch size is {batch_size}, the input size is {input_size}")
            # data = data.view(batch_size * self.bag_size, 3, input_size, input_size)
            # print(f"the data before squeeze size is {data.size()}")
            data = torch.squeeze(data, 0)
            # print(f"the data after squeeze size is {data.size()}")
            output, logit = self.model(data)
            # print(f"output size is {output.size()}")
            output = output.view(batch_size, bag_size, 2)

            # index_list = index.cpu().tolist()
            selected_idx = 0
            selected_idx_list = []

            if self.criterion == 'max-max':
                # selected_idx_list = [bag[:, 1].max(0)[1] for _, bag in enumerate(output)]
                # print(f"the output have size {output.size()}")
                for _, bag in enumerate(output):
                    # print(f"the bag has size {bag.size()}")
                    # print(f"the bag is {bag}")
                    maxmax_indices = torch.nonzero(bag[:, 1] == bag[:, 1].max(0)[0])
                    # print(f"the bag bag[:, 1].max(0) {bag[:, 1].max(0)}")
                    # print(f"the bag bag[:, 1].max(0)[0] {bag[:, 1].max(0)[0]}")
                    # print(f"the bag bag filtered {bag[:, 1] == bag[:, 1].max(0)[0]}")
                    # print(f"the maxmax_indices is {maxmax_indices}, the length the of maxmax_indices is {len(maxmax_indices)}")
                    pick_idx = np.random.randint(len(maxmax_indices))
                    # print(f"the pick_idx is {pick_idx}")
                    selected_idx = maxmax_indices[pick_idx].squeeze(0)
                    # print(f"the selected_idx is {selected_idx}")
                    selected_idx_list.append(selected_idx)
            else:
                # selected_idx_list = [bag[:, 1].max(0)[1] if target[i].item() == 1 else bag[:, 1].min(0)[1] \
                #                      for i, bag in enumerate(output)]
                # print(f"the output have size {output.size()}")
                for i, bag in enumerate(output):
                    # print(f"the bag has size {bag.size()}")
                    maxmin_indices = torch.nonzero(bag[:, 1] == bag[:, 1].max(0)[0]) if target[i].item() == 1 \
                                else torch.nonzero(bag[:, 1] == bag[:, 1].min(0)[0])

                    pick_idx = np.random.randint(len(maxmin_indices))
                    selected_idx = maxmin_indices[pick_idx].squeeze(0)
                    selected_idx_list.append(selected_idx)


            selected_torch = torch.stack([output[i][idx] for i, idx in enumerate(selected_idx_list)])

            prediction = selected_torch.max(1)[1]
            correct = (prediction == target).sum().cpu().item()

            pred_cpu = prediction.cpu().detach().numpy()
            # print(f"pred_cpu is {pred_cpu}")
            target_cpu = target.cpu().detach().numpy()
            # print(f"target_cpu is {target_cpu}")
            selected_exp_cpu = torch.exp(selected_torch[:, 1]).cpu().detach().numpy()
            # print(f"selected_idx_list is {selected_idx_list}")
            # index_list = [index[i][0] for i in range(len(index))]
            # print(index_list)
            # return index_list, selected_exp_cpu
            if purpose == 'infer':
                pick_list = list(zip(slides_folder, tiles_ids[selected_idx], pred_cpu.tolist(), target_cpu.tolist()))
                # print(f"The selected_idx is {selected_idx}")
                # print(f"the image is {tiles_ids[selected_idx]}")
                # print(f"the slides_foler is {slides_folder}")
                # print(f"the pick_list is {pick_list}")
                print()
                file_path = slides_folder[0] 
                file_path += '/'
                file_path += tiles_ids[selected_idx][0]
                print(file_path)
                print(f"The most probable cancer slides predicted by {purpose} in the case {slides_folder} is {tiles_ids[selected_idx]}, display as below")
                display(pil_loader(file_path))
                infer_total_list += pick_list
                pbar.update(batch_size)
            else:
                loss_sum = self.loss_func(selected_torch, target)

            if purpose == 'valid':
                try:
                    auc = roc_auc_score(target_cpu, selected_exp_cpu)
                except:
                    pass
                else:
                    auc_list.append(auc)

                test_loss += loss_sum
                correct_list.append(correct)

                target_list = target_cpu.tolist()
                pred_list = pred_cpu.tolist()

                non_cancer_count, cancer_count = target_list.count(0), target_list.count(1)

                nc_total += non_cancer_count
                c_total += cancer_count

                target_total_list += target_list
                pred_total_list += pred_list

            if purpose == 'train':
                loss_sum.backward()
                self.optimizer.step()

                print('Epoch({}-{}) [{}/{} ({:.0f}%)] Loss: {:.6f} Acc: {:.3f} Output: {} Prediction: {} Label: {} ({} sec/step)'.format(
                    epoch, batch_idx,
                    batch_idx * batch_size, self.training_set_size, 100. * batch_idx * batch_size / self.training_set_size, 
                    loss_sum.data.item(),
                    correct / batch_size,
                    list(map(lambda x: round(x, 2), selected_exp_cpu.tolist()[:4])),
                    list(map(lambda x: round(x, 2), pred_cpu.tolist()[:4])),
                    target_cpu.tolist()[:4],
                    round(time.time() - train_start_time, 3)
                    )
                )

                loss_tensorboard += loss_sum.data.item()
                acc_tensorboard += correct

            elif purpose == 'valid':
                print('Val/Infer [{}/{} ({:.0f}%)] Loss: {:.6f} Acc: {:.6f} Output: {} Prediction: {} Label: {} ({} sec/step)'.format(
                    batch_idx * batch_size, self.valid_set_size, 100. * batch_idx * batch_size / self.valid_set_size, 
                    loss_sum.data.item(),
                    correct / batch_size,
                    list(map(lambda x: round(x, 2), selected_exp_cpu.tolist()[:4])),
                    list(map(lambda x: round(x, 2), pred_cpu.tolist()[:4])),
                    target_cpu.tolist()[:4],
                    round(time.time() - train_start_time, 3)
                    )
                )
            else:
                pass

        if purpose == 'valid':
            test_loss /= num_rows
            auc = np.mean(auc_list)
            num_correct = np.sum(correct_list)
            acc = (100. * num_correct) / num_rows
            # print(f"target is {np.array(target_total_list)}")
            # print(f"pred is {np.array(pred_total_list)}")
            sens = precision_score(np.array(target_total_list), np.array(pred_total_list))
            cm = confusion_matrix(np.array(target_total_list), np.array(pred_total_list))
            tn = cm[0, 0]
            tp = cm[1, 1]
            fn = cm[1, 0]
            fp = cm[0, 1]
            spec = tn / (tn + fp)
            prec = precision_score(np.array(target_total_list), np.array(pred_total_list))

            # fpr = fp/(fp + tn)
            # tpr = sens

            # roc_auc = auc(fpr, tpr)
            # plt.figure()
            # lw = 2
            # plt.plot(fpr, tpr, color='darkorange',lw=lw, label='ROC curve (area = %0.2f)' % auc)
            # plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
            # plt.xlim([0.0, 1.0])
            # plt.ylim([0.0, 1.05])
            # plt.xlabel('False Positive Rate')
            # plt.ylabel('True Positive Rate')
            # plt.title('Receiver operating characteristic example')
            # plt.legend(loc="lower right")
            # plt.show()
            print(f'tn is {tn}, tp is {tp}, fn is {fn}, fp is {fp}')
            print("Validation: Average loss: {:.4f}, AUC: {:.3f}, Acc: {}/{} ({:.3f}%) Recall: {:.3f} Prec: {:.3f} Spec: {}\n".format(
                  test_loss, 
                  auc, 
                  num_correct, num_rows, 
                  acc,
                  sens,
                  prec,
                  spec
                  ))
            # try:
            fpr, tpr, _ = roc_curve(np.array(target_total_list), np.array(pred_total_list))
            # print(f"Roc calculation result: fpr is {fpr}, tpr is {tpr}, {type(fpr)}, {type(tpr)}")
            # roc_fpr[:, epoch] = fpr
            # roc_tpr[:, epoch] = tpr
            roc_auc = auc_area_calc(fpr, tpr)
            plt.figure()
            lw = 2
            plt.plot(fpr, tpr, color='darkorange',lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
            plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver operating characteristic example')
            plt.legend(loc="lower right")
            plt.show()
            # except TypeError:
            #     pass
            acc /= 100
            return test_loss.cpu().detach().numpy(), auc, acc, sens, prec

        if purpose == 'infer':
            return infer_total_list
        
        if purpose == 'train':
            val_loss, val_auc, val_acc, val_sens, val_prec = self.validate()
            with self.writer.as_default():
                loss_tensorboard /= num_rows
                error = 1 - (acc_tensorboard / num_rows)
                print(f"summary val acc is {val_acc}")
                tf.summary.scalar('Loss/Train', loss_tensorboard, step=epoch)
                tf.summary.scalar('Error/Train', error, step=epoch)
                tf.summary.scalar('Loss/Val', val_loss, step=epoch)
                tf.summary.scalar('Error/Val', (1 - val_acc), step=epoch)

            if val_acc > self.best_acc:
                print(f"acc update {val_acc}")
                self.best_acc = val_acc
                self.best_model_wts = copy.deepcopy(self.model.state_dict())

            # print("{} secs elapsed.".format(time.time() - init_time))

            # val_auc, val_acc, val_sens, val_prec = self.validate()

            # if self.save_path is not None:
            #     if isinstance(self.model, nn.DataParallel):
            #         model = self.model.module
            #     else:
            #         model = self.model
                
            #     model_cpu = model.cpu()

            #     torch.save(model_cpu.state_dict(), self.save_path + f'{epoch}_{batch_idx}_{str(val_auc)[:5]}_{str(val_acc)[:5]}_{str(val_sens)[:5]}_{str(val_prec)[:5]}.pth')

            #     if torch.cuda.is_available():
            #         self.model = self.model.cuda()

# (self, data_dict_path, criterion, save_path=None, pretrained=False, bag_size=16, batch_size=1, valid_ratio=0.01, gpu_num=1, worker_ratio=4, save_interval=1000)

#     max_max = cMIL(DATA_DICT_PATH, criterion=CRITERION, save_path=SAVE_PATH, pretrained=True, bag_size=BAG_SIZE, 
#                    batch_size=BATCH_SIZE, gpu_num=GPU_NUM, worker_ratio=WORKER_RATIO, valid_ratio=VALID_RATIO,
#                    save_interval=SAVE_INTERVAL)

# (self, criterion, train_loader, val_loader, full_loader, save_path=None, pretrained=False, bag_size=16, batch_size=1, worker_ratio=4, save_interval=1000)

#     max_max = cMIL(criterion=CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, bag_size=BAG_SIZE, 
#                 batch_size=BATCH_SIZE, save_interval=SAVE_INTERVAL)

In [None]:
# class cMIL(object):
#     """
#     criterion: 
#               max-max: select the data with highest response as the representative of the image, regardless of the image class.
#               max-min: select the data with highest response as the representative of the image for class 1, and those with lowest
#                        response to represent class 0.

#     """
#     purpose_set = {'train', 'valid', 'infer'}
#     criterion_set = {'max-max', 'max-min'}

#     def __init__(self, writer, criterion, train_loader, val_loader, full_loader, save_path=None, pretrained=False, save_interval=1000):
#         # self.data_dict_path = data_dict_path
#         self.criterion = criterion
#         self.save_path = save_path

#         self.pretrained = pretrained
#         # self.bag_size = bag_size
#         # self.batch_size = batch_size
#         # self.valid_ratio = valid_ratio

#         # self.gpu_num = gpu_num
#         # self.worker_ratio = worker_ratio

#         self.save_interval = save_interval

#         # Intialize dataloader
#         # self.kwargs = {'num_workers': self.gpu_num * self.worker_ratio, 'pin_memory': False} if torch.cuda.is_available() else {}
#         # self.kwargs = {} if torch.cuda.is_available() else {}

#         # self.dataset_train = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='train', valid_ratio=self.valid_ratio)
#         # self.dataset_val = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='valid', valid_ratio=self.valid_ratio)
#         # self.dataset_full = ImageBagDataset(self.data_dict_path, bag_size=self.bag_size, purpose='full')

#         # self.training_set_size = train_loader.dataset.__len__()
#         # self.valid_set_size = val_loader.dataset.__len__()
#         # self.full_set_size = full_loader.dataset.__len__()
        
#         self.training_set_size = 0
#         self.valid_set_size = 0
#         self.full_set_size = 0

#         self.train_loader = train_loader
#         self.val_loader = val_loader
#         self.full_loader = full_loader

#         self.best_acc = 0
#         self.best_model_wts = 0

#         # Initialize model
#         self.model = ResNet50(pretrained=self.pretrained)
#         self.loss_func = nn.NLLLoss()

#         self.writer = writer

#     def load_model(self, best_model_wts):
#         self.model = ResNet50(pretrained=False)
#         self.model.load_state_dict(best_model_wts)

#     def train(self, epoch_num, lr=1e-4):
#         self.training_set_size = self.train_loader.dataset.__len__()

#         # Instantiate optimizer
#         self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

#         if torch.cuda.is_available():
#             self.model = self.model.cuda()

#             # if self.multi_gpu:
#             #     self.model = nn.DataParallel(self.model)
#         best_acc = 0
#         for epoch in range(1, epoch_num + 1):
#             epoch_start_time = time.time()

#             self.forward_and_backward(purpose='train', epoch=epoch)

#             print("Epoch took {} secs.".format(time.time() - epoch_start_time))
        
#         print(f'best_acc is {self.best_acc}')
        
#         return self.best_model_wts
#         # ?
#         # if isinstance(self.model, nn.DataParallel):
#         #     model = self.model.module
#         # else:
#         #     model = self.model

#         # ?
#         # model_cpu = model.cpu()

#     def validate(self):
#         self.valid_set_size = self.val_loader.dataset.__len__()
#         if torch.cuda.is_available():
#             self.model = self.model.cuda()

#         with torch.no_grad():
#             loss, auc, acc, sens, prec = self.forward_and_backward(purpose='valid')

#         return loss, auc, acc, sens, prec

#     def infer(self, data_loader):
#         if torch.cuda.is_available():
#             self.model = self.model.cuda()

#         if torch.cuda.is_available():
#             self.model = self.model.cuda()

#             if self.multi_gpu:
#                 self.model = nn.DataParallel(self.model)

#         with torch.no_grad():
#             infer_total_list = self.forward_and_backward(purpose='infer', data_loader=data_loader)

#         return infer_total_list
        
#     def forward_and_backward(self, purpose='train', epoch=1, data_loader=None):
#         init_time = time.time()

#         if purpose not in self.purpose_set:
#             raise ValueError("Invlid purpose given!")

#         auc_list = []
#         correct_list = []

#         nc_total = 0
#         c_total = 0
#         correct = 0
#         num_rows = 0
#         test_loss = 0
#         rep_list = []

#         if purpose == 'train':
#             data_loader = self.train_loader
#             self.model.train()

#             print()
#             print("Training...")
#             print("Training set size: {}".format(len(data_loader.dataset)))
#             loss_tensorboard = 0
#             acc_tensorboard = 0
#         elif purpose == 'valid':
#             data_loader = self.val_loader
#             self.model.eval()

#             print()
#             print("Validating...")
#             print("Test set size: {}".format(len(data_loader.dataset)))

#             pred_total_list = []
#             target_total_list = []
#         elif purpose == 'infer':
#             infer_total_list = []

#             self.model.eval()

#             print()
#             print("Inferring...")
#             print("Test set size: {}".format(len(data_loader.dataset)))

#             pbar = tqdm(total=len(data_loader.dataset))

#         # for batch_idx, (data, target, index) in enumerate(data_loader):
#         for batch_idx, (data, target, _, _) in enumerate(data_loader):
#             train_start_time = time.time()

#             if torch.cuda.is_available():
#                 # print(data)
#                 # print(target)
#                 data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
#                 # data, target = data.cuda(), target.cuda()
            
#             if purpose == 'train':
#                 data, target = data.requires_grad_(), target
#                 self.optimizer.zero_grad()
            
#             batch_size, bag_size, input_size = data.size(0), data.size(1), data.size(-1)
#             num_rows += batch_size
#             # print(f"the batch size is {batch_size}, the input size is {input_size}")
#             # data = data.view(batch_size * self.bag_size, 3, input_size, input_size)
#             data = torch.squeeze(data, 0)
#             output, logit = self.model(data)
#             # print(f"output size is {output.size()}")
#             output = output.view(batch_size, bag_size, 2)

#             # index_list = index.cpu().tolist()
#             selected_idx_list = []

#             if self.criterion == 'max-max':
#                 # selected_idx_list = [bag[:, 1].max(0)[1] for _, bag in enumerate(output)]
#                 # print(f"the output have size {output.size()}")
#                 for _, bag in enumerate(output):
#                     print(f"the bag has size {bag.size()}")
#                     print(f"the bag is {bag}")
#                     maxmax_indices = torch.nonzero(bag[:, 1] == bag[:, 1].max(0)[0])

#                     pick_idx = np.random.randint(len(maxmax_indices))
#                     selected_idx = maxmax_indices[pick_idx].squeeze(0)

#                     selected_idx_list.append(selected_idx)
#             else:
#                 # selected_idx_list = [bag[:, 1].max(0)[1] if target[i].item() == 1 else bag[:, 1].min(0)[1] \
#                 #                      for i, bag in enumerate(output)]
#                 # print(f"the output have size {output.size()}")
#                 for i, bag in enumerate(output):
#                     # print(f"the bag has size {bag.size()}")
#                     maxmin_indices = torch.nonzero(bag[:, 1] == bag[:, 1].max(0)[0]) if target[i].item() == 1 \
#                                 else torch.nonzero(bag[:, 1] == bag[:, 1].min(0)[0])

#                     pick_idx = np.random.randint(len(maxmin_indices))
#                     selected_idx = maxmin_indices[pick_idx].squeeze(0)

#                     selected_idx_list.append(selected_idx)

#             selected_torch = torch.stack([output[i][idx] for i, idx in enumerate(selected_idx_list)])

#             prediction = selected_torch.max(1)[1]
#             correct = (prediction == target).sum().cpu().item()

#             pred_cpu = prediction.cpu().detach().numpy()
#             # print(f"pred_cpu is {pred_cpu}")
#             target_cpu = target.cpu().detach().numpy()
#             # print(f"target_cpu is {target_cpu}")
#             selected_exp_cpu = torch.exp(selected_torch[:, 1]).cpu().detach().numpy()
#             print(f"selected_idx_list is {selected_idx_list}")

#             if purpose == 'infer':
#                 pick_list = list(zip(index_list, [idx.cpu().item() for idx in selected_idx_list], pred_cpu.tolist(), target_cpu.tolist()))
#                 infer_total_list += pick_list
#                 pbar.update(batch_size)
#             else:
#                 loss_sum = self.loss_func(selected_torch, target)

#             if purpose == 'valid':
#                 try:
#                     auc = roc_auc_score(target_cpu, selected_exp_cpu)
#                 except:
#                     pass
#                 else:
#                     auc_list.append(auc)

#                 test_loss += loss_sum
#                 correct_list.append(correct)

#                 target_list = target_cpu.tolist()
#                 pred_list = pred_cpu.tolist()

#                 non_cancer_count, cancer_count = target_list.count(0), target_list.count(1)

#                 nc_total += non_cancer_count
#                 c_total += cancer_count

#                 target_total_list += target_list
#                 pred_total_list += pred_list

#             if purpose == 'train':
#                 loss_sum.backward()
#                 self.optimizer.step()

#                 print('Epoch({}-{}) [{}/{} ({:.0f}%)] Loss: {:.6f} Acc: {:.3f} Output: {} Prediction: {} Label: {} ({} sec/step)'.format(
#                     epoch, batch_idx,
#                     batch_idx * batch_size, self.training_set_size, 100. * batch_idx * batch_size / self.training_set_size, 
#                     loss_sum.data.item(),
#                     correct / batch_size,
#                     list(map(lambda x: round(x, 2), selected_exp_cpu.tolist()[:4])),
#                     list(map(lambda x: round(x, 2), pred_cpu.tolist()[:4])),
#                     target_cpu.tolist()[:4],
#                     round(time.time() - train_start_time, 3)
#                     )
#                 )

#                 loss_tensorboard += loss_sum.data.item()
#                 acc_tensorboard += correct

#             elif purpose == 'valid':
#                 print('Val/Infer [{}/{} ({:.0f}%)] Loss: {:.6f} Acc: {:.6f} Output: {} Prediction: {} Label: {} ({} sec/step)'.format(
#                     batch_idx * batch_size, self.valid_set_size, 100. * batch_idx * batch_size / self.valid_set_size, 
#                     loss_sum.data.item(),
#                     correct / batch_size,
#                     list(map(lambda x: round(x, 2), selected_exp_cpu.tolist()[:4])),
#                     list(map(lambda x: round(x, 2), pred_cpu.tolist()[:4])),
#                     target_cpu.tolist()[:4],
#                     round(time.time() - train_start_time, 3)
#                     )
#                 )
#             else:
#                 pass

#         if purpose == 'valid':
#             test_loss /= num_rows
#             auc = np.mean(auc_list)
#             num_correct = np.sum(correct_list)
#             acc = (100. * num_correct) / num_rows
#             # print(f"target is {np.array(target_total_list)}")
#             # print(f"pred is {np.array(pred_total_list)}")
#             sens = precision_score(np.array(target_total_list), np.array(pred_total_list))
#             cm = confusion_matrix(np.array(target_total_list), np.array(pred_total_list))
#             tn = cm[0, 0]
#             tp = cm[1, 1]
#             fn = cm[1, 0]
#             fp = cm[0, 1]
#             spec = tn / (tn + fp)
#             prec = precision_score(np.array(target_total_list), np.array(pred_total_list))

#             # fpr = fp/(fp + tn)
#             # tpr = sens

#             # roc_auc = auc(fpr, tpr)
#             # plt.figure()
#             # lw = 2
#             # plt.plot(fpr, tpr, color='darkorange',lw=lw, label='ROC curve (area = %0.2f)' % auc)
#             # plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
#             # plt.xlim([0.0, 1.0])
#             # plt.ylim([0.0, 1.05])
#             # plt.xlabel('False Positive Rate')
#             # plt.ylabel('True Positive Rate')
#             # plt.title('Receiver operating characteristic example')
#             # plt.legend(loc="lower right")
#             # plt.show()
#             print(f'tn is {tn}, tp is {tp}, fn is {fn}, fp is {fp}')
#             print("Validation: Average loss: {:.4f}, AUC: {:.3f}, Acc: {}/{} ({:.3f}%) Recall: {:.3f} Prec: {:.3f} Spec: {}\n".format(
#                   test_loss, 
#                   auc, 
#                   num_correct, num_rows, 
#                   acc,
#                   sens,
#                   prec,
#                   spec
#                   ))
#             # try:
#             fpr, tpr, _ = roc_curve(np.array(target_total_list), np.array(pred_total_list))
#             # print(f"Roc calculation result: fpr is {fpr}, tpr is {tpr}, {type(fpr)}, {type(tpr)}")
#             # roc_fpr[:, epoch] = fpr
#             # roc_tpr[:, epoch] = tpr
#             roc_auc = auc_area_calc(fpr, tpr)
#             plt.figure()
#             lw = 2
#             plt.plot(fpr, tpr, color='darkorange',lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
#             plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
#             plt.xlim([0.0, 1.0])
#             plt.ylim([0.0, 1.05])
#             plt.xlabel('False Positive Rate')
#             plt.ylabel('True Positive Rate')
#             plt.title('Receiver operating characteristic example')
#             plt.legend(loc="lower right")
#             plt.show()
#             # except TypeError:
#             #     pass
#             acc /= 100
#             return test_loss.cpu().detach().numpy(), auc, acc, sens, prec

#         if purpose == 'infer':
#             return infer_total_list
        
#         if purpose == 'train':
#             val_loss, val_auc, val_acc, val_sens, val_prec = self.validate()
#             with self.writer.as_default():
#                 loss_tensorboard /= num_rows
#                 error = 1 - (acc_tensorboard / num_rows)
#                 print(f"summary val acc is {val_acc}")
#                 tf.summary.scalar('Loss/Train', loss_tensorboard, step=epoch)
#                 tf.summary.scalar('Error/Train', error, step=epoch)
#                 tf.summary.scalar('Loss/Val', val_loss, step=epoch)
#                 tf.summary.scalar('Error/Val', (1 - val_acc), step=epoch)

#             if val_acc > self.best_acc:
#                 print(f"acc update {val_acc}")
#                 self.best_acc = val_acc
#                 self.best_model_wts = copy.deepcopy(self.model.state_dict())

#             # print("{} secs elapsed.".format(time.time() - init_time))

#             # val_auc, val_acc, val_sens, val_prec = self.validate()

#             # if self.save_path is not None:
#             #     if isinstance(self.model, nn.DataParallel):
#             #         model = self.model.module
#             #     else:
#             #         model = self.model
                
#             #     model_cpu = model.cpu()

#             #     torch.save(model_cpu.state_dict(), self.save_path + f'{epoch}_{batch_idx}_{str(val_auc)[:5]}_{str(val_acc)[:5]}_{str(val_sens)[:5]}_{str(val_prec)[:5]}.pth')

#             #     if torch.cuda.is_available():
#             #         self.model = self.model.cuda()

# # (self, data_dict_path, criterion, save_path=None, pretrained=False, bag_size=16, batch_size=1, valid_ratio=0.01, gpu_num=1, worker_ratio=4, save_interval=1000)

# #     max_max = cMIL(DATA_DICT_PATH, criterion=CRITERION, save_path=SAVE_PATH, pretrained=True, bag_size=BAG_SIZE, 
# #                    batch_size=BATCH_SIZE, gpu_num=GPU_NUM, worker_ratio=WORKER_RATIO, valid_ratio=VALID_RATIO,
# #                    save_interval=SAVE_INTERVAL)

# # (self, criterion, train_loader, val_loader, full_loader, save_path=None, pretrained=False, bag_size=16, batch_size=1, worker_ratio=4, save_interval=1000)

# #     max_max = cMIL(criterion=CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, bag_size=BAG_SIZE, 
# #                 batch_size=BATCH_SIZE, save_interval=SAVE_INTERVAL)

In [None]:
# current_time = datetime.datetime.now().strftime("%m%d-%H%M")
# print(current_time)
# log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-8'
# summary_writer = tf.summary.create_file_writer(log_dir)

# EPOCHS = 2
# LR = 5e-8
# SAVE_INTERVAL = 2
# CRITERION = 'max-max'
# SAVE_PATH = None
    
# max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
# best_model_wts = max_max.train(EPOCHS, lr=LR)
# # max_max.load_model(best_model_wts)
# # max.max.validate()


# Max-Max

## Learning rate determination

In [None]:
hyper_parameters = {
    # Training Control Parameters
    
    # Dataset Parameters
    'max_bag_size': 100,
    'dataset_max_size': None,
    'with_data_augmentation': False,
    # 'with_tensorboard': not args.no_tensorboard,
    'seed': 123,
    'val_size': 0.15,
    'test_size': 0,
}

logger = None
input_width = 224
train_dataset, val_dataset, test_dataset, whole_cases_ids, whole_indexes, whole_dataset = build_datasets(source_slides_folders=slides_folders,
                                                              model_input_width=input_width,
                                                              hyper_parameters=hyper_parameters,
                                                              logger=logger)
N_PROCESSES = 5
def to_dataloader(dataset, for_training):
    assert isinstance(dataset, Dataset) or isinstance(dataset, torch.utils.data.Subset)
    return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=for_training, num_workers=N_PROCESSES)

train_dataloader = to_dataloader(train_dataset, True)
val_dataloader = to_dataloader(val_dataset, False) if len(val_dataset) else None
test_dataloader = to_dataloader(test_dataset, False) if len(test_dataset) else None

In [None]:
train_carcinoma = 0
train_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(train_dataloader):
    if bag_label == torch.Tensor([1]):
      train_carcinoma += 1
    else:
      train_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the training set" %(train_carcinoma, train_non_carcinoma))
except TypeError:
  print("Nan")

val_carcinoma = 0
val_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(val_dataloader):
    if bag_label == torch.Tensor([1]):
      val_carcinoma += 1
    else:
      val_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the validation set" %(val_carcinoma, val_non_carcinoma))
except TypeError:
  print("Nan")

test_carcinoma = 0
test_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(test_dataloader):
    if bag_label == torch.Tensor([1]):
      test_carcinoma += 1
    else:
      test_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the test set" %(test_carcinoma, test_non_carcinoma))
except TypeError:
  print("Nan")

In [None]:
# for batch_idx, (data, bag_label, tiles_ids, slide_summary, _) in enumerate(val_dataloader):
#   # print(len(tiles_ids))
#   print(f"There are {len(tiles_ids)} instances in the {batch_idx}-th bag")

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-8'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-8
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-7'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-7
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-6'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-6
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-5'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-5
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-4'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-4
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-3'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-3
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
!rm -r logs

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir logs/learning_rate/

## Full Epochs Train

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/Camel/max-max/full_train/' + current_time +'/lr=5e-4'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 70
LR = 5e-4
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir logs/Camel/max-max/full_train/

## Test and Inference

In [None]:
hyper_parameters = {
    # Training Control Parameters
    
    # Dataset Parameters
    'max_bag_size': 100,
    'dataset_max_size': None,
    'with_data_augmentation': False,
    # 'with_tensorboard': not args.no_tensorboard,
    'seed': 123,
    'val_size': 0.02,
    'test_size': 0.95,
}

logger = None
input_width = 224
train_dataset, val_dataset, test_dataset, whole_cases_ids, whole_indexes, whole_dataset = build_datasets(source_slides_folders=slides_folders,
                                                              model_input_width=input_width,
                                                              hyper_parameters=hyper_parameters,
                                                              logger=logger)
N_PROCESSES = 5
def to_dataloader(dataset, for_training):
    assert isinstance(dataset, Dataset) or isinstance(dataset, torch.utils.data.Subset)
    return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=for_training, num_workers=N_PROCESSES)

train_dataloader = to_dataloader(train_dataset, True)
val_dataloader = to_dataloader(val_dataset, False) if len(val_dataset) else None
test_dataloader = to_dataloader(test_dataset, False) if len(test_dataset) else None

In [None]:
train_carcinoma = 0
train_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(train_dataloader):
    if bag_label == torch.Tensor([1]):
      train_carcinoma += 1
    else:
      train_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the training set" %(train_carcinoma, train_non_carcinoma))
except TypeError:
  print("Nan")

val_carcinoma = 0
val_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(val_dataloader):
    if bag_label == torch.Tensor([1]):
      val_carcinoma += 1
    else:
      val_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the validation set" %(val_carcinoma, val_non_carcinoma))
except TypeError:
  print("Nan")

test_carcinoma = 0
test_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(test_dataloader):
    if bag_label == torch.Tensor([1]):
      test_carcinoma += 1
    else:
      test_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the test set" %(test_carcinoma, test_non_carcinoma))
except TypeError:
  print("Nan")

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/' + current_time +'/Camel/max-max/lr=5e-8'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-8
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None

# To test on the test set, we replace the val_dataloader with test_dataloader and validate directly
cMIL_maxmax_test = cMIL(summary_writer, CRITERION, None, test_dataloader, None, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
cMIL_maxmax_test.load_model(best_model_wts)
val_loss, val_auc, val_acc, val_sens, val_prec = cMIL_maxmax_test.validate()


In [None]:
# Inference
cMIL_maxmax_infer = cMIL(summary_writer, CRITERION, None, None, None, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
cMIL_maxmax_infer.load_model(best_model_wts)
infer_total_list = cMIL_maxmax_infer.infer(test_dataloader)

# Max-Min

## Learning rate determination

In [None]:
hyper_parameters = {
    # Training Control Parameters
    
    # Dataset Parameters
    'max_bag_size': 100,
    'dataset_max_size': None,
    'with_data_augmentation': False,
    # 'with_tensorboard': not args.no_tensorboard,
    'seed': 123,
    'val_size': 0.3,
    'test_size': 0,
}

logger = None
input_width = 224
train_dataset, val_dataset, test_dataset, whole_cases_ids, whole_indexes, whole_dataset = build_datasets(source_slides_folders=slides_folders,
                                                              model_input_width=input_width,
                                                              hyper_parameters=hyper_parameters,
                                                              logger=logger)
N_PROCESSES = 5
def to_dataloader(dataset, for_training):
    assert isinstance(dataset, Dataset) or isinstance(dataset, torch.utils.data.Subset)
    return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=for_training, num_workers=N_PROCESSES)

train_dataloader = to_dataloader(train_dataset, True)
val_dataloader = to_dataloader(val_dataset, False) if len(val_dataset) else None
test_dataloader = to_dataloader(test_dataset, False) if len(test_dataset) else None

In [None]:
train_carcinoma = 0
train_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(train_dataloader):
    if bag_label == torch.Tensor([1]):
      train_carcinoma += 1
    else:
      train_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the training set" %(train_carcinoma, train_non_carcinoma))
except TypeError:
  print("Nan")

val_carcinoma = 0
val_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(val_dataloader):
    if bag_label == torch.Tensor([1]):
      val_carcinoma += 1
    else:
      val_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the validation set" %(val_carcinoma, val_non_carcinoma))
except TypeError:
  print("Nan")

test_carcinoma = 0
test_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(test_dataloader):
    if bag_label == torch.Tensor([1]):
      test_carcinoma += 1
    else:
      test_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the test set" %(test_carcinoma, test_non_carcinoma))
except TypeError:
  print("Nan")

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/Camel/max-min/' + current_time +'/lr=5e-8'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 30
LR = 5e-8
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/Camel/max-min/' + current_time +'/lr=5e-7'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 30
LR = 5e-7
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/Camel/max-min/' + current_time +'/lr=5e-5'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 30
LR = 5e-5
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/Camel/max-min/' + current_time +'/lr=5e-4'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 30
LR = 5e-4
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/learning_rate/Camel/max-min/' + current_time +'/lr=5e-3'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 30
LR = 5e-3
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()


In [None]:
!rm -r logs

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir logs/learning_rate/Camel/max-min/

## Full Epochs train

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/Camel/max-min/' + current_time +'/lr=5e-4'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 70
LR = 5e-4
SAVE_INTERVAL = 2
CRITERION = 'max-min'
SAVE_PATH = None
    
max_max = cMIL(summary_writer, CRITERION, train_dataloader, val_dataloader, test_dataloader, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
best_model_wts = max_max.train(EPOCHS, lr=LR)
# max_max.load_model(best_model_wts)
# max.max.validate()

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir logs/Camel/max-min/0516-2213/

## Test and Inference

In [None]:
hyper_parameters = {
    # Training Control Parameters
    
    # Dataset Parameters
    'max_bag_size': 100,
    'dataset_max_size': None,
    'with_data_augmentation': False,
    # 'with_tensorboard': not args.no_tensorboard,
    'seed': 123,
    'val_size': 0.02,
    'test_size': 0.95,
}

logger = None
input_width = 224
train_dataset, val_dataset, test_dataset, whole_cases_ids, whole_indexes, whole_dataset = build_datasets(source_slides_folders=slides_folders,
                                                              model_input_width=input_width,
                                                              hyper_parameters=hyper_parameters,
                                                              logger=logger)
N_PROCESSES = 5
def to_dataloader(dataset, for_training):
    assert isinstance(dataset, Dataset) or isinstance(dataset, torch.utils.data.Subset)
    return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=for_training, num_workers=N_PROCESSES)

train_dataloader = to_dataloader(train_dataset, True)
val_dataloader = to_dataloader(val_dataset, False) if len(val_dataset) else None
test_dataloader = to_dataloader(test_dataset, False) if len(test_dataset) else None

In [None]:
train_carcinoma = 0
train_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(train_dataloader):
    if bag_label == torch.Tensor([1]):
      train_carcinoma += 1
    else:
      train_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the training set" %(train_carcinoma, train_non_carcinoma))
except TypeError:
  print("Nan")

val_carcinoma = 0
val_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(val_dataloader):
    if bag_label == torch.Tensor([1]):
      val_carcinoma += 1
    else:
      val_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the validation set" %(val_carcinoma, val_non_carcinoma))
except TypeError:
  print("Nan")

test_carcinoma = 0
test_non_carcinoma = 0
try:
  for batch_idx, (data, bag_label, _, _, _) in enumerate(test_dataloader):
    if bag_label == torch.Tensor([1]):
      test_carcinoma += 1
    else:
      test_non_carcinoma += 1
  print("There are %d carcinoma and %d non-carcinoma samples in the test set" %(test_carcinoma, test_non_carcinoma))
except TypeError:
  print("Nan")

In [None]:
current_time = datetime.datetime.now().strftime("%m%d-%H%M")
print(current_time)
log_dir = 'logs/Camel/max-min/test/' + current_time +'/lr=5e-7'
summary_writer = tf.summary.create_file_writer(log_dir)

EPOCHS = 50
LR = 5e-7
SAVE_INTERVAL = 2
CRITERION = 'max-max'
SAVE_PATH = None

# To test on the test set, we replace the val_dataloader with test_dataloader and validate directly
cMIL_maxmax_test = cMIL(summary_writer, CRITERION, None, test_dataloader, None, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
cMIL_maxmax_test.load_model(best_model_wts)
val_loss, val_auc, val_acc, val_sens, val_prec = cMIL_maxmax_test.validate()

In [None]:
# Inference
cMIL_maxmax_infer = cMIL(summary_writer, CRITERION, None, None, None, save_path=SAVE_PATH, pretrained=True, save_interval=SAVE_INTERVAL)
cMIL_maxmax_infer.load_model(best_model_wts)
infer_total_list = cMIL_maxmax_infer.infer(test_dataloader)

# Infer_maxmax