In [None]:
import os
import gc
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.backends import cudnn
from torch.utils.tensorboard import SummaryWriter

import src.utils.custom_transformations as CT
from src.utils import datautils, pyutils, torchutils, metrics, imgutils

cudnn.enabled = True

In [None]:
# import shutil
#
# os.makedirs('test/data', exist_ok=True)
# maximum = np.array([789, 156, 42, 808])
# counts = np.array([0, 0, 0, 0])
# chosen = []
# available = os.listdir(path='datasets/LUAD-HistoSeg/train')
# done = counts >= maximum
# while np.sum(done) < 4:
#     img_name = random.choice(available)
#     picked = False
#     labels = np.array([int(img_name[-12]), int(img_name[-10]), int(img_name[-8]), int(img_name[-6])])
#
#     if (not done[2]) and (labels[2] == 1) and (not picked):
#         chosen.append(img_name)
#         counts += labels
#         picked = True
#
#     if (not done[1]) and (labels[1] == 1) and (not picked):
#         chosen.append(img_name)
#         counts += labels
#         picked = True
#
#     if (not done[0]) and (labels[0] == 1) and (not picked):
#         chosen.append(img_name)
#         counts += labels
#         picked = True
#
#     if (not done[3]) and (labels[3] == 1) and (not picked):
#         chosen.append(img_name)
#         counts += labels
#         picked = True
#
#     if picked:
#         available = [x for x in available if x != img_name]
#
#     done = counts >= maximum
#
# for im in chosen:
#     shutil.copy(src=f'datasets/LUAD-HistoSeg/train/{im}', dst=f'test/data/{im}')

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, first_stride, first_dilation_padding, dilation_padding):
        super(ResBlock, self).__init__()

        self.same_shape = (in_channels == out_channels)

        self.bn_branch2a = nn.BatchNorm2d(num_features=in_channels)
        self.relu_branch2a = nn.ReLU()
        self.conv_branch2a = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=first_stride, padding=first_dilation_padding, dilation=first_dilation_padding, bias=False)

        self.bn_branch2b1 = nn.BatchNorm2d(num_features=mid_channels)
        self.relu_branch2b1 = nn.ReLU()
        self.conv_branch2b1 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=dilation_padding, dilation=dilation_padding, bias=False)

        if not self.same_shape:
            self.conv_branch1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=first_stride, bias=False)

    def forward(self, x, mu=None, gamma=None, cam=None):

        branch2 = self.bn_branch2a(x)
        branch2 = self.relu_branch2a(branch2)

        x_bn_relu = branch2

        if (mu is not None) and (gamma is not None) and (cam is not None):
            branch2 = intermediate_forward_PDA(x=x, mu=mu, gamma=gamma, cam=cam)

        if not self.same_shape:
            branch1 = self.conv_branch1(branch2)
        else:
            branch1 = x

        branch2 = self.conv_branch2a(branch2)

        branch2 = self.bn_branch2b1(branch2)
        branch2 = self.relu_branch2b1(branch2)
        branch2 = self.conv_branch2b1(branch2)

        x = branch2 + branch1

        return x, x_bn_relu

In [None]:
class ResBlockBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_padding=1, dropout=0.0):
        super(ResBlockBottleneck, self).__init__()

        self.same_shape = (in_channels == out_channels)

        self.bn_branch2a = nn.BatchNorm2d(num_features=in_channels)
        self.relu_branch2a = nn.ReLU()
        self.conv_branch2a = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 4, kernel_size=1, bias=False)

        self.bn_branch2b1 = nn.BatchNorm2d(num_features=out_channels // 4)
        self.relu_branch2b1 = nn.ReLU()
        self.conv_branch2b1 = nn.Conv2d(in_channels=out_channels // 4, out_channels=out_channels // 2, kernel_size=3, padding=dilation_padding, dilation=dilation_padding, bias=False)

        self.bn_branch2b2 = nn.BatchNorm2d(num_features=out_channels // 2)
        self.relu_branch2b2 = nn.ReLU()
        self.dropout_branch2b2 = nn.Dropout2d(p=dropout)
        self.conv_branch2b2 = nn.Conv2d(in_channels=out_channels // 2, out_channels=out_channels, kernel_size=1, bias=False)

        if not self.same_shape:
            self.conv_branch1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False)

    def forward(self, x, mu=None, gamma=None, cam=None):

        branch2 = self.bn_branch2a(x)
        branch2 = self.relu_branch2a(branch2)

        x_bn_relu = branch2

        if (mu is not None) and (gamma is not None) and (cam is not None):
            branch2 = intermediate_forward_PDA(x=x, mu=mu, gamma=gamma, cam=cam)

        if not self.same_shape:
            branch1 = self.conv_branch1(branch2)
        else:
            branch1 = x

        branch2 = self.conv_branch2a(branch2)

        branch2 = self.bn_branch2b1(branch2)
        branch2 = self.relu_branch2b1(branch2)
        branch2 = self.conv_branch2b1(branch2)

        branch2 = self.bn_branch2b2(branch2)
        branch2 = self.relu_branch2b2(branch2)
        branch2 = self.conv_branch2b2(branch2)

        x = branch1 + branch2

        return x, x_bn_relu

In [None]:
class ResNet38ClassificationModel(nn.Module):
    def __init__(self, num_classes):
        super(ResNet38ClassificationModel, self).__init__()

        self.enable_PDA = [False, False, False, False, False]
        self.mu = [1, 1, 1, 1, 1]
        self.gamma = [0, 0, 0, 0, 0]

        self.conv1a = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, bias=False)

        #  -------- B1 -----------#
        # Input size = (224, 224),
        # So no B1 layer

        # -------- B2 -----------#
        self.res2a = ResBlock(in_channels=64, mid_channels=128, out_channels=128, first_stride=2, first_dilation_padding=1, dilation_padding=1)
        self.res2b1 = ResBlock(in_channels=128, mid_channels=128, out_channels=128, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res2b2 = ResBlock(in_channels=128, mid_channels=128, out_channels=128, first_stride=1, first_dilation_padding=1, dilation_padding=1)

        # -------- B3 -----------#
        self.res3a = ResBlock(in_channels=128, mid_channels=256, out_channels=256, first_stride=2, first_dilation_padding=1, dilation_padding=1)
        self.res3b1 = ResBlock(in_channels=256, mid_channels=256, out_channels=256, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res3b2 = ResBlock(in_channels=256, mid_channels=256, out_channels=256, first_stride=1, first_dilation_padding=1, dilation_padding=1)

        # -------- B4 -----------#
        self.res4a = ResBlock(in_channels=256, mid_channels=512, out_channels=512, first_stride=2, first_dilation_padding=1, dilation_padding=1)
        self.res4b1 = ResBlock(in_channels=512, mid_channels=512, out_channels=512, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res4b2 = ResBlock(in_channels=512, mid_channels=512, out_channels=512, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res4b3 = ResBlock(in_channels=512, mid_channels=512, out_channels=512, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res4b4 = ResBlock(in_channels=512, mid_channels=512, out_channels=512, first_stride=1, first_dilation_padding=1, dilation_padding=1)
        self.res4b5 = ResBlock(in_channels=512, mid_channels=512, out_channels=512, first_stride=1, first_dilation_padding=1, dilation_padding=1)

        # -------- B5 -----------#
        self.res5a = ResBlock(in_channels=512, mid_channels=512, out_channels=1024, first_stride=1, first_dilation_padding=1, dilation_padding=2)
        self.res5b1 = ResBlock(in_channels=1024, mid_channels=512, out_channels=1024, first_stride=1, first_dilation_padding=2, dilation_padding=2)
        self.res5b2 = ResBlock(in_channels=1024, mid_channels=512, out_channels=1024, first_stride=1, first_dilation_padding=2, dilation_padding=2)

        # -------- B6 -----------#
        self.res6a = ResBlockBottleneck(in_channels=1024, out_channels=2048, dilation_padding=4, dropout=0.3)

        # -------- B7 -----------#
        self.res7a = ResBlockBottleneck(in_channels=2048, out_channels=4096, dilation_padding=4, dropout=0.5)
        self.bn7 = nn.BatchNorm2d(num_features=4096)
        self.relu7 = nn.ReLU()
        # -----------------------#

        self.dropout = nn.Dropout2d(p=0.5)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc8 = nn.Conv2d(in_channels=4096, out_channels=num_classes, kernel_size=1, bias=False)

        nn.init.xavier_uniform_(tensor=self.fc8.weight)
        self.not_training = [self.conv1a, self.res2a, self.res2b1, self.res2b2]
        self.from_scratch_layers = [self.fc8]

    def extract_features(self, imgs, cam=None):

        x = self.conv1a(imgs)  # (b, 64, 224, 224)

        # -------- B2 -----------#
        # if self.enable_PDA[0] and (cam is not None):
        #     x, x1a = self.res2a(x=x, cam=cam, mu=self.mu[0], gamma=self.gamma[0])  # (b, 128, 112, 112), (b, 64, 224, 224)
        # else:
        x, x1a = self.res2a(x=x)  # (b, 128, 112, 112), (b, 64, 224, 224)
        x, x2a = self.res2b1(x=x)  # (b, 128, 112, 112), (b, 128, 112, 112)
        x, x2b1 = self.res2b2(x=x)  # (b, 128, 112, 112), (b, 128, 112, 112)

        # -------- B3 -----------#
        # if self.enable_PDA[0] and (cam is not None):
        #     x, x2b2 = self.res3a(x=x, cam=cam, mu=self.mu[1], gamma=self.gamma[1])  # (b, 256, 56, 56), (b, 128, 112, 112)
        # else:
        x, x2b2 = self.res3a(x=x)  # (b, 256, 56, 56), (b, 128, 112, 112)
        x, x3a = self.res3b1(x=x)  # (b, 256, 56, 56), (b, 256, 56, 56)
        x, x3b1 = self.res3b2(x=x)  # (b, 256, 56, 56), (b, 256, 56, 56)

        # -------- B4 -----------#
        # if self.enable_PDA[0] and (cam is not None):
        #     x, x3b2 = self.res4a(x=x, cam=cam, mu=self.mu[2], gamma=self.gamma[2])  # (b, 512, 28, 28), (b, 256, 56, 56)
        # else:
        x, x3b2 = self.res4a(x=x)  # (b, 512, 28, 28), (b, 256, 56, 56)
        x, x4a = self.res4b1(x=x)  # (b, 512, 28, 28), (b, 512, 28, 28)
        x, x4b1 = self.res4b2(x=x)  # (b, 512, 28, 28), (b, 512, 28, 28)
        x, x4b2 = self.res4b3(x=x)  # (b, 512, 28, 28), (b, 512, 28, 28)
        x, x4b3 = self.res4b4(x=x)  # (b, 512, 28, 28), (b, 512, 28, 28)
        x, x4b4 = self.res4b5(x=x)  # (b, 512, 28, 28), (b, 512, 28, 28)

        # -------- B5 -----------#
        # if self.enable_PDA[3] and (cam is not None):
        #     x, x4b5 = self.res5a(x=x, cam=cam, mu=self.mu[3], gamma=self.gamma[3])  # (b, 1024, 28, 28), (b, 512, 28, 28)
        # else:
        x, x4b5 = self.res5a(x=x)  # (b, 1024, 28, 28), (b, 512, 28, 28)
        x, x5a = self.res5b1(x=x)  # (b, 1024, 28, 28), (b, 1024, 28, 28)
        x, x5b1 = self.res5b2(x=x)  # (b, 1024, 28, 28), (b, 1024, 28, 28)

        # -------- B6 -----------#
        x, x5b2 = self.res6a(x=x)  # (b, 2048, 28, 28), (b, 1024, 28, 28)

        # -------- B7 -----------#
        x, x6a = self.res7a(x=x)  # (b, 4096, 28, 28), (b, 2048, 28, 28)
        x = self.bn7(x)  # (b, 4096, 28, 28)
        x = self.relu7(x)  # (b, 4096, 28, 28)
        # -----------------------#

        # if self.enable_PDA[4] and (cam is not None):
        #     x = final_forward_PDA(x=x, w=self.fc8.weight, mu=self.mu[4], gamma=self.gamma[4])  # (b, 4096, 28, 28)

        return {
            'x1a': x1a,

            'x2a': x2a,
            'x2b1': x2b1,
            'x2b2': x2b2,

            'x3a': x3a,
            'x3b1': x3b1,
            'x3b2': x3b2,

            'x4a': x4a,
            'x4b1': x4b1,
            'x4b2': x4b2,
            'x4b3': x4b3,
            'x4b4': x4b4,
            'x4b5': x4b5,

            'x5a': x5a,
            'x5b1': x5b1,
            'x5b2': x5b2,

            'x6a': x6a,

            'x7a': x
        }

    def make_cam(self, x):  # (b, 3, 224, 224)
        with torch.no_grad():
            x7a = self.extract_features(imgs=x, cam=None)['x7a']  # (b, 4096, 28, 28)

            cam = torch.conv2d(input=x7a, weight=self.fc8.weight)  # (b, 4, 28, 28)
            cam = torch.relu(input=cam)  # (b, 4, 28, 28)

            cam = torchutils.standard_scale(x=cam, dims=-3)
        return cam  # (b, 4, 28, 28)

    def forward(self, x):  # (b, 3, 224, 224)
        ##########################################
        cam = self.make_cam(x=x) if (sum(self.enable_PDA) > 0) else None  # (b, 4, 28, 28)
        ##########################################
        x = self.extract_features(imgs=x, cam=cam)['x7a']  # (b, 4096, 28, 28)
        x = self.dropout(x)  # (b, 4096, 28, 28)
        x = self.avgpool(x)  # (b, 4096, 1, 1)
        x = self.fc8(x)  # (b, 4, 1, 1)
        return x  # (b, 4, 1, 1)

    def forward_cam(self, x):  # (b, 3, 224, 224)
        x = self.extract_features(imgs=x, cam=None)['x7a']  # (b, 4096, 28, 28)

        cam = torch.conv2d(input=x, weight=self.fc8.weight)  # (b, 4, 28, 28)
        cam = torch.relu(input=cam)  # (b, 4, 28, 28)

        z = self.avgpool(x)  # (b, 4096, 1, 1)
        z = self.fc8(z)  # (b, 4, 1, 1)
        z = torch.sigmoid(input=z)  # (b, 4, 1, 1)

        return cam, z  # (b, 4, 28, 28), (b, 4, 1, 1)

    def get_parameter_groups(self):
        groups = ([], [], [], [], [], [])

        for m in self.modules():

            if isinstance(m, nn.Conv2d):

                if m in self.not_training:
                    groups[0].append(m.weight)
                elif m in self.from_scratch_layers:
                    groups[4].append(m.weight)
                else:
                    groups[2].append(m.weight)

                if m.bias is not None:
                    if m in self.not_training:
                        groups[1].append(m.bias)
                    elif m in self.from_scratch_layers:
                        groups[5].append(m.bias)
                    else:
                        groups[3].append(m.bias)

        return groups

    def train(self, mode=True, TL=True):
        super(ResNet38ClassificationModel, self).train(mode)

        for layer in self.not_training:

            if isinstance(layer, nn.Conv2d):
                layer.weight.requires_grad = False if TL else True
                if layer.bias is not None:
                    layer.bias.requires_grad = False if TL else True

            elif isinstance(layer, nn.Module):
                for c in layer.children():
                    if hasattr(c, 'weight'):
                        c.weight.requires_grad = False if TL else True
                    if hasattr(c, 'bias'):
                        if c.bias is not None:
                            c.bias.requires_grad = False if TL else True

        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                layer.bias.requires_grad = False
                layer.weight.requires_grad = False

In [None]:
class Engine:
    def __init__(
            self,
            train_cls_log_path,
            val_cls_log_path,
            val_cam_log_path,
            test_cam_log_path,
            train_data_path,
            val_data_path,
            test_data_path,
            init_weights,
            init_weights_path,
            checkpoints_path,
            device,
            val_size,
            batch_size,
            num_classes,
            session_name,
            max_epochs,
            lr,
            wt_dec,
            start_PDA,
            sigma,
            l,
            gamma
    ):
        self.gs = [1, 1]
        self.checkpoints_path = checkpoints_path
        self.device = device
        self.session_name = session_name
        self.max_epochs = max_epochs
        self.lr = lr
        self.wt_dec = wt_dec
        self.start_PDA = start_PDA
        self.init_mu = init_mu
        self.sigma = sigma
        self.l = l
        self.gamma = gamma

        os.makedirs(name=train_cls_log_path, exist_ok=True)
        os.makedirs(name=val_cls_log_path, exist_ok=True)
        os.makedirs(name=val_cam_log_path, exist_ok=True)
        os.makedirs(name=test_cam_log_path, exist_ok=True)

        self.train_cls_writer = SummaryWriter(log_dir=train_cls_log_path)
        self.val_cls_writer = SummaryWriter(log_dir=val_cls_log_path)
        self.val_cam_writer = SummaryWriter(log_dir=val_cam_log_path)
        self.test_cam_writer = SummaryWriter(log_dir=test_cam_log_path)

        train_cls_data, val_cls_data = datautils.split_classification_data(
            images_path=train_data_path,
            split_size=val_size,
        )

        self.train_cls_loader = datautils.get_LUAD_HistoSeg_classification_dataloader(
            images_path=train_data_path,
            image_names=train_cls_data,
            trans=CT.Compose([
                CT.RandomHorizontalFlip(p=0.5),
                CT.RandomVerticalFlip(p=0.5),
                CT.Random90Rotation(p=0.5),
                CT.Random180Rotation(p=0.5),
                CT.Random270Rotation(p=0.5),
                CT.ToTensor(),
            ]),
            batch_size=batch_size,
            shuffle=True,
        )

        self.val_cls_loader = datautils.get_LUAD_HistoSeg_classification_dataloader(
            images_path=train_data_path,
            image_names=val_cls_data,
            trans=CT.Compose([
                CT.ToTensor(),
            ]),
            batch_size=batch_size,
            shuffle=False,
        )

        self.val_cam_loader = datautils.get_LUAD_HistoSeg_segmentation_dataloader(
            images_path=os.path.join(val_data_path, 'img'),
            masks_path=os.path.join(val_data_path, 'mask'),
            trans=CT.Compose([
                CT.ToTensor(),
            ]),
            batch_size=batch_size,
            shuffle=False,
        )

        self.test_cam_loader = datautils.get_LUAD_HistoSeg_segmentation_dataloader(
            images_path=os.path.join(test_data_path, 'img'),
            masks_path=os.path.join(test_data_path, 'mask'),
            trans=CT.Compose([
                CT.ToTensor(),
            ]),
            batch_size=batch_size,
            shuffle=False,
        )

        self.model = ResNet38ClassificationModel(num_classes=num_classes)
        self.train_cls_writer.add_graph(model=self.model, input_to_model=torch.Tensor(torch.zeros(size=(1, 3, 224, 224))))

        if init_weights is not None:
            wp = os.path.join(init_weights_path, init_weights)

            if init_weights[-7:] == '.params':
                weights_dict = torchutils.convert_mxnet_weights_to_torch(weights_path=wp)

                # Strict=False because of linear1000 and fc8
                self.model.load_state_dict(state_dict=weights_dict, strict=False)
                print('Initialize model with MXNet weights')

            elif init_weights[-4:] == '.pth':
                weights_dict = torch.load(f=wp, map_location='cpu')
                self.model.load_state_dict(state_dict=weights_dict, strict=True)
                print('Initialize model with user-defined weights')

            else:
                raise NotImplementedError(f'Invalid model weights {init_weights}')
        else:
            print('Initialize model with random weights')

        param_groups = self.model.get_parameter_groups()

        self.optimizer = torch.optim.SGD(params=[
            {'params': param_groups[0], 'lr': self.lr * 0.1, 'weight_decay': self.wt_dec},
            {'params': param_groups[1], 'lr': self.lr * 0.2, 'weight_decay': 0},
            {'params': param_groups[2], 'lr': self.lr * 10, 'weight_decay': self.wt_dec},
            {'params': param_groups[3], 'lr': self.lr * 20, 'weight_decay': 0},
            {'params': param_groups[4], 'lr': self.lr * 100, 'weight_decay': self.wt_dec},
            {'params': param_groups[5], 'lr': self.lr * 200, 'weight_decay': 0},
        ], lr=self.lr, weight_decay=self.wt_dec)

        self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=1, gamma=0.85)

        self.cls_criterion = metrics.WeightedMultiLabelSoftMarginLoss()
        self.cls_evaluator = metrics.IoUAccuracy()
        self.cam_evaluator = metrics.PseudoMaskEvaluator(num_classes=num_classes)

    def save_checkpoint(self, epoch, history, state_dict_path):
        state_dict = {
            'session_name': self.session_name,
            'epoch': epoch,
            'global_steps': self.gs,

            'train_cls_writer_logdir': self.train_cls_writer.get_logdir(),
            'train_cls_writer_logname': self.train_cls_writer.file_writer.event_writer._file_name,
            'val_cls_writer_logdir': self.val_cls_writer.get_logdir(),
            'val_cls_writer_logname': self.val_cls_writer.file_writer.event_writer._file_name,
            'val_cam_writer_logdir': self.val_cam_writer.get_logdir(),
            'val_cam_writer_logname': self.val_cam_writer.file_writer.event_writer._file_name,
            'test_cam_writer_logdir': self.test_cam_writer.get_logdir(),
            'test_cam_writer_logname': self.test_cam_writer.file_writer.event_writer._file_name,

            'train_cls_loader': self.train_cls_loader,
            'val_cls_loader': self.val_cls_loader,
            'val_cam_loader': self.val_cam_loader,
            'test_cam_loader': self.test_cam_loader,

            'model_state_dict': self.model.state_dict(),
            'model_mu': self.model.mu,
            'model_gamma': self.model.gamma,
            'model_enable_PDA': self.model.enable_PDA,

            'optimizer': self.optimizer,
            'scheduler': self.scheduler,
            'history': history
        }

        torch.save(obj=state_dict, f=state_dict_path)

    def resume_checkpoint(self, state_dict_path=None):

        state_dict = torch.load(f=state_dict_path, map_location='cpu')
        session = state_dict['session_name']
        assert session == self.session_name, f"State dict {session} is not for session {self.session_name}"
        print('Checkpoint Loaded')

        epoch = state_dict['epoch']
        self.gs = state_dict['global_steps']

        self.train_cls_writer.log_dir = state_dict['train_cls_writer_logdir']
        self.train_cls_writer.file_writer.event_writer._file_name = state_dict['train_cls_writer_logname'],
        self.val_cls_writer.log_dir = state_dict['val_cls_writer_logdir']
        self.val_cls_writer.file_writer.event_writer._file_name = state_dict['val_cls_writer_logname'],
        self.val_cam_writer.log_dir = state_dict['val_cam_writer_logdir']
        self.val_cam_writer.file_writer.event_writer._file_name = state_dict['val_cam_writer_logname'],
        self.test_cam_writer.log_dir = state_dict['test_cam_writer_logdir']
        self.test_cam_writer.file_writer.event_writer._file_name = state_dict['test_cam_writer_logname'],

        self.train_cls_loader = state_dict['train_cls_loader']
        self.val_cls_loader = state_dict['val_cls_loader']
        self.val_cam_loader = state_dict['val_cam_loader']
        self.test_cam_loader = state_dict['test_cam_loader']

        self.model.load_state_dict(state_dict=state_dict['model_state_dict'])
        self.model.mu = state_dict['model_mu']
        self.model.gamma = state_dict['model_gamma']
        self.model.enable_PDA = state_dict['model_enable_PDA']

        self.optimizer = state_dict['optimizer']
        self.scheduler = state_dict['scheduler']
        history = state_dict['history']

        self.fit(ep=epoch + 1, epoch_history=history)

    def train_cls_one_epoch(self, ep, thresh=0.5, print_incorrects=False):
        print(f'{"#" * 20} Train Classification E{ep} {"#" * 20}')

        h = {
            'MLSMLoss': [],
            'IoUAccuracy': [],
            'ExactMatch': [],
            'TE_IoUAccuracy': [],
            'NEC_IoUAccuracy': [],
            'LYM_IoUAccuracy': [],
            'TAS_IoUAccuracy': [],
            'TE_MLSMLoss': [],
            'NEC_MLSMLoss': [],
            'LYM_MLSMLoss': [],
            'TAS_MLSMLoss': [],
        }
        for i, params in enumerate(self.optimizer.param_groups):
            h[f'LRG{i}'] = []

        t = tqdm(self.train_cls_loader)

        for iter, sample in enumerate(t):
            names = sample['name']
            images = sample['img'].to(device=self.device)
            labels = sample['label'].to(device=self.device)

            x = self.model(images)
            x = x.view(x.size(0), -1)

            loss_class = self.cls_criterion.train_call(input=x, target=labels)

            loss = loss_class.mean()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            gc.collect()
            torch.cuda.empty_cache()
            loss = loss.item()
            loss_class = loss_class.detach().cpu().numpy()
            probs = torch.sigmoid(input=x).detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            scores = self.cls_evaluator(probs=probs, y=labels)

            if print_incorrects:
                torchutils.print_incorrects(names=names, probs=probs, labels=labels, thresh=thresh)

            h['MLSMLoss'].append(loss)
            h['IoUAccuracy'].append(scores['accuracy'])
            h['ExactMatch'].append(scores['exact_match'])
            h['TE_IoUAccuracy'].append(scores['class_acc'][0])
            h['NEC_IoUAccuracy'].append(scores['class_acc'][1])
            h['LYM_IoUAccuracy'].append(scores['class_acc'][2])
            h['TAS_IoUAccuracy'].append(scores['class_acc'][3])
            h['TE_MLSMLoss'].append(loss_class[0])
            h['NEC_MLSMLoss'].append(loss_class[1])
            h['LYM_MLSMLoss'].append(loss_class[2])
            h['TAS_MLSMLoss'].append(loss_class[3])

            self.train_cls_writer.add_scalar(tag='Batch/MLSMLoss', scalar_value=loss, global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/IoUAccuracy', scalar_value=scores['accuracy'], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/ExactMatch', scalar_value=scores['exact_match'], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/TE_IoUAccuracy', scalar_value=scores['class_acc'][0], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/NEC_IoUAccuracy', scalar_value=scores['class_acc'][1], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/LYM_IoUAccuracy', scalar_value=scores['class_acc'][2], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/TAS_IoUAccuracy', scalar_value=scores['class_acc'][3], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/TE_MLSMLoss', scalar_value=loss_class[0], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/NEC_MLSMLoss', scalar_value=loss_class[1], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/LYM_MLSMLoss', scalar_value=loss_class[2], global_step=self.gs[0])
            self.train_cls_writer.add_scalar(tag='Batch/TAS_MLSMLoss', scalar_value=loss_class[3], global_step=self.gs[0])

            for i, params in enumerate(self.optimizer.param_groups):
                self.train_cls_writer.add_scalar(tag=f'Batch/LRG{i}', scalar_value=params['lr'], global_step=self.gs[0])
                h[f'LRG{i}'].append(params['lr'])

            self.gs[0] += 1

            t.set_description(
                desc=f"E{ep} Train Loss: {loss:0.4f}, "
                f"Acc: {scores['accuracy']:0.4f}, "
                f"EM: {scores['exact_match']:0.4f}, "
            )

        self.train_cls_writer.add_scalar(tag='Epoch/MLSMLoss', scalar_value=np.nanmean(h['MLSMLoss']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/IoUAccuracy', scalar_value=np.nanmean(h['IoUAccuracy']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/ExactMatch', scalar_value=np.nanmean(h['ExactMatch']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/TE_IoUAccuracy', scalar_value=np.nanmean(h['TE_IoUAccuracy']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/NEC_IoUAccuracy', scalar_value=np.nanmean(h['NEC_IoUAccuracy']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/LYM_IoUAccuracy', scalar_value=np.nanmean(h['LYM_IoUAccuracy']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/TAS_IoUAccuracy', scalar_value=np.nanmean(h['TAS_IoUAccuracy']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/TE_MLSMLoss', scalar_value=np.nanmean(h['TE_MLSMLoss']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/NEC_MLSMLoss', scalar_value=np.nanmean(h['NEC_MLSMLoss']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/LYM_MLSMLoss', scalar_value=np.nanmean(h['LYM_MLSMLoss']), global_step=ep)
        self.train_cls_writer.add_scalar(tag='Epoch/TAS_MLSMLoss', scalar_value=np.nanmean(h['TAS_MLSMLoss']), global_step=ep)

        # for i, epda in enumerate(self.model.enable_PDA):
        #     self.train_cls_writer.add_scalar(tag=f'Epoch/Enable_PDA{i}', scalar_value=epda, global_step=ep)
        #
        # for i, mu in enumerate(self.model.mu):
        #     self.train_cls_writer.add_scalar(tag=f'Epoch/Mu{i}', scalar_value=mu, global_step=ep)
        #
        # for i, gamma in enumerate(self.model.gamma):
        #     self.train_cls_writer.add_scalar(tag=f'Epoch/Gamma{i}', scalar_value=gamma, global_step=ep)

        for k, v in h.items():
            print(f'Train Classification {k}: {np.nanmean(h[k]):0.4f}')
        print(f'Train Classification model_enable_PDA: {self.model.enable_PDA}')
        print(f'Train Classification model_mu: {self.model.mu}')
        print(f'Train Classification model_gamma: {self.model.gamma}')
        return h

    def validate_cls_one_epoch(self, ep, thresh=0.5, print_incorrects=False):
        print(f'{"#" * 20} Validating Classification E{ep} {"#" * 20}')

        h = {
            'MLSMLoss': [],
            'IoUAccuracy': [],
            'ExactMatch': [],
            'TE_IoUAccuracy': [],
            'NEC_IoUAccuracy': [],
            'LYM_IoUAccuracy': [],
            'TAS_IoUAccuracy': [],
            'TE_MLSMLoss': [],
            'NEC_MLSMLoss': [],
            'LYM_MLSMLoss': [],
            'TAS_MLSMLoss': [],
        }

        t = tqdm(self.val_cls_loader)

        for iter, sample in enumerate(t):
            names = sample['name']
            images = sample['img'].to(device=self.device)
            labels = sample['label'].to(device=self.device)

            with torch.no_grad():
                x = self.model(images)
                x = x.view(x.size(0), -1)

            loss_class = self.cls_criterion.val_call(input=x, target=labels)

            gc.collect()
            torch.cuda.empty_cache()

            loss_class = loss_class.detach().cpu().numpy()
            loss = loss_class.mean()
            probs = torch.sigmoid(input=x).detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            scores = self.cls_evaluator(probs=probs, y=labels)

            if print_incorrects:
                torchutils.print_incorrects(names=names, probs=probs, labels=labels, thresh=thresh)

            h['MLSMLoss'].append(loss)
            h['IoUAccuracy'].append(scores['accuracy'])
            h['ExactMatch'].append(scores['exact_match'])
            h['TE_IoUAccuracy'].append(scores['class_acc'][0])
            h['NEC_IoUAccuracy'].append(scores['class_acc'][1])
            h['LYM_IoUAccuracy'].append(scores['class_acc'][2])
            h['TAS_IoUAccuracy'].append(scores['class_acc'][3])
            h['TE_MLSMLoss'].append(loss_class[0])
            h['NEC_MLSMLoss'].append(loss_class[1])
            h['LYM_MLSMLoss'].append(loss_class[2])
            h['TAS_MLSMLoss'].append(loss_class[3])

            t.set_description(
                desc=f"E{ep} Val Loss: {loss:0.4f}, "
                f"Acc: {scores['accuracy']:0.4f}, "
                f"EM: {scores['exact_match']:0.4f}, "
            )

            self.val_cls_writer.add_scalar(tag='Batch/MLSMLoss', scalar_value=loss, global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/IoUAccuracy', scalar_value=scores['accuracy'], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/ExactMatch', scalar_value=scores['exact_match'], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/TE_IoUAccuracy', scalar_value=scores['class_acc'][0], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/NEC_IoUAccuracy', scalar_value=scores['class_acc'][1], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/LYM_IoUAccuracy', scalar_value=scores['class_acc'][2], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/TAS_IoUAccuracy', scalar_value=scores['class_acc'][3], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/TE_MLSMLoss', scalar_value=loss_class[0], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/NEC_MLSMLoss', scalar_value=loss_class[1], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/LYM_MLSMLoss', scalar_value=loss_class[2], global_step=self.gs[1])
            self.val_cls_writer.add_scalar(tag='Batch/TAS_MLSMLoss', scalar_value=loss_class[3], global_step=self.gs[1])

            self.gs[1] += 1

        self.val_cls_writer.add_scalar(tag='Epoch/MLSMLoss', scalar_value=np.nanmean(h['MLSMLoss']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/IoUAccuracy', scalar_value=np.nanmean(h['IoUAccuracy']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/ExactMatch', scalar_value=np.nanmean(h['ExactMatch']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/TE_IoUAccuracy', scalar_value=np.nanmean(h['TE_IoUAccuracy']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/NEC_IoUAccuracy', scalar_value=np.nanmean(h['NEC_IoUAccuracy']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/LYM_IoUAccuracy', scalar_value=np.nanmean(h['LYM_IoUAccuracy']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/TAS_IoUAccuracy', scalar_value=np.nanmean(h['TAS_IoUAccuracy']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/TE_MLSMLoss', scalar_value=np.nanmean(h['TE_MLSMLoss']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/NEC_MLSMLoss', scalar_value=np.nanmean(h['NEC_MLSMLoss']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/LYM_MLSMLoss', scalar_value=np.nanmean(h['LYM_MLSMLoss']), global_step=ep)
        self.val_cls_writer.add_scalar(tag='Epoch/TAS_MLSMLoss', scalar_value=np.nanmean(h['TAS_MLSMLoss']), global_step=ep)

        for k, v in h.items():
            print(f'Validate Classification {k}: {np.nanmean(h[k]):0.4f}')
        print(f'Validate Classification model_enable_PDA: {self.model.enable_PDA}')
        print(f'Validate Classification model_mu: {self.model.mu}')
        print(f'Train Classification model_gamma: {self.model.gamma}')

        return h

    def validate_cam_one_epoch(self, ep=1, threshold=0.25):
        print(f'{"#" * 20} Validating CAMs E{ep} {"#" * 20}')

        t = tqdm(self.val_cam_loader, desc=f'E{ep} Validating CAMs')

        for iter, sample in enumerate(t):
            names = sample['name']
            images = sample['img'].to(device=self.device)
            masks = sample['mask'].numpy()

            b, c, h, w = images.size()

            with torch.no_grad():
                cams, labels = self.model.forward_cam(images)

            cams = cams.detach().cpu()
            labels = labels.detach().cpu()

            labels = torch.greater(input=labels, other=threshold)
            cams = F.interpolate(input=cams, size=(h, w), mode='bilinear', align_corners=False)

            cams = cams.numpy()
            labels = labels.numpy()

            cams = cams * labels
            cams = np.argmax(a=cams, axis=1).astype(dtype=np.uint8)

            cams[masks == 4] = 4
            self.cam_evaluator.add_batch(gt_mask=masks, pred_mask=cams)

        scores = self.cam_evaluator.get_scores()

        self.val_cam_writer.add_scalar(tag='Epoch/PixelAccuracy', scalar_value=scores['pa'], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/MeanClassAccuracy', scalar_value=scores['ma'], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/TE_IoUAccuracy', scalar_value=scores['iou'][0], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/NEC_IoUAccuracy', scalar_value=scores['iou'][1], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/LYM_IoUAccuracy', scalar_value=scores['iou'][2], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/TAS_IoUAccuracy', scalar_value=scores['iou'][3], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/MIoU', scalar_value=scores['miou'], global_step=ep)
        self.val_cam_writer.add_scalar(tag='Epoch/FWIoU', scalar_value=scores['fwiou'], global_step=ep)

        for k, v in scores.items():
            print(f'Validate CAMs {k}: {np.nanmean(scores[k]):0.4f}')
        print(f'Validate CAMs model_enable_PDA: {self.model.enable_PDA}')
        print(f'Validate CAMs model_mu: {self.model.mu}')
        print(f'Validate CAMs model_gamma: {self.model.gamma}')

        return scores

    def test_cam_one_epoch(self, ep=1, threshold=0.25):
        print(f'{"#" * 20} Testing CAMs E{ep} {"#" * 20}')

        t = tqdm(self.test_cam_loader, desc=f'E{ep} Testing CAMs')

        for iter, sample in enumerate(t):
            names = sample['name']
            images = sample['img'].to(device=self.device)
            masks = sample['mask'].numpy()

            b, c, h, w = images.size()

            with torch.no_grad():
                cams, labels = self.model.forward_cam(images)

            cams = cams.detach().cpu()
            labels = labels.detach().cpu()

            labels = torch.greater(input=labels, other=threshold)
            cams = F.interpolate(input=cams, size=(h, w), mode='bilinear', align_corners=False)

            cams = cams.numpy()
            labels = labels.numpy()

            cams = cams * labels
            cams = np.argmax(a=cams, axis=1).astype(dtype=np.uint8)

            cams[masks == 4] = 4
            self.cam_evaluator.add_batch(gt_mask=masks, pred_mask=cams)

        scores = self.cam_evaluator.get_scores()

        self.test_cam_writer.add_scalar(tag='Epoch/PixelAccuracy', scalar_value=scores['pa'], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/MeanClassAccuracy', scalar_value=scores['ma'], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/TE_IoUAccuracy', scalar_value=scores['iou'][0], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/NEC_IoUAccuracy', scalar_value=scores['iou'][1], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/LYM_IoUAccuracy', scalar_value=scores['iou'][2], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/TAS_IoUAccuracy', scalar_value=scores['iou'][3], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/MIoU', scalar_value=scores['miou'], global_step=ep)
        self.test_cam_writer.add_scalar(tag='Epoch/FWIoU', scalar_value=scores['fwiou'], global_step=ep)

        for k, v in scores.items():
            print(f'Test CAMs {k}: {np.nanmean(scores[k]):0.4f}')
        print(f'Test CAMs model_enable_PDA: {self.model.enable_PDA}')
        print(f'Test CAMs model_mu: {self.model.mu}')
        print(f'Test CAMs model_gamma: {self.model.gamma}')

        return scores

    def fit(self, ep=1, epoch_history=None):
        self.model = self.model.to(device=self.device)

        wp = os.path.join(self.checkpoints_path, 'weights')
        os.makedirs(name=wp, exist_ok=True)

        if epoch_history is None:
            epoch_history = {
                'train_cls': [],
                'val_cls': [],
                'val_cam': [],
                'test_cam': [],
            }

        for ep in range(ep, self.max_epochs + 1):

            ########################################################
            #        Start of Transfer Learning Trade off
            ########################################################
            if ep < 4:
                self.model.train(TL=True)
            else:
                self.model.train(TL=False)
            ########################################################
            #        End of Transfer Learning Trade off
            ########################################################

            ########################################################
            #        Start of Progressive Dropout Attention
            ########################################################
            # if ep >= self.start_PDA:  # 7
            #     self.model.enable_PDA[0] = True
            #     if ep >= self.start_PDA + 5:
            #         self.model.enable_PDA[1] = True
            #     if ep >= self.start_PDA + 15:
            #         self.model.enable_PDA[2] = True
            #     if ep >= self.start_PDA + 20:
            #         self.model.enable_PDA[3] = True
            #     if ep >= self.start_PDA + 25:
            #         self.model.enable_PDA[4] = True
            #
            #     for i, pda in enumerate(self.model.enable_PDA):
            #         if pda:
            #             if self.model.mu[i] > self.l:
            #                 self.model.mu[i] = self.model.mu[i] * self.sigma
            #             # if self.model.gamma[i] < 0.3:
            #             #     self.model.gamma[i] = self.model.gamma[i] + self.gamma
            # else:
            #     self.model.enable_PDA = [False, False, False, False, False]
            ########################################################
            #        End of Progressive Dropout Attention
            ########################################################

            """ Training Classification """
            epoch_history['train_cls'].append(self.train_cls_one_epoch(ep=ep))

            """ Validating Classification """
            self.model.eval()
            self.model.enable_PDA = [False, False, False, False, False]
            epoch_history['val_cls'].append(self.validate_cls_one_epoch(ep=ep))

            """ Validating CAMs """
            self.cam_evaluator.reset()
            epoch_history['val_cam'].append(self.validate_cam_one_epoch(ep=ep))

            """ Testing CAMs """
            self.cam_evaluator.reset()
            epoch_history['test_cam'].append(self.test_cam_one_epoch(ep=ep))

            self.scheduler.step()

            wn = f'{self.session_name}_E{ep}_checkpoint_trained_on_luad.pth'

            self.save_checkpoint(epoch=ep, history=epoch_history, state_dict_path=os.path.join(wp, wn))

In [None]:
def final_forward_PDA(x, w, mu, gamma):  # (b, 4096, 28, 28), (4, 4096, 1, 1)

    cam = torch.conv2d(input=x, weight=w)  # (b, 4, 28, 28)
    cam = torch.relu(input=cam)  # (b, 4, 28, 28)

    beta = torch.amax(input=cam, dim=(-2, -1), keepdim=True) * mu  # (b, 4, 1, 1)
    beta = beta.expand(size=cam.size())  # (b, 4, 28, 28)

    alpha = torch.amin(input=cam, dim=(-2, -1), keepdim=True) + gamma  # (b, 4, 1, 1)
    alpha = alpha.expand(size=cam.size())  # (b, 4, 28, 28)

    cam = torch.less(input=cam, other=beta) * torch.greater(input=cam, other=alpha) * cam  # (b, 4, 28, 28)

    cam = torch.mean(input=cam, dim=1, keepdim=True)  # (b, 1, 28, 28)

    x = x * cam  # (b, 4096, 28, 28)

    return x  # (b, 4096, 28, 28)

In [None]:
def intermediate_forward_PDA(x, cam, mu, gamma):
    print(x.size())
    print(cam.size())

    cam = F.interpolate(input=cam, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False)
    # for i, e in enumerate(['TE', 'NEC', 'LYM', 'TAS']):
    #     plt.imshow(cam[0, i].detach().cpu())
    #     plt.grid(visible=False)
    #     plt.colorbar()
    #     plt.title(label=f'Orig {e}')
    #     plt.show()

    x_mean = x.mean(dim=1, keepdim=True)
    plt.matshow(x_mean[0, 0].detach().cpu())
    plt.grid(visible=False)
    plt.colorbar()
    plt.title(label=f'X-Mean')
    plt.show()

    # cam = x_mean * cam
    # for i, e in enumerate(['TE', 'NEC', 'LYM', 'TAS']):
    #     plt.imshow(cam[0, i].detach().cpu())
    #     plt.grid(visible=False)
    #     plt.colorbar()
    #     plt.title(label=f'After X*CAM {e}')
    #     plt.show()

    beta = x_mean.amax(dim=(-2, -1), keepdim=True) * 0.6  # (b, num_classes, 1, 1)
    beta = beta.expand(size=x_mean.size())  # (b, num_classes, h, w)

    alpha = x_mean.amin(dim=(-2, -1), keepdim=True) + gamma  # (b, num_classes, 1, 1)
    alpha = alpha.expand(size=x_mean.size())  # (b, num_classes, h, w)

    x_mean = torch.less(input=x_mean, other=beta) * torch.greater(input=x_mean, other=alpha) * x_mean  # (b, num_classes, h, w)
    # for i, e in enumerate(['TE', 'NEC', 'LYM', 'TAS']):
    #     plt.imshow(cam[0, i].detach().cpu())
    #     plt.grid(visible=False)
    #     plt.colorbar()
    #     plt.title(label=f'After PDA {e}')
    #     plt.show()

    x_mean = x_mean.mean(dim=1, keepdim=True)
    # cam = (cam > 0) * 1.0
    plt.imshow(x_mean[0, 0].detach().cpu())
    plt.grid(visible=False)
    plt.colorbar()
    plt.title(label='Mean after PDA')
    plt.show()


    x = x * x_mean  # (b, c, h, w)
    # plt.imshow(x.mean(dim=1, keepdim=True)[0, 0].detach().cpu())
    # plt.grid(visible=False)
    # plt.colorbar()
    # plt.title(label='X-Mean-Final')
    # plt.show()

    return x

In [None]:
session_name = 'Base7'
root_path = os.getcwd()
checkpoints_path = os.path.join(root_path, 'test', session_name)

train_cls_log_path = os.path.join(checkpoints_path, 'train_cls')
val_cls_log_path = os.path.join(checkpoints_path, 'val_cls')
val_cam_log_path = os.path.join(checkpoints_path, 'val_cam')
test_cam_log_path = os.path.join(checkpoints_path, 'test_cam')

train_data_path = os.path.join('test', 'data')
val_data_path = os.path.join('datasets', 'LUAD-HistoSeg', 'val')
test_data_path = os.path.join('datasets', 'LUAD-HistoSeg', 'test')

init_weights = 'ilsvrc-cls_rna-a1_cls1000_ep-0001.params'
init_weights_path = os.path.join(root_path, 'init_weights')
resume_path = 'test/Base7/weights/Base7_E3_checkpoint_trained_on_luad.pth'
# resume_path = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

val_size = 200
batch_size = 4
num_classes = 4
seed = 42
max_epochs = 40
lr = 1e-5
wt_dec = 5e-4
start_PDA = 5
init_mu = 1
sigma = 0.985
l = 0.65
gamma = 0.015

pyutils.set_seed(seed=seed)

engine = Engine(
    train_cls_log_path=train_cls_log_path,
    val_cls_log_path=val_cls_log_path,
    val_cam_log_path=val_cam_log_path,
    test_cam_log_path=test_cam_log_path,

    train_data_path=train_data_path,
    val_data_path=val_data_path,
    test_data_path=test_data_path,

    init_weights=init_weights,
    init_weights_path=init_weights_path,

    checkpoints_path=checkpoints_path,

    device=device,
    val_size=val_size,
    batch_size=batch_size,
    num_classes=num_classes,
    session_name=session_name,
    max_epochs=max_epochs,
    lr=lr,
    wt_dec=wt_dec,
    start_PDA=start_PDA,
    sigma=sigma,
    l=l,
    gamma=gamma
)

if (resume_path is not None) and (os.path.exists(path=resume_path) or os.path.exists(path=os.path.join(root_path, resume_path))):
    print('State Dict Found')
    engine.resume_checkpoint(state_dict_path=resume_path)
else:
    print('Training from scratch')
    engine.fit()