In [None]:
!pip install torchsummary
!pip install pretrainedmodels

In [None]:
# import for train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models as tv_models
from torch.utils.data import DataLoader
from torchsummary import summary
import numpy as np
from scipy import io
import threading
import pickle
from pathlib import Path
import math
import os
import sys
from glob import glob
import re
import gc
import importlib
import time
import sklearn.preprocessing
from sklearn.utils import class_weight
import psutil

In [None]:
!pip install efficientnet_pytorch

In [None]:
# import for models.py
import torch
import numbers
import numpy as np
import functools
import h5py
import math
from torchvision import models
import pretrainedmodels
import torch.nn.functional as F
import types
import torch
from efficientnet_pytorch import EfficientNet
from collections import OrderedDict
import torch.nn as nn

def Dense121(config):
    return models.densenet121(pretrained=True)

def Dense161(config):
    return models.densenet169(pretrained=True)

def Dense169(config):
    return models.densenet161(pretrained=True)

def Dense201(config):
    return models.densenet201(pretrained=True)

def Resnet50(config):
    return pretrainedmodels.__dict__['resnet50'](num_classes=1000, pretrained='imagenet')

def Resnet101(config):
    return models.resnet101(pretrained=True)

def InceptionV3(config):
    return models.inception_v3(pretrained=True)

def se_resnext50(config):
    return pretrainedmodels.__dict__['se_resnext50_32x4d'](num_classes=1000, pretrained='imagenet')

def se_resnext101(config):
    return pretrainedmodels.__dict__['se_resnext101_32x4d'](num_classes=1000, pretrained='imagenet')

def se_resnet50(config):
    return pretrainedmodels.__dict__['se_resnet50'](num_classes=1000, pretrained='imagenet')

def se_resnet101(config):
    return pretrainedmodels.__dict__['se_resnet101'](num_classes=1000, pretrained='imagenet')

def se_resnet152(config):
    return pretrainedmodels.__dict__['se_resnet152'](num_classes=1000, pretrained='imagenet')

def resnext101(config):
    return pretrainedmodels.__dict__['resnext101_32x4d'](num_classes=1000, pretrained='imagenet')

def resnext101_64(config):
    return pretrainedmodels.__dict__['resnext101_64x4d'](num_classes=1000, pretrained='imagenet')

def senet154(config):
    return pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained='imagenet')

def polynet(config):
    return pretrainedmodels.__dict__['polynet'](num_classes=1000, pretrained='imagenet')

def dpn92(config):
    return pretrainedmodels.__dict__['dpn92'](num_classes=1000, pretrained='imagenet+5k')

def dpn68b(config):
    return pretrainedmodels.__dict__['dpn68b'](num_classes=1000, pretrained='imagenet+5k')

def nasnetamobile(config):
    return pretrainedmodels.__dict__['nasnetamobile'](num_classes=1000, pretrained='imagenet')

def resnext101_32_8_wsl(config):
    return torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')

def resnext101_32_16_wsl(config):
    return torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')

def resnext101_32_32_wsl(config):
    return torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x32d_wsl')

def resnext101_32_48_wsl(config):
    return torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x48d_wsl')

def efficientnet_b0(config):
    return EfficientNet.from_pretrained('efficientnet-b0',num_classes=config['numClasses'])

def efficientnet_b1(config):
    return EfficientNet.from_pretrained('efficientnet-b1',num_classes=config['numClasses'])

def efficientnet_b2(config):
    return EfficientNet.from_pretrained('efficientnet-b2',num_classes=config['numClasses'])

def efficientnet_b3(config):
    return EfficientNet.from_pretrained('efficientnet-b3',num_classes=config['numClasses'])

def efficientnet_b4(config):
    return EfficientNet.from_pretrained('efficientnet-b4',num_classes=config['numClasses'])

def efficientnet_b5(config):
    return EfficientNet.from_pretrained('efficientnet-b5',num_classes=config['numClasses'])       

def efficientnet_b6(config):
    return EfficientNet.from_pretrained('efficientnet-b6',num_classes=config['numClasses'])   

def efficientnet_b7(config):
    return EfficientNet.from_pretrained('efficientnet-b7',num_classes=config['numClasses'])  

def isic2019_efficientnet_b0(config):
    # Load the EfficientNet B0 model with the specified number of output classes
    model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=config['numClasses'])
    
    # Load the state dict from the provided path and update the model
    model.load_state_dict(torch.load('/kaggle/input/extrafiles/checkpoint_best-60.pt')['state_dict'])
    
    # Freeze specific layers (set requires_grad=False for frozen layers)
    for name, param in model.named_parameters():
        if any(layer in name for layer in ['_conv_stem']):
            param.requires_grad = False
        else:
            param.requires_grad = True

    return model
    
def modify_meta(mdlParams,model):
    # Define FC layers
    if len(mdlParams['fc_layers_before']) > 1:
        model.meta_before = nn.Sequential(nn.Linear(mdlParams['meta_array'].shape[1],mdlParams['fc_layers_before'][0]),
                                    nn.BatchNorm1d(mdlParams['fc_layers_before'][0]),
                                    nn.ReLU(),
                                    nn.Dropout(p=mdlParams['dropout_meta']),
                                    nn.Linear(mdlParams['fc_layers_before'][0],mdlParams['fc_layers_before'][1]),
                                    nn.BatchNorm1d(mdlParams['fc_layers_before'][1]),
                                    nn.ReLU(),
                                    nn.Dropout(p=mdlParams['dropout_meta']))
    else:
        model.meta_before = nn.Sequential(nn.Linear(mdlParams['meta_array'].shape[1],mdlParams['fc_layers_before'][0]),
                                    nn.BatchNorm1d(mdlParams['fc_layers_before'][0]),
                                    nn.ReLU(),
                                    nn.Dropout(p=mdlParams['dropout_meta']))
    # Define fc layers after
    if len(mdlParams['fc_layers_after']) > 0:
        if 'efficient' in mdlParams['model_type']:
            num_cnn_features = model._fc.in_features 
        elif 'wsl' in mdlParams['model_type']:
            num_cnn_features = model.fc.in_features  
        else:
            num_cnn_features = model.last_linear.in_features     
        model.meta_after = nn.Sequential(nn.Linear(mdlParams['fc_layers_before'][-1]+num_cnn_features,mdlParams['fc_layers_after'][0]),
                                    nn.BatchNorm1d(mdlParams['fc_layers_after'][0]),
                                    nn.ReLU())
        classifier_in_features = mdlParams['fc_layers_after'][0] 
    else:
        model.meta_after = None
        classifier_in_features = mdlParams['fc_layers_before'][-1]+model._fc.in_features
    # Modify classifier
    if 'efficient' in mdlParams['model_type']:
        model._fc = nn.Linear(classifier_in_features, mdlParams['numClasses'])
    elif 'wsl' in mdlParams['model_type']:
        model.fc = nn.Linear(classifier_in_features, mdlParams['numClasses']) 
    else:
        model.last_linear = nn.Linear(classifier_in_features, mdlParams['numClasses'])       
    # Modify forward pass
    def new_forward(self, inputs):
        x, meta_data = inputs
        # Normal CNN features
        if 'efficient' in mdlParams['model_type']:
            # Convolution layers
            cnn_features = self.extract_features(x)
            # Pooling and final linear layer
            cnn_features = F.adaptive_avg_pool2d(cnn_features, 1).squeeze(-1).squeeze(-1)
            if self._dropout:
                cnn_features = F.dropout(cnn_features, p=self._dropout, training=self.training)
        elif 'wsl' in mdlParams['model_type']:
            cnn_features = self.conv1(x)
            cnn_features = self.bn1(cnn_features)
            cnn_features = self.relu(cnn_features)
            cnn_features = self.maxpool(cnn_features)

            cnn_features = self.layer1(cnn_features)
            cnn_features = self.layer2(cnn_features)
            cnn_features = self.layer3(cnn_features)
            cnn_features = self.layer4(cnn_features)

            cnn_features = self.avgpool(cnn_features)
            cnn_features = torch.flatten(cnn_features, 1) 
        else:
            cnn_features = self.layer0(x)
            cnn_features = self.layer1(cnn_features)
            cnn_features = self.layer2(cnn_features)
            cnn_features = self.layer3(cnn_features)
            cnn_features = self.layer4(cnn_features)   
            cnn_features = self.avg_pool(cnn_features)
            if self.dropout is not None:
                cnn_features = self.dropout(cnn_features)
            cnn_features = cnn_features.view(cnn_features.size(0), -1)                                
        # Meta part
        #print(meta_data.shape,meta_data)
        meta_features = self.meta_before(meta_data)

        # Cat
        features = torch.cat((cnn_features,meta_features),dim=1)
        #print("features cat",features.shape)
        if self.meta_after is not None:
            features = self.meta_after(features)
        # Classifier
        if 'efficient' in mdlParams['model_type']:
            output = self._fc(features)
        elif 'wsl' in mdlParams['model_type']:
            output = self.fc(features)
        else:
            output = self.last_linear(features)
        return output
    model.forward  = types.MethodType(new_forward, model)
    return model                                                                                                                       

model_map = OrderedDict([('Dense121',  Dense121),
                        ('Dense169' , Dense161),
                        ('Dense161' , Dense169),
                        ('Dense201' , Dense201),
                        ('Resnet50' , Resnet50),
                        ('Resnet101' , Resnet101),   
                        ('InceptionV3', InceptionV3),# models.inception_v3(pretrained=True),
                        ('se_resnext50', se_resnext50),
                        ('se_resnext101', se_resnext101),
                        ('se_resnet50', se_resnet50),
                        ('se_resnet101', se_resnet101),
                        ('se_resnet152', se_resnet152),
                        ('resnext101', resnext101),
                        ('resnext101_64', resnext101_64),
                        ('senet154', senet154),
                        ('polynet', polynet),
                        ('dpn92', dpn92),
                        ('dpn68b', dpn68b),
                        ('nasnetamobile', nasnetamobile),
                        ('resnext101_32_8_wsl', resnext101_32_8_wsl),
                        ('resnext101_32_16_wsl', resnext101_32_16_wsl),
                        ('resnext101_32_32_wsl', resnext101_32_32_wsl),
                        ('resnext101_32_48_wsl', resnext101_32_48_wsl),
                        ('efficientnet-b0', efficientnet_b0), 
                        ('efficientnet-b1', efficientnet_b1), 
                        ('efficientnet-b2', efficientnet_b2), 
                        ('efficientnet-b3', efficientnet_b3),  
                        ('efficientnet-b4', efficientnet_b4), 
                        ('efficientnet-b5', efficientnet_b5),  
                        ('efficientnet-b6', efficientnet_b6), 
                        ('efficientnet-b7', efficientnet_b7),
                        ('isic2019-efficientnet-b0', isic2019_efficientnet_b0)
                    ])

def getModel(config):
  """Returns a function for a model
  Args:
    config: dictionary, contains configuration
  Returns:
    model: A class that builds the desired model
  Raises:
    ValueError: If model name is not recognized.
  """
  if config['model_type'] in model_map:
    func = model_map[config['model_type'] ]
    @functools.wraps(func)
    def model():
        return func(config)
  else:
      raise ValueError('Name of model unknown %s' % config['model_name'] )
  return model

In [None]:
!pip install --upgrade torchvision

In [None]:
# import for autoaugment.py
import random
import numpy as np
import scipy
from scipy import ndimage
from PIL import Image, ImageEnhance, ImageOps

#See: https://github.com/4uiiurz1/pytorch-auto-augment
class AutoAugment(object):
    def __init__(self):
        self.policies = [
            ['Invert', 0.1, 7, 'Contrast', 0.2, 6],
            ['Rotate', 0.7, 2, 'TranslateX', 0.3, 9],
            ['Sharpness', 0.8, 1, 'Sharpness', 0.9, 3],
            ['ShearY', 0.5, 8, 'TranslateY', 0.7, 9],
            ['AutoContrast', 0.5, 8, 'Equalize', 0.9, 2],
            ['ShearY', 0.2, 7, 'Posterize', 0.3, 7],
            ['Color', 0.4, 3, 'Brightness', 0.6, 7],
            ['Sharpness', 0.3, 9, 'Brightness', 0.7, 9],
            ['Equalize', 0.6, 5, 'Equalize', 0.5, 1],
            ['Contrast', 0.6, 7, 'Sharpness', 0.6, 5],
            ['Color', 0.7, 7, 'TranslateX', 0.5, 8],
            ['Equalize', 0.3, 7, 'AutoContrast', 0.4, 8],
            ['TranslateY', 0.4, 3, 'Sharpness', 0.2, 6],
            ['Brightness', 0.9, 6, 'Color', 0.2, 8],
            ['Solarize', 0.5, 2, 'Invert', 0, 0.3],
            ['Equalize', 0.2, 0, 'AutoContrast', 0.6, 0],
            ['Equalize', 0.2, 8, 'Equalize', 0.6, 4],
            ['Color', 0.9, 9, 'Equalize', 0.6, 6],
            ['AutoContrast', 0.8, 4, 'Solarize', 0.2, 8],
            ['Brightness', 0.1, 3, 'Color', 0.7, 0],
            ['Solarize', 0.4, 5, 'AutoContrast', 0.9, 3],
            ['TranslateY', 0.9, 9, 'TranslateY', 0.7, 9],
            ['AutoContrast', 0.9, 2, 'Solarize', 0.8, 3],
            ['Equalize', 0.8, 8, 'Invert', 0.1, 3],
            ['TranslateY', 0.7, 9, 'AutoContrast', 0.9, 1],
        ]

    def __call__(self, img):
        img = apply_policy(img, self.policies[random.randrange(len(self.policies))])
        return img


operations = {
    'ShearX': lambda img, magnitude: shear_x(img, magnitude),
    'ShearY': lambda img, magnitude: shear_y(img, magnitude),
    'TranslateX': lambda img, magnitude: translate_x(img, magnitude),
    'TranslateY': lambda img, magnitude: translate_y(img, magnitude),
    'Rotate': lambda img, magnitude: rotate(img, magnitude),
    'AutoContrast': lambda img, magnitude: auto_contrast(img, magnitude),
    'Invert': lambda img, magnitude: invert(img, magnitude),
    'Equalize': lambda img, magnitude: equalize(img, magnitude),
    'Solarize': lambda img, magnitude: solarize(img, magnitude),
    'Posterize': lambda img, magnitude: posterize(img, magnitude),
    'Contrast': lambda img, magnitude: contrast(img, magnitude),
    'Color': lambda img, magnitude: color(img, magnitude),
    'Brightness': lambda img, magnitude: brightness(img, magnitude),
    'Sharpness': lambda img, magnitude: sharpness(img, magnitude),
    'Cutout': lambda img, magnitude: cutout(img, magnitude),
}


def apply_policy(img, policy):
    if random.random() < policy[1]:
        img = operations[policy[0]](img, policy[2])
    if random.random() < policy[4]:
        img = operations[policy[3]](img, policy[5])

    return img


def transform_matrix_offset_center(matrix, x, y):
    o_x = float(x) / 2 + 0.5
    o_y = float(y) / 2 + 0.5
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    transform_matrix = offset_matrix @ matrix @ reset_matrix
    return transform_matrix


def shear_x(img, magnitude):
    img = np.array(img)
    magnitudes = np.linspace(-0.3, 0.3, 11)

    transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 0],
                                 [0, 1, 0],
                                 [0, 0, 1]])
    transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1])
    affine_matrix = transform_matrix[:2, :2]
    offset = transform_matrix[:2, 2]
    img = np.stack([ndimage.interpolation.affine_transform(
                    img[:, :, c],
                    affine_matrix,
                    offset) for c in range(img.shape[2])], axis=2)
    img = Image.fromarray(img)
    return img


def shear_y(img, magnitude):
    img = np.array(img)
    magnitudes = np.linspace(-0.3, 0.3, 11)

    transform_matrix = np.array([[1, 0, 0],
                                 [random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 1, 0],
                                 [0, 0, 1]])
    transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1])
    affine_matrix = transform_matrix[:2, :2]
    offset = transform_matrix[:2, 2]
    img = np.stack([ndimage.interpolation.affine_transform(
                    img[:, :, c],
                    affine_matrix,
                    offset) for c in range(img.shape[2])], axis=2)
    img = Image.fromarray(img)
    return img


def translate_x(img, magnitude):
    img = np.array(img)
    magnitudes = np.linspace(-150/331, 150/331, 11)

    transform_matrix = np.array([[1, 0, 0],
                                 [0, 1, img.shape[1]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])],
                                 [0, 0, 1]])
    transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1])
    affine_matrix = transform_matrix[:2, :2]
    offset = transform_matrix[:2, 2]
    img = np.stack([ndimage.interpolation.affine_transform(
                    img[:, :, c],
                    affine_matrix,
                    offset) for c in range(img.shape[2])], axis=2)
    img = Image.fromarray(img)
    return img


def translate_y(img, magnitude):
    img = np.array(img)
    magnitudes = np.linspace(-150/331, 150/331, 11)

    transform_matrix = np.array([[1, 0, img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])],
                                 [0, 1, 0],
                                 [0, 0, 1]])
    transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1])
    affine_matrix = transform_matrix[:2, :2]
    offset = transform_matrix[:2, 2]
    img = np.stack([ndimage.interpolation.affine_transform(
                    img[:, :, c],
                    affine_matrix,
                    offset) for c in range(img.shape[2])], axis=2)
    img = Image.fromarray(img)
    return img


def rotate(img, magnitude):
    img = np.array(img)
    magnitudes = np.linspace(-30, 30, 11)
    theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    transform_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                 [np.sin(theta), np.cos(theta), 0],
                                 [0, 0, 1]])
    transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1])
    affine_matrix = transform_matrix[:2, :2]
    offset = transform_matrix[:2, 2]
    img = np.stack([ndimage.interpolation.affine_transform(
                    img[:, :, c],
                    affine_matrix,
                    offset) for c in range(img.shape[2])], axis=2)
    img = Image.fromarray(img)
    return img


def auto_contrast(img, magnitude):
    img = ImageOps.autocontrast(img)
    return img


def invert(img, magnitude):
    img = ImageOps.invert(img)
    return img


def equalize(img, magnitude):
    img = ImageOps.equalize(img)
    return img


def solarize(img, magnitude):
    magnitudes = np.linspace(0, 256, 11)
    img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    return img


def posterize(img, magnitude):
    magnitudes = np.linspace(4, 8, 11)
    img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))))
    return img


def contrast(img, magnitude):
    magnitudes = np.linspace(0.1, 1.9, 11)
    img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    return img


def color(img, magnitude):
    magnitudes = np.linspace(0.1, 1.9, 11)
    img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    return img


def brightness(img, magnitude):
    magnitudes = np.linspace(0.1, 1.9, 11)
    img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    return img


def sharpness(img, magnitude):
    magnitudes = np.linspace(0.1, 1.9, 11)
    img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))
    return img


def cutout(org_img, magnitude=None):
    img = np.array(img)

    magnitudes = np.linspace(0, 60/331, 11)

    img = np.copy(org_img)
    mask_val = img.mean()

    if magnitude is None:
        mask_size = 16
    else:
        mask_size = int(round(img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])))
    top = np.random.randint(0 - mask_size//2, img.shape[0] - mask_size)
    left = np.random.randint(0 - mask_size//2, img.shape[1] - mask_size)
    bottom = top + mask_size
    right = left + mask_size

    if top < 0:
        top = 0
    if left < 0:
        left = 0

    img[top:bottom, left:right, :].fill(mask_val)

    img = Image.fromarray(img)

    return img



class Cutout(object):
    def __init__(self, length=16):
        self.length = length

    def __call__(self, img):
        img = np.array(img)

        mask_val = img.mean()

        top = np.random.randint(0 - self.length//2, img.shape[0] - self.length)
        left = np.random.randint(0 - self.length//2, img.shape[1] - self.length)
        bottom = top + self.length
        right = left + self.length

        top = 0 if top < 0 else top
        left = 0 if left < 0 else top

        img[top:bottom, left:right, :] = mask_val

        img = Image.fromarray(img)

        return img

In [None]:
# import for utils.py

import os
import torch
#import pandas as pd
from skimage import io, transform
import scipy
import numpy as np
#import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, utils
import sklearn.metrics
from sklearn.metrics import confusion_matrix, auc, roc_curve, f1_score, classification_report
from sklearn.model_selection import StratifiedShuffleSplit
import math
from PIL import Image
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
import types

# Define ISIC Dataset Class
class ISICDataset(Dataset):
    """ISIC dataset."""

    def __init__(self, mdlParams, indSet):
        """
        Args:
            mdlParams (dict): Configuration for loading
            indSet (string): Indicates train, val, test
        """
        # Mdlparams
        self.mdlParams = mdlParams
        # Number of classes
        self.numClasses = mdlParams['numClasses']
        # Model input size
        self.input_size = (np.int32(mdlParams['input_size'][0]),np.int32(mdlParams['input_size'][1]))      
        # Whether or not to use ordered cropping 
        self.orderedCrop = mdlParams['orderedCrop']   
        # Number of crops for multi crop eval
        self.multiCropEval = mdlParams['multiCropEval']   
        # Whether during training same-sized crops should be used
        self.same_sized_crop = mdlParams['same_sized_crops']    
        # Only downsample
        self.only_downsmaple = mdlParams.get('only_downsmaple',False)   
        # Potential class balancing option 
        self.balancing = mdlParams['balance_classes']
        # Whether data should be preloaded
        self.preload = mdlParams['preload']
        # Potentially subtract a mean
        self.subtract_set_mean = mdlParams['subtract_set_mean']
        # Potential switch for evaluation on the training set
        self.train_eval_state = mdlParams['trainSetState']   
        # Potential setMean to deduce from channels
        self.setMean = mdlParams['setMean'].astype(np.float32)
        # Current indSet = 'trainInd'/'valInd'/'testInd'
        self.indices = mdlParams[indSet]  
        self.indSet = indSet
        # feature scaling for meta
        if mdlParams.get('meta_features',None) is not None and mdlParams['scale_features']:
            self.feature_scaler = mdlParams['feature_scaler_meta']
        if self.balancing == 3 and indSet == 'trainInd':
            # Sample classes equally for each batch
            # First, split set by classes
            not_one_hot = np.argmax(mdlParams['labels_array'],1)
            self.class_indices = []
            for i in range(mdlParams['numClasses']):
                self.class_indices.append(np.where(not_one_hot==i)[0])
                # Kick out non-trainind indices
                self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['valInd'])
                # And test indices
                if 'testInd' in mdlParams:
                    self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['testInd'])
            # Now sample indices equally for each batch by repeating all of them to have the same amount as the max number
            indices = []
            max_num = np.max([len(x) for x in self.class_indices])
            # Go thourgh all classes
            for i in range(mdlParams['numClasses']):
                count = 0
                class_count = 0
                max_num_curr_class = len(self.class_indices[i])
                # Add examples until we reach the maximum
                while(count < max_num):
                    # Start at the beginning, if we are through all available examples
                    if class_count == max_num_curr_class:
                        class_count = 0
                    indices.append(self.class_indices[i][class_count])
                    count += 1
                    class_count += 1
            print("Largest class",max_num,"Indices len",len(indices))
            print("Intersect val",np.intersect1d(indices,mdlParams['valInd']),"Intersect Testind",np.intersect1d(indices,mdlParams['testInd']))
            # Set labels/inputs
            self.labels = mdlParams['labels_array'][indices,:]
            self.im_paths = np.array(mdlParams['im_paths'])[indices].tolist()     
            # Normal train proc
            if self.same_sized_crop:
                cropping = transforms.RandomCrop(self.input_size)
            elif self.only_downsmaple:
                cropping = transforms.Resize(self.input_size)
            else:
                cropping = transforms.RandomResizedCrop(self.input_size[0])
            # All transforms
            self.composed = transforms.Compose([
                    cropping,
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomVerticalFlip(),
                    transforms.ColorJitter(brightness=32. / 255.,saturation=0.5),
                    transforms.ToTensor(),
                    transforms.Normalize(torch.from_numpy(self.setMean).float(),torch.from_numpy(np.array([1.,1.,1.])).float())
                    ])                                
        elif self.orderedCrop and (indSet == 'valInd' or self.train_eval_state  == 'eval' or indSet == 'testInd'):
            # Also flip on top            
            if mdlParams.get('eval_flipping',0) > 1:
                # Complete labels array, only for current indSet, repeat for multiordercrop
                inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']*mdlParams['eval_flipping'])
                self.labels = mdlParams['labels_array'][inds_rep,:]
                # meta
                if mdlParams.get('meta_features',None) is not None:
                    self.meta_data = mdlParams['meta_array'][inds_rep,:]    
                # Path to images for loading, only for current indSet, repeat for multiordercrop
                self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
                print("len im path",len(self.im_paths))                
                if self.mdlParams.get('var_im_size',False):
                    self.cropPositions = np.tile(mdlParams['cropPositions'][mdlParams[indSet],:,:],(1,mdlParams['eval_flipping'],1))
                    self.cropPositions = np.reshape(self.cropPositions,[mdlParams['multiCropEval']*mdlParams['eval_flipping']*mdlParams[indSet].shape[0],2])
                    #self.cropPositions = np.repeat(self.cropPositions, (mdlParams['eval_flipping'],1))
                    #print("CP examples",self.cropPositions[:50,:])
                else:
                    self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams['eval_flipping']*mdlParams[indSet].shape[0],1))
                # Flip states
                if mdlParams['eval_flipping'] == 2:
                    self.flipPositions = np.array([0,1])
                elif mdlParams['eval_flipping'] == 3:
                    self.flipPositions = np.array([0,1,2])
                elif mdlParams['eval_flipping'] == 4:
                    self.flipPositions = np.array([0,1,2,3])                    
                self.flipPositions = np.repeat(self.flipPositions, mdlParams['multiCropEval'])
                self.flipPositions = np.tile(self.flipPositions, mdlParams[indSet].shape[0])
                print("Crop positions shape",self.cropPositions.shape,"flip pos shape",self.flipPositions.shape)
                print("Flip example",self.flipPositions[:30])
            else:
                # Complete labels array, only for current indSet, repeat for multiordercrop
                inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
                self.labels = mdlParams['labels_array'][inds_rep,:]
                # meta
                if mdlParams.get('meta_features',None) is not None:
                    self.meta_data = mdlParams['meta_array'][inds_rep,:]                 
                # Path to images for loading, only for current indSet, repeat for multiordercrop
                self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
                print("len im path",len(self.im_paths))
                # Set up crop positions for every sample                
                if self.mdlParams.get('var_im_size',False):
                    self.cropPositions = np.reshape(mdlParams['cropPositions'][mdlParams[indSet],:,:],[mdlParams['multiCropEval']*mdlParams[indSet].shape[0],2])
                    #print("CP examples",self.cropPositions[:50,:])
                else:
                    self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams[indSet].shape[0],1))
                print("CP",self.cropPositions.shape)
            #print("CP Example",self.cropPositions[0:len(mdlParams['cropPositions']),:])          
            # Set up transforms
            self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
            self.trans = transforms.ToTensor()
        elif indSet == 'valInd' or indSet == 'testInd':
            if self.multiCropEval == 0:
                if self.only_downsmaple:
                    self.cropping = transforms.Resize(self.input_size)
                else:
                    self.cropping = transforms.Compose([transforms.CenterCrop(np.int32(self.input_size[0]*1.5)),transforms.Resize(self.input_size)])
                # Complete labels array, only for current indSet
                self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
                # meta
                if mdlParams.get('meta_features',None) is not None:
                    self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]                 
                # Path to images for loading, only for current indSet
                self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()                   
            else:
                # Deterministic processing
                if self.mdlParams.get('deterministic_eval',False):
                    total_len_per_im = mdlParams['numCropPositions']*len(mdlParams['cropScales'])*mdlParams['cropFlipping']                    
                    # Actual transforms are functionally applied at forward pass
                    self.cropPositions = np.zeros([total_len_per_im,3])
                    ind = 0
                    for i in range(mdlParams['numCropPositions']):
                        for j in range(len(mdlParams['cropScales'])):
                            for k in range(mdlParams['cropFlipping']):
                                self.cropPositions[ind,0] = i
                                self.cropPositions[ind,1] = mdlParams['cropScales'][j]
                                self.cropPositions[ind,2] = k
                                ind += 1
                    # Complete labels array, only for current indSet, repeat for multiordercrop
                    print("crops per image",total_len_per_im)
                    self.cropPositions = np.tile(self.cropPositions, (mdlParams[indSet].shape[0],1))
                    inds_rep = np.repeat(mdlParams[indSet], total_len_per_im)
                    self.labels = mdlParams['labels_array'][inds_rep,:]
                    # meta
                    if mdlParams.get('meta_features',None) is not None:
                        self.meta_data = mdlParams['meta_array'][inds_rep,:]                     
                    # Path to images for loading, only for current indSet, repeat for multiordercrop
                    self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
                else:
                    self.cropping = transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0))
                    # Complete labels array, only for current indSet, repeat for multiordercrop
                    inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
                    self.labels = mdlParams['labels_array'][inds_rep,:]
                    # meta
                    if mdlParams.get('meta_features',None) is not None:
                        self.meta_data = mdlParams['meta_array'][inds_rep,:]                    
                    # Path to images for loading, only for current indSet, repeat for multiordercrop
                    self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
            print(len(self.im_paths))  
            # Set up transforms
            self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
            self.trans = transforms.ToTensor()                   
        else:
            all_transforms = []
            # Normal train proc
            if self.same_sized_crop:
                all_transforms.append(transforms.RandomCrop(self.input_size))
            elif self.only_downsmaple:
                all_transforms.append(transforms.Resize(self.input_size))
            else:
                all_transforms.append(transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0)))
            if mdlParams.get('flip_lr_ud',False):
                all_transforms.append(transforms.RandomHorizontalFlip())
                all_transforms.append(transforms.RandomVerticalFlip())
            # Full rot
            if mdlParams.get('full_rot',0) > 0:
                if mdlParams.get('scale',False):
                    all_transforms.append(transforms.RandomChoice([transforms.RandomAffine(mdlParams['full_rot'], scale=mdlParams['scale'], shear=mdlParams.get('shear',0), resample=Image.NEAREST),
                                                                transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BICUBIC),
                                                                transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BILINEAR)])) 
                else:
                    all_transforms.append(transforms.RandomChoice([transforms.RandomRotation(mdlParams['full_rot'], resample=Image.NEAREST),
                                                                transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BICUBIC),
                                                                transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BILINEAR)]))    
            # Color distortion
            if mdlParams.get('full_color_distort') is not None:
                all_transforms.append(transforms.ColorJitter(brightness=mdlParams.get('brightness_aug',32. / 255.),saturation=mdlParams.get('saturation_aug',0.5), contrast = mdlParams.get('contrast_aug',0.5), hue = mdlParams.get('hue_aug',0.2)))
            else:
                all_transforms.append(transforms.ColorJitter(brightness=32. / 255.,saturation=0.5))   
            # Autoaugment
            if self.mdlParams.get('autoaugment',False):
                all_transforms.append(AutoAugment())             
            # Cutout
            if self.mdlParams.get('cutout',0) > 0:
                all_transforms.append(Cutout_v0(n_holes=1,length=self.mdlParams['cutout']))                             
            # Normalize
            all_transforms.append(transforms.ToTensor())
            all_transforms.append(transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd'])))            
            # All transforms
            self.composed = transforms.Compose(all_transforms)                  
            # Complete labels array, only for current indSet
            self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
            # meta
            if mdlParams.get('meta_features',None) is not None:
                self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]            
            # Path to images for loading, only for current indSet
            self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()
        # Potentially preload
        if self.preload:
            self.im_list = []
            for i in range(len(self.im_paths)):
                self.im_list.append(Image.open(self.im_paths[i]))
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        # Load image
        if self.preload:
            x = self.im_list[idx]
        else:
            x = Image.open(self.im_paths[idx])
            if self.mdlParams.get('resize_large_ones',0) > 0 and (x.size[0] == self.mdlParams['large_size'] and x.size[1] == self.mdlParams['large_size']):
                width = self.mdlParams['resize_large_ones']
                height = self.mdlParams['resize_large_ones']
                #height = (self.mdlParams['resize_large_ones']/self.mdlParams['large_size'])*x.size[1]
                x = x.resize((width,height),Image.BILINEAR)
            if self.mdlParams['input_size'][0] >= 224 and self.mdlParams['orderedCrop']:
                if x.size[0] < self.mdlParams['input_size'][0]:
                    new_height = int(self.mdlParams['input_size'][0]/float(x.size[0]))*x.size[1]
                    new_width = self.mdlParams['input_size'][0]
                    x = x.resize((new_width,new_height),Image.BILINEAR)
                if x.size[1] < self.mdlParams['input_size'][0]:
                    new_width = int(self.mdlParams['input_size'][0]/float(x.size[1]))*x.size[0]
                    new_height = self.mdlParams['input_size'][0]
                    x = x.resize((new_width,new_height),Image.BILINEAR)               
        # Get label
        y = self.labels[idx,:]
        # meta
        if self.mdlParams.get('meta_features',None) is not None:
            x_meta = self.meta_data[idx,:].copy()         
        # Transform data based on whether train or not train. If train, also check if its train train or train inference
        if self.orderedCrop and (self.indSet == 'valInd' or self.indSet == 'testInd' or self.train_eval_state == 'eval'):
            # Apply ordered cropping to validation or test set
            # Get current crop position
            x_loc = self.cropPositions[idx,0]
            y_loc = self.cropPositions[idx,1]
            # scale
            if self.mdlParams.get('meta_features',None) is not None and self.mdlParams['scale_features']:
                x_meta = np.squeeze(self.feature_scaler.transform(np.expand_dims(x_meta,0)))            
            if self.mdlParams.get('trans_norm_first',False):
                # First, to pytorch tensor (0.0-1.0)
                x = self.trans(x)
                # Normalize
                x = self.norm(x)   
                #print(self.im_paths[idx])
                #print("Before",x.size(),"xloc",x_loc,"y_loc",y_loc)
                if self.mdlParams.get('eval_flipping',0) > 1:
                    if self.flipPositions[idx] == 1:
                        x = torch.flip(x,(1,))
                    elif self.flipPositions[idx] == 2:
                        x = torch.flip(x,(2,))
                    elif self.flipPositions[idx] == 3:
                        x = torch.flip(x,(1,2))
                #print((x_loc-np.int32(self.input_size[0]/2.)),(x_loc-np.int32(self.input_size[0]/2.))+self.input_size[0],(y_loc-np.int32(self.input_size[1]/2.)),(y_loc-np.int32(self.input_size[1]/2.))+self.input_size[1])                
                x = x[:,np.int32(x_loc-(self.input_size[0]/2.)):np.int32(x_loc-(self.input_size[0]/2.))+self.input_size[0],
                        np.int32(y_loc-(self.input_size[1]/2.)):np.int32(y_loc-(self.input_size[1]/2.))+self.input_size[1]]                 
                #print("After",x.size())           
            else:
                # Then, apply current crop
                #print("Before",x.size(),"xloc",x_loc,"y_loc",y_loc)
                #print((x_loc-np.int32(self.input_size[0]/2.)),(x_loc-np.int32(self.input_size[0]/2.))+self.input_size[0],(y_loc-np.int32(self.input_size[1]/2.)),(y_loc-np.int32(self.input_size[1]/2.))+self.input_size[1])
                x = Image.fromarray(np.array(x)[(x_loc-np.int32(self.input_size[0]/2.)):(x_loc-np.int32(self.input_size[0]/2.))+self.input_size[0],
                        (y_loc-np.int32(self.input_size[1]/2.)):(y_loc-np.int32(self.input_size[1]/2.))+self.input_size[1],:])
                # First, to pytorch tensor (0.0-1.0)
                x = self.trans(x)
                # Normalize
                x = self.norm(x)            
            #print("After",x.size())
        elif self.indSet == 'valInd' or self.indSet == 'testInd':
            if self.mdlParams.get('deterministic_eval',False):
                crop = self.cropPositions[idx,0]   
                scale = self.cropPositions[idx,1]
                flipping = self.cropPositions[idx,2]
                if flipping == 1:
                    # Left flip
                    x = transforms.functional.hflip(x)
                elif flipping == 2:
                    # Right flip
                    x = transforms.functional.vflip(x)
                elif flipping == 3:
                    # Both flip
                    x = transforms.functional.hflip(x)
                    x = transforms.functional.vflip(x)                    
                # Scale
                if int(scale*x.size[0]) > self.input_size[0] and int(scale*x.size[1]) > self.input_size[1]:
                    x = transforms.functional.resize(x,(int(scale*x.size[0]),int(scale*x.size[1])))
                else:
                    x = transforms.functional.resize(x,(self.input_size[0],self.input_size[1]))
                # Crop
                if crop == 0:
                    # Center
                    x = transforms.functional.center_crop(x,self.input_size[0])
                elif crop == 1:
                    # upper left
                    x = transforms.functional.crop(x, self.mdlParams['offset_crop']*x.size[0], self.mdlParams['offset_crop']*x.size[1], self.input_size[0],self.input_size[1])
                elif crop == 2:
                    # lower left
                    x = transforms.functional.crop(x, self.mdlParams['offset_crop']*x.size[0], (1.0-self.mdlParams['offset_crop'])*x.size[1]-self.input_size[1], self.input_size[0],self.input_size[1]) 
                elif crop == 3:
                    # upper right
                    x = transforms.functional.crop(x, (1.0-self.mdlParams['offset_crop'])*x.size[0]-self.input_size[0], self.mdlParams['offset_crop']*x.size[1], self.input_size[0],self.input_size[1])  
                elif crop == 4:
                    # lower right
                    x = transforms.functional.crop(x, (1.0-self.mdlParams['offset_crop'])*x.size[0]-self.input_size[0], (1.0-self.mdlParams['offset_crop'])*x.size[1]-self.input_size[1], self.input_size[0],self.input_size[1])       
            else:
                x = self.cropping(x)        
            # To pytorch tensor (0.0-1.0)
            x = self.trans(x)
            x = self.norm(x)    
            # scale
            if self.mdlParams.get('meta_features',None) is not None and self.mdlParams['scale_features']:
                x_meta = np.squeeze(self.feature_scaler.transform(np.expand_dims(x_meta,0)))                          
        else:
            # Apply
            x = self.composed(x)
            # meta augment
            if self.mdlParams.get('meta_features',None) is not None:
                if self.mdlParams['drop_augment'] > 0:
                    # randomly deactivate a feature
                    # age
                    if torch.rand(1) < self.mdlParams['drop_augment']:
                        if 'age_oh' in self.mdlParams['meta_features']:
                            x_meta[0:self.mdlParams['meta_feature_sizes'][0]] = np.zeros([self.mdlParams['meta_feature_sizes'][0]])
                        else:
                            x_meta[0] = -5
                    if torch.rand(1) < self.mdlParams['drop_augment']:
                        if 'loc_oh' in self.mdlParams['meta_features']:   
                            x_meta[self.mdlParams['meta_feature_sizes'][0]:self.mdlParams['meta_feature_sizes'][0]+self.mdlParams['meta_feature_sizes'][1]] = np.zeros([self.mdlParams['meta_feature_sizes'][1]])
                    if torch.rand(1) < self.mdlParams['drop_augment']:
                        if 'sex_oh' in self.mdlParams['meta_features']:   
                            x_meta[self.mdlParams['meta_feature_sizes'][0]+self.mdlParams['meta_feature_sizes'][1]:self.mdlParams['meta_feature_sizes'][0]+self.mdlParams['meta_feature_sizes'][1]+self.mdlParams['meta_feature_sizes'][2]] = np.zeros([self.mdlParams['meta_feature_sizes'][2]]) 
                # scale
                if self.mdlParams['scale_features']:
                    x_meta = np.squeeze(self.feature_scaler.transform(np.expand_dims(x_meta,0)))                         
        # Transform y
        y = np.argmax(y)
        y = np.int64(y)
        if self.mdlParams.get('meta_features',None) is not None:
            x_meta = np.float32(x_meta) 
        if self.mdlParams.get('eval_flipping',0) > 1:
            return x, y, idx, self.flipPositions[idx]
        else:
            if self.mdlParams.get('meta_features',None) is not None:
                return (x, x_meta), y, idx
            else:
                return x, y, idx




class Cutout_v0(object):
    """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        img = np.array(img)
        #print(img.shape)
        h = img.shape[0]
        w = img.shape[1]

        mask = np.ones((h, w), np.uint8)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        #mask = torch.from_numpy(mask)
        #mask = mask.expand_as(img)
        img = img * np.expand_dims(mask,axis=2)
        img = Image.fromarray(img)
        return img    

# Sampler for balanced sampling
class StratifiedSampler(torch.utils.data.sampler.Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, mdlParams):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.dataset_len = len(mdlParams['trainInd'])
        self.numClasses = mdlParams['numClasses']
        self.trainInd = mdlParams['trainInd']
        # Sample classes equally for each batch
        # First, split set by classes
        not_one_hot = np.argmax(mdlParams['labels_array'][mdlParams['trainInd'],:],1)
        self.class_indices = []
        for i in range(mdlParams['numClasses']):
            self.class_indices.append(np.where(not_one_hot==i)[0])
        self.current_class_ind = 0
        self.current_in_class_ind = np.zeros([mdlParams['numClasses']],dtype=int)

    def gen_sample_array(self):
        # Shuffle all classes first
        for i in range(self.numClasses):
            np.random.shuffle(self.class_indices[i])
        # Construct indset
        indices = np.zeros([self.dataset_len],dtype=np.int32)
        ind = 0
        while(ind < self.dataset_len):
            indices[ind] = self.class_indices[self.current_class_ind][self.current_in_class_ind[self.current_class_ind]]
            # Take care of in-class index
            if self.current_in_class_ind[self.current_class_ind] == len(self.class_indices[self.current_class_ind])-1:
                self.current_in_class_ind[self.current_class_ind] = 0
                # Shuffle
                np.random.shuffle(self.class_indices[self.current_class_ind])
            else:
                self.current_in_class_ind[self.current_class_ind] += 1
            # Take care of overall class ind
            if self.current_class_ind == self.numClasses-1:
                self.current_class_ind = 0
            else:
                self.current_class_ind += 1
            ind += 1
        return indices

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return self.dataset_len 

class FocalLoss(nn.Module):

    def __init__(self, gamma=2.0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        #print("before gather",logpt)
        #print("target",target)
        logpt = logpt.gather(1,target)
        #print("after gather",logpt)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            #print("alpha",self.alpha)
            #print("gathered",at)
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

def getErrClassification_mgpu(mdlParams, indices, modelVars, exclude_class=None):
    """Helper function to return the error of a set
    Args:
      mdlParams: dictionary, configuration file
      indices: string, either "trainInd", "valInd" or "testInd"
    Returns:
      loss: float, avg loss
      acc: float, accuracy
      sensitivity: float, sensitivity
      spec: float, specificity
      conf: float matrix, confusion matrix
    """
    # Set up sizes
    if indices == 'trainInd':
        numBatches = int(math.floor(len(mdlParams[indices])/mdlParams['batchSize']/len(mdlParams['numGPUs'])))
    else:
        numBatches = int(math.ceil(len(mdlParams[indices])/mdlParams['batchSize']/len(mdlParams['numGPUs'])))
    # Consider multi-crop case
    if mdlParams.get('eval_flipping',0) > 1 and mdlParams.get('multiCropEval',0) > 0:
        loss_all = np.zeros([numBatches])
        predictions = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])
        targets = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])        
        loss_mc = np.zeros([len(mdlParams[indices])*mdlParams['eval_flipping']])
        predictions_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['multiCropEval'],mdlParams['eval_flipping']])
        targets_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['multiCropEval'],mdlParams['eval_flipping']])  
        # Very suboptimal method
        ind = -1
        for i, (inputs, labels, inds, flip_ind) in enumerate(modelVars['dataloader_'+indices]):
            if flip_ind[0] != np.mean(np.array(flip_ind)):
                print("Problem with flipping",flip_ind)
            if flip_ind[0] == 0:
                ind += 1
            # Get data
            if mdlParams.get('meta_features',None) is not None: 
                inputs[0] = inputs[0].cuda()
                inputs[1] = inputs[1].cuda()
            else:            
                inputs = inputs.to(modelVars['device'])
            labels = labels.to(modelVars['device'])       
            # Not sure if thats necessary
            modelVars['optimizer'].zero_grad()    
            with torch.set_grad_enabled(False):
                # Get outputs
                if mdlParams.get('aux_classifier',False):
                    outputs, outputs_aux = modelVars['model'](inputs)
                    if mdlParams['eval_aux_classifier']:
                        outputs = outputs_aux
                else:
                    outputs = modelVars['model'](inputs)
                preds = modelVars['softmax'](outputs)      
                # Loss
                loss = modelVars['criterion'](outputs, labels)           
            # Write into proper arrays
            loss_mc[ind] = np.mean(loss.cpu().numpy())
            predictions_mc[ind,:,:,flip_ind[0]] = np.transpose(preds.cpu().numpy())
            tar_not_one_hot = labels.data.cpu().numpy()
            tar = np.zeros((tar_not_one_hot.shape[0], mdlParams['numClasses']))
            tar[np.arange(tar_not_one_hot.shape[0]),tar_not_one_hot] = 1
            targets_mc[ind,:,:,flip_ind[0]] = np.transpose(tar)
        # Targets stay the same
        targets = targets_mc[:,:,0,0]
        # reshape preds
        predictions_mc = np.reshape(predictions_mc,[predictions_mc.shape[0],predictions_mc.shape[1],mdlParams['multiCropEval']*mdlParams['eval_flipping']])
        if mdlParams['voting_scheme'] == 'vote':
            # Vote for correct prediction
            print("Pred Shape",predictions_mc.shape)
            predictions_mc = np.argmax(predictions_mc,1)    
            print("Pred Shape",predictions_mc.shape) 
            for j in range(predictions_mc.shape[0]):
                predictions[j,:] = np.bincount(predictions_mc[j,:],minlength=mdlParams['numClasses'])   
            print("Pred Shape",predictions.shape) 
        elif mdlParams['voting_scheme'] == 'average':
            predictions = np.mean(predictions_mc,2)        
    elif mdlParams.get('multiCropEval',0) > 0:
        loss_all = np.zeros([numBatches])
        predictions = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])
        targets = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])        
        loss_mc = np.zeros([len(mdlParams[indices])])
        predictions_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['multiCropEval']])
        targets_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['multiCropEval']])   
        for i, (inputs, labels, inds) in enumerate(modelVars['dataloader_'+indices]):
            # Get data
            if mdlParams.get('meta_features',None) is not None: 
                inputs[0] = inputs[0].cuda()
                inputs[1] = inputs[1].cuda()
            else:            
                inputs = inputs.to(modelVars['device'])
            labels = labels.to(modelVars['device'])       
            # Not sure if thats necessary
            modelVars['optimizer'].zero_grad()    
            with torch.set_grad_enabled(False):
                # Get outputs
                if mdlParams.get('aux_classifier',False):
                    outputs, outputs_aux = modelVars['model'](inputs)
                    if mdlParams['eval_aux_classifier']:
                        outputs = outputs_aux
                else:
                    outputs = modelVars['model'](inputs)
                preds = modelVars['softmax'](outputs)      
                # Loss
                loss = modelVars['criterion'](outputs, labels)           
            # Write into proper arrays
            loss_mc[i] = np.mean(loss.cpu().numpy())
            predictions_mc[i,:,:] = np.transpose(preds.cpu().numpy())
            tar_not_one_hot = labels.data.cpu().numpy()
            tar = np.zeros((tar_not_one_hot.shape[0], mdlParams['numClasses']))
            tar[np.arange(tar_not_one_hot.shape[0]),tar_not_one_hot] = 1
            targets_mc[i,:,:] = np.transpose(tar)
        # Targets stay the same
        targets = targets_mc[:,:,0]
        if mdlParams['voting_scheme'] == 'vote':
            # Vote for correct prediction
            print("Pred Shape",predictions_mc.shape)
            predictions_mc = np.argmax(predictions_mc,1)    
            print("Pred Shape",predictions_mc.shape) 
            for j in range(predictions_mc.shape[0]):
                predictions[j,:] = np.bincount(predictions_mc[j,:],minlength=mdlParams['numClasses'])   
            print("Pred Shape",predictions.shape) 
        elif mdlParams['voting_scheme'] == 'average':
            predictions = np.mean(predictions_mc,2)
    else:    
        if mdlParams.get('model_type_cnn') is not None and mdlParams['numRandValSeq'] > 0:
            loss_all = np.zeros([numBatches])
            predictions = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])
            targets = np.zeros([len(mdlParams[indices]),mdlParams['numClasses']])        
            loss_mc = np.zeros([len(mdlParams[indices])])
            predictions_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['numRandValSeq']])
            targets_mc = np.zeros([len(mdlParams[indices]),mdlParams['numClasses'],mdlParams['numRandValSeq']])   
            for i, (inputs, labels, inds) in enumerate(modelVars['dataloader_'+indices]):
                # Get data
                if mdlParams.get('meta_features',None) is not None: 
                    inputs[0] = inputs[0].cuda()
                    inputs[1] = inputs[1].cuda()
                else:            
                    inputs = inputs.to(modelVars['device'])
                labels = labels.to(modelVars['device'])       
                # Not sure if thats necessary
                modelVars['optimizer'].zero_grad()    
                with torch.set_grad_enabled(False):
                    # Get outputs
                    if mdlParams.get('aux_classifier',False):
                        outputs, outputs_aux = modelVars['model'](inputs)
                        if mdlParams['eval_aux_classifier']:
                            outputs = outputs_aux
                    else:
                        outputs = modelVars['model'](inputs)
                    preds = modelVars['softmax'](outputs)      
                    # Loss
                    loss = modelVars['criterion'](outputs, labels)           
                # Write into proper arrays
                loss_mc[i] = np.mean(loss.cpu().numpy())
                predictions_mc[i,:,:] = np.transpose(preds)
                tar_not_one_hot = labels.data.cpu().numpy()
                tar = np.zeros((tar_not_one_hot.shape[0], mdlParams['numClasses']))
                tar[np.arange(tar_not_one_hot.shape[0]),tar_not_one_hot] = 1
                targets_mc[i,:,:] = np.transpose(tar)
            # Targets stay the same
            targets = targets_mc[:,:,0]
            if mdlParams['voting_scheme'] == 'vote':
                # Vote for correct prediction
                print("Pred Shape",predictions_mc.shape)
                predictions_mc = np.argmax(predictions_mc,1)    
                print("Pred Shape",predictions_mc.shape) 
                for j in range(predictions_mc.shape[0]):
                    predictions[j,:] = np.bincount(predictions_mc[j,:],minlength=mdlParams['numClasses'])   
                print("Pred Shape",predictions.shape) 
            elif mdlParams['voting_scheme'] == 'average':
                predictions = np.mean(predictions_mc,2)
        else:
            for i, (inputs, labels, indices) in enumerate(modelVars['dataloader_'+indices]):
                # Get data
                if mdlParams.get('meta_features',None) is not None: 
                    inputs[0] = inputs[0].cuda()
                    inputs[1] = inputs[1].cuda()
                else:            
                    inputs = inputs.to(modelVars['device'])
                labels = labels.to(modelVars['device'])       
                # Not sure if thats necessary
                modelVars['optimizer'].zero_grad()    
                with torch.set_grad_enabled(False):
                    # Get outputs
                    if mdlParams.get('aux_classifier',False):
                        outputs, outputs_aux = modelVars['model'](inputs)
                        if mdlParams['eval_aux_classifier']:
                            outputs = outputs_aux
                    else:
                        outputs = modelVars['model'](inputs)
                    #print("in",inputs.shape,"out",outputs.shape)
                    preds = modelVars['softmax'](outputs)      
                    # Loss
                    loss = modelVars['criterion'](outputs, labels)     
                # Write into proper arrays                
                if i==0:
                    loss_all = np.array([loss.cpu().numpy()])
                    predictions = preds.cpu().numpy()
                    tar_not_one_hot = labels.data.cpu().numpy()
                    tar = np.zeros((tar_not_one_hot.shape[0], mdlParams['numClasses']))
                    tar[np.arange(tar_not_one_hot.shape[0]),tar_not_one_hot] = 1   
                    targets = tar    
                    #print("Loss",loss_all)         
                else:                 
                    loss_all = np.concatenate((loss_all,np.array([loss.cpu().numpy()])),0)
                    predictions = np.concatenate((predictions,preds.cpu().numpy()),0)
                    tar_not_one_hot = labels.data.cpu().numpy()
                    tar = np.zeros((tar_not_one_hot.shape[0], mdlParams['numClasses']))
                    tar[np.arange(tar_not_one_hot.shape[0]),tar_not_one_hot] = 1                   
                    targets = np.concatenate((targets,tar),0)
                    #allInds[(i*len(mdlParams['numGPUs'])+k)*bSize:(i*len(mdlParams['numGPUs'])+k+1)*bSize] = res_tuple[3][k]
            predictions_mc = predictions
    #print("Check Inds",np.setdiff1d(allInds,mdlParams[indices]))
    # Calculate metrics
    if exclude_class is not None:
        predictions = np.concatenate((predictions[:,:exclude_class],predictions[:,exclude_class+1:]),1)
        targets = np.concatenate((targets[:,:exclude_class],targets[:,exclude_class+1:]),1)    
        num_classes = mdlParams['numClasses']-1
    elif mdlParams['numClasses'] == 9 and mdlParams.get('no_c9_eval',False):
        predictions = predictions[:,:mdlParams['numClasses']-1]
        targets = targets[:,:mdlParams['numClasses']-1]
        num_classes = mdlParams['numClasses']-1
    else:
        num_classes = mdlParams['numClasses']
    # Accuarcy
    acc = np.mean(np.equal(np.argmax(predictions,1),np.argmax(targets,1)))
    # Confusion matrix
    conf = confusion_matrix(np.argmax(targets,1),np.argmax(predictions,1))
    if conf.shape[0] < num_classes:
        conf = np.ones([num_classes,num_classes])
    # Class weighted accuracy
    wacc = conf.diagonal()/conf.sum(axis=1)    
    # Sensitivity / Specificity
    sensitivity = np.zeros([num_classes])
    specificity = np.zeros([num_classes])
    if num_classes > 2:
        for k in range(num_classes):
                sensitivity[k] = conf[k,k]/(np.sum(conf[k,:]))
                true_negative = np.delete(conf,[k],0)
                true_negative = np.delete(true_negative,[k],1)
                true_negative = np.sum(true_negative)
                false_positive = np.delete(conf,[k],0)
                false_positive = np.sum(false_positive[:,k])
                specificity[k] = true_negative/(true_negative+false_positive)
                # F1 score
                f1 = f1_score(np.argmax(predictions,1),np.argmax(targets,1),average='weighted')                
    else:
        tn, fp, fn, tp = confusion_matrix(np.argmax(targets,1),np.argmax(predictions,1)).ravel()
        sensitivity = tp/(tp+fn)
        specificity = tn/(tn+fp)
        # F1 score
        f1 = f1_score(np.argmax(predictions,1),np.argmax(targets,1))
    # AUC
    fpr = {}
    tpr = {}
    roc_auc = np.zeros([num_classes])
    if num_classes > 9:
        print(predictions)
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(targets[:, i], predictions[:, i])
        roc_auc[i] = sklearn.metrics.auc(fpr[i], tpr[i])
    return np.mean(loss_all), acc, sensitivity, specificity, conf, f1, roc_auc, wacc, predictions, targets, predictions_mc 


def modify_densenet_avg_pool(model):
    def logits(self, features):
        x = F.relu(features, inplace=True)
        x = torch.mean(torch.mean(x,2), 2)
        #x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x

    # Modify methods
    model.logits = types.MethodType(logits, model)
    model.forward = types.MethodType(forward, model)

    return model

In [None]:
!pip install imagesize

In [None]:
# import config file

import os
import sys
import h5py
import re
import csv
import numpy as np
from glob import glob
import scipy
import pickle
import imagesize
import pdb

mdlParams = {}
# Save summaries and model here
mdlParams['saveDir'] = '/kaggle/working/'
# Data is loaded from here
# mdlParams['dataDir'] = mdlParams_['pathBase']+"/"

### Model Selection ###
mdlParams['model_type'] = 'isic2019-efficientnet-b0'
mdlParams['dataset_names'] = ['official']#,'sevenpoint_rez3_ll']
mdlParams['file_ending'] = '.jpg'
mdlParams['exclude_inds'] = False
mdlParams['same_sized_crops'] = True
mdlParams['multiCropEval'] = 9
mdlParams['var_im_size'] = True
mdlParams['orderedCrop'] = True
mdlParams['voting_scheme'] = 'average'    
mdlParams['classification'] = True
mdlParams['balance_classes'] = 2
mdlParams['extra_fac'] = 1.0
mdlParams['numClasses'] = 2
mdlParams['no_c9_eval'] = True
mdlParams['numOut'] = mdlParams['numClasses']
mdlParams['numCV'] = 5
mdlParams['trans_norm_first'] = True
# Scale up for b1-b7
mdlParams['input_size'] = [224,224,3]     

### Training Parameters ###
# Batch size
mdlParams['batchSize'] = 20#*len(mdlParams['numGPUs'])
# Initial learning rate
mdlParams['learning_rate'] = 0.000015#*len(mdlParams['numGPUs'])
# Lower learning rate after no improvement over 100 epochs
mdlParams['lowerLRAfter'] = 25
# If there is no validation set, start lowering the LR after X steps
mdlParams['lowerLRat'] = 50
# Divide learning rate by this value
mdlParams['LRstep'] = 5
# Maximum number of training iterations
mdlParams['training_steps'] = 60 #250
# Display error every X steps
mdlParams['display_step'] = 10
# Scale?
mdlParams['scale_targets'] = False
# Peak at test error during training? (generally, dont do this!)
mdlParams['peak_at_testerr'] = False
# Print trainerr
mdlParams['print_trainerr'] = False
# Subtract trainset mean?
mdlParams['subtract_set_mean'] = False
mdlParams['setMean'] = np.array([0.0, 0.0, 0.0])   
mdlParams['setStd'] = np.array([1.0, 1.0, 1.0])   

# Data AUG
#mdlParams['full_color_distort'] = True
mdlParams['autoaugment'] = False     
mdlParams['flip_lr_ud'] = True
mdlParams['full_rot'] = 180
mdlParams['scale'] = (0.8,1.2)
mdlParams['shear'] = 10
mdlParams['cutout'] = 16

### Data ###
mdlParams['preload'] = False
# Labels first
# Targets, as dictionary, indexed by im file name
mdlParams['labels_dict'] = {}
 # All sets
allSets = ['official']  
# Go through all sets
for i in range(len(allSets)):
    # Check if want to include this dataset
    foundSet = False
    for j in range(len(mdlParams['dataset_names'])):
        if mdlParams['dataset_names'][j] in allSets[i]:
            foundSet = True
    if not foundSet:
        continue                
    # Find csv file
    file_loc = '/kaggle/input/labels/labels2024.csv'

    # Load csv file
    with open(file_loc, newline='') as csvfile:
        labels_str = csv.reader(csvfile, delimiter=',', quotechar='|')
        for row in labels_str:
            if 'Malignant' == row[1]:
                continue
            #if 'ISIC' in row[0] and '_downsampled' in row[0]:
            #    print(row[0])
            if row[0] + '_downsampled' in mdlParams['labels_dict']:
                print("removed",row[0] + '_downsampled')
                continue
            if mdlParams['numClasses'] == 1:
                mdlParams['labels_dict'][row[0]] = np.array([int(float(row[1]))])
            if mdlParams['numClasses'] == 2:
                mdlParams['labels_dict'][row[0]] = np.array([int(float(row[1])), int(float(row[2]))])
            if mdlParams['numClasses'] == 7:
                mdlParams['labels_dict'][row[0]] = np.array([int(float(row[1])),int(float(row[2])),int(float(row[3])),int(float(row[4])),int(float(row[5])),int(float(row[6])),int(float(row[7]))])
            elif mdlParams['numClasses'] == 8:
                if len(row) < 9 or row[8] == '':
                    class_8 = 0
                else:
                    class_8 = int(float(row[8]))
                mdlParams['labels_dict'][row[0]] = np.array([int(float(row[1])),int(float(row[2])),int(float(row[3])),int(float(row[4])),int(float(row[5])),int(float(row[6])),int(float(row[7])),class_8])
            elif mdlParams['numClasses'] == 9:
                if len(row) < 9 or row[8] == '':
                    class_8 = 0
                else:
                    class_8 = int(float(row[8]))  
                if len(row) < 10 or row[9] == '':
                    class_9 = 0
                else:
                    class_9 = int(float(row[9]))
                mdlParams['labels_dict'][row[0]] = np.array([int(float(row[1])),int(float(row[2])),int(float(row[3])),int(float(row[4])),int(float(row[5])),int(float(row[6])),int(float(row[7])),class_8,class_9])
# Save all im paths here
mdlParams['im_paths'] = []
mdlParams['labels_list'] = []
# Define the sets
# All sets
allSets = ['official']
# Ids which name the folders
# Make official first dataset
for i in range(len(allSets)):
    if mdlParams['dataset_names'][0] in allSets[i]:
        temp = allSets[i]
        allSets.remove(allSets[i])
        allSets.insert(0, temp)
print(allSets)        
# Set of keys, for marking old HAM10000
mdlParams['key_list'] = []
if mdlParams['exclude_inds']:
    with open(mdlParams['saveDir'] + 'indices_exclude.pkl','rb') as f:
        indices_exclude = pickle.load(f)          
    exclude_list = []   
for i in range(len(allSets)):
    # All files in that set
    files = sorted(glob('/kaggle/input/isic-2024-challenge/train-image/image/*'))
    # Check if there is something in there, if not, discard
    if len(files) == 0:
        continue
    # Check if want to include this dataset
    foundSet = False
    for j in range(len(mdlParams['dataset_names'])):
        if mdlParams['dataset_names'][j] in allSets[i]:
            foundSet = True
    if not foundSet:
        continue                    
    for j in range(len(files)):
        if '.jpg' in files[j] or '.jpeg' in files[j] or '.JPG' in files[j] or '.JPEG' in files[j] or '.png' in files[j] or '.PNG' in files[j]:                
            # Add according label, find it first
            found_already = False
            for key in mdlParams['labels_dict']:
                if key + mdlParams['file_ending'] in files[j]:
                    if found_already:
                        print("Found already:",key,files[j])                     
                    mdlParams['key_list'].append(key)
                    mdlParams['labels_list'].append(mdlParams['labels_dict'][key])
                    found_already = True
            if found_already:
                mdlParams['im_paths'].append(files[j])     
                if mdlParams['exclude_inds']:
                    for key in indices_exclude:
                        if key in files[j]:
                            exclude_list.append(indices_exclude[key])                                       
# Convert label list to array
mdlParams['labels_array'] = np.array(mdlParams['labels_list'])
print(np.mean(mdlParams['labels_array'],axis=0))

# Create indices list with HAM10000 only
# mdlParams['HAM10000_inds'] = []
# HAM_START = 24306
# HAM_END = 34320
# for j in range(len(mdlParams['key_list'])):
#     try:
#         curr_id = [int(s) for s in re.findall(r'\d+',mdlParams['key_list'][j])][-1]
#     except:
#         continue
#     if curr_id >= HAM_START and curr_id <= HAM_END:
#         mdlParams['HAM10000_inds'].append(j)
# mdlParams['HAM10000_inds'] = np.array(mdlParams['HAM10000_inds'])    
# print("Len ham",len(mdlParams['HAM10000_inds']))

# Perhaps preload images
if mdlParams['preload']:
    mdlParams['images_array'] = np.zeros([len(mdlParams['im_paths']),mdlParams['input_size_load'][0],mdlParams['input_size_load'][1],mdlParams['input_size_load'][2]],dtype=np.uint8)
    for i in range(len(mdlParams['im_paths'])):
        x = scipy.ndimage.imread(mdlParams['im_paths'][i])
        #x = x.astype(np.float32)   
        # Scale to 0-1 
        #min_x = np.min(x)
        #max_x = np.max(x)
        #x = (x-min_x)/(max_x-min_x)
        mdlParams['images_array'][i,:,:,:] = x
        if i%1000 == 0:
            print(i+1,"images loaded...")     
if mdlParams['subtract_set_mean']:
    mdlParams['images_means'] = np.zeros([len(mdlParams['im_paths']),3])
    for i in range(len(mdlParams['im_paths'])):
        x = scipy.ndimage.imread(mdlParams['im_paths'][i])
        x = x.astype(np.float32)   
        # Scale to 0-1 
        min_x = np.min(x)
        max_x = np.max(x)
        x = (x-min_x)/(max_x-min_x)
        mdlParams['images_means'][i,:] = np.mean(x,(0,1))
        if i%1000 == 0:
            print(i+1,"images processed for mean...")         

### Define Indices ###
indices_path = '/kaggle/input/extrafiles/indices_isic2024.pkl';

with open(indices_path,'rb') as f:
    indices = pickle.load(f)            
mdlParams['trainIndCV'] = indices['trainIndCV']
mdlParams['valIndCV'] = indices['valIndCV']
if mdlParams['exclude_inds']:
    exclude_list = np.array(exclude_list)
    all_inds = np.arange(len(mdlParams['im_paths']))
    exclude_inds = all_inds[exclude_list.astype(bool)]
    for i in range(len(mdlParams['trainIndCV'])):
        mdlParams['trainIndCV'][i] = np.setdiff1d(mdlParams['trainIndCV'][i],exclude_inds)
    for i in range(len(mdlParams['valIndCV'])):
        mdlParams['valIndCV'][i] = np.setdiff1d(mdlParams['valIndCV'][i],exclude_inds)     
# Consider case with more than one set
if len(mdlParams['dataset_names']) > 1:
    restInds = np.array(np.arange(25331,mdlParams['labels_array'].shape[0]))
    for i in range(mdlParams['numCV']):
        mdlParams['trainIndCV'][i] = np.concatenate((mdlParams['trainIndCV'][i],restInds))        
print("Train")
for i in range(len(mdlParams['trainIndCV'])):
    print(mdlParams['trainIndCV'][i].shape)
print("Val")
for i in range(len(mdlParams['valIndCV'])):
    print(mdlParams['valIndCV'][i].shape)    

# Use this for ordered multi crops
if mdlParams['orderedCrop']:
    # Crop positions, always choose multiCropEval to be 4, 9, 16, 25, etc.
    mdlParams['cropPositions'] = np.zeros([len(mdlParams['im_paths']),mdlParams['multiCropEval'],2],dtype=np.int64)
    #mdlParams['imSizes'] = np.zeros([len(mdlParams['im_paths']),mdlParams['multiCropEval'],2],dtype=np.int64)
    for u in range(len(mdlParams['im_paths'])):
        height, width = imagesize.get(mdlParams['im_paths'][u])
        if width < mdlParams['input_size'][0]:
            height = int(mdlParams['input_size'][0]/float(width))*height
            width = mdlParams['input_size'][0]
        if height < mdlParams['input_size'][0]:
            width = int(mdlParams['input_size'][0]/float(height))*width
            height = mdlParams['input_size'][0]            
        ind = 0
        for i in range(np.int32(np.sqrt(mdlParams['multiCropEval']))):
            for j in range(np.int32(np.sqrt(mdlParams['multiCropEval']))):
                mdlParams['cropPositions'][u,ind,0] = mdlParams['input_size'][0]/2+i*((width-mdlParams['input_size'][1])/(np.sqrt(mdlParams['multiCropEval'])-1))
                mdlParams['cropPositions'][u,ind,1] = mdlParams['input_size'][1]/2+j*((height-mdlParams['input_size'][0])/(np.sqrt(mdlParams['multiCropEval'])-1))
                #mdlParams['imSizes'][u,ind,0] = curr_im_size[0]

                ind += 1
    # Sanity checks
    #print("Positions",mdlParams['cropPositions'])
    # Test image sizes
    height = mdlParams['input_size'][0]
    width = mdlParams['input_size'][1]
    for u in range(len(mdlParams['im_paths'])):
        height_test, width_test = imagesize.get(mdlParams['im_paths'][u])
        if width_test < mdlParams['input_size'][0]:
            height_test = int(mdlParams['input_size'][0]/float(width_test))*height_test
            width_test = mdlParams['input_size'][0]
        if height_test < mdlParams['input_size'][0]:
            width_test = int(mdlParams['input_size'][0]/float(height_test))*width_test
            height_test = mdlParams['input_size'][0]                
        test_im = np.zeros([width_test,height_test]) 
        for i in range(mdlParams['multiCropEval']):
            im_crop = test_im[np.int32(mdlParams['cropPositions'][u,i,0]-height/2):np.int32(mdlParams['cropPositions'][u,i,0]-height/2)+height,np.int32(mdlParams['cropPositions'][u,i,1]-width/2):np.int32(mdlParams['cropPositions'][u,i,1]-width/2)+width]
            if im_crop.shape[0] != mdlParams['input_size'][0]:
                print("Wrong shape",im_crop.shape[0],mdlParams['im_paths'][u])    
            if im_crop.shape[1] != mdlParams['input_size'][1]:
                print("Wrong shape",im_crop.shape[1],mdlParams['im_paths'][u])

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models as tv_models
from torch.utils.data import DataLoader
from torchsummary import summary
import numpy as np
from scipy import io
import threading
import pickle
from pathlib import Path
import math
import os
import sys
from glob import glob
import re
import gc
import importlib
import time
import sklearn.preprocessing
from sklearn.utils import class_weight
import psutil

# Import machine config
# pc_cfg = importlib.import_module('pc_cfgs.'+sys.argv[1])
# mdlParams.update(pc_cfg.mdlParams)
# example. py

mdlParams['saveDirBase'] = '/kaggle/working'

# Indicate training
mdlParams['trainSetState'] = 'train'

# Set visible devices
if 'gpu' in 'gpu0':
    mdlParams['numGPUs']= [[int(s) for s in re.findall(r'\d+','gpu0')][-1]]
    cuda_str = ""
    for i in range(len(mdlParams['numGPUs'])):
        cuda_str = cuda_str + str(mdlParams['numGPUs'][i])
        if i is not len(mdlParams['numGPUs'])-1:
            cuda_str = cuda_str + ","
    print("Devices to use:",cuda_str)
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_str      

# Check if there is a validation set, if not, evaluate train error instead
if 'valIndCV' in mdlParams or 'valInd' in mdlParams:
    eval_set = 'valInd'
    print("Evaluating on validation set during training.")
else:
    eval_set = 'trainInd'
    print("No validation set, evaluating on training set during training.")

# Check if there were previous ones that have alreary bin learned
prevFile = Path(mdlParams['saveDirBase'] + '/CV.pkl')
#print(prevFile)
if prevFile.exists():
    print("Part of CV already done")
    with open(mdlParams['saveDirBase'] + '/CV.pkl', 'rb') as f:
        allData = pickle.load(f)
else:
    allData = {}
    allData['f1Best'] = {}
    allData['sensBest'] = {}
    allData['specBest'] = {}
    allData['accBest'] = {}
    allData['waccBest'] = {}
    allData['aucBest'] = {}
    allData['convergeTime'] = {}
    allData['bestPred'] = {}
    allData['targets'] = {}


# Take care of CV
if mdlParams.get('cv_subset',None) is not None:
    cv_set = mdlParams['cv_subset']
else:
    cv_set = range(mdlParams['numCV'])
for cv in cv_set:  
    # Check if this fold was already trained
    already_trained = False
    if 'valIndCV' in mdlParams:
        mdlParams['saveDir'] = '/kaggle/input/output' + '/CVSet' + str(cv)
        if os.path.isdir('/kaggle/input/output'):
            if os.path.isdir(mdlParams['saveDir']):
                all_max_iter = []
                for name in os.listdir(mdlParams['saveDir']):
                    int_list = [int(s) for s in re.findall(r'\d+',name)]
                    if len(int_list) > 0:
                        all_max_iter.append(int_list[-1])
                    #if '-' + str(mdlParams['training_steps'])+ '.pt' in name:
                    #    print("Fold %d already fully trained"%(cv))
                    #    already_trained = True
                all_max_iter = np.array(all_max_iter)
                if len(all_max_iter) > 0 and np.max(all_max_iter) >= mdlParams['training_steps']:
                    print("Fold %d already fully trained with %d iterations"%(cv,np.max(all_max_iter)))
                    already_trained = True
    if already_trained:
        continue        
    print("CV set",cv)
    # Reset model graph 
    # importlib.reload(models)
    # importlib.reload(torchvision)
    # Collect model variables
    modelVars = {}
    #print("here")
    modelVars['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(modelVars['device'])
    # Def current CV set
    mdlParams['trainInd'] = mdlParams['trainIndCV'][cv]
    if 'valIndCV' in mdlParams:
        mdlParams['valInd'] = mdlParams['valIndCV'][cv]
    # Def current path for saving stuff
    if 'valIndCV' in mdlParams:
        mdlParams['saveDir'] = mdlParams['saveDirBase'] + '/CVSet' + str(cv)
    else:
        mdlParams['saveDir'] = mdlParams['saveDirBase']
    # Create basepath if it doesnt exist yet
    if not os.path.isdir(mdlParams['saveDirBase']):
        os.mkdir(mdlParams['saveDirBase'])
    # Check if there is something to load
    load_old = 0
    if os.path.isdir(mdlParams['saveDir']):
        # Check if a checkpoint is in there
        if len([name for name in os.listdir(mdlParams['saveDir'])]) > 0:
            load_old = 1
            print("Loading old model")
        else:
            # Delete whatever is in there (nothing happens)
            filelist = [os.remove(mdlParams['saveDir'] +'/'+f) for f in os.listdir(mdlParams['saveDir'])]
    else:
        os.mkdir(mdlParams['saveDir'])
    # Save training progress in here
    save_dict = {}
    save_dict['acc'] = []
    save_dict['loss'] = []
    save_dict['wacc'] = []
    save_dict['auc'] = []
    save_dict['sens'] = []
    save_dict['spec'] = []
    save_dict['f1'] = []
    save_dict['step_num'] = []
    if mdlParams['print_trainerr']:
        save_dict_train = {}
        save_dict_train['acc'] = []
        save_dict_train['loss'] = []
        save_dict_train['wacc'] = []
        save_dict_train['auc'] = []
        save_dict_train['sens'] = []
        save_dict_train['spec'] = []
        save_dict_train['f1'] = []
        save_dict_train['step_num'] = []        
    # Potentially calculate setMean to subtract
    if mdlParams['subtract_set_mean'] == 1:
        mdlParams['setMean'] = np.mean(mdlParams['images_means'][mdlParams['trainInd'],:],(0))
        print("Set Mean",mdlParams['setMean']) 

    # balance classes
    if mdlParams['balance_classes'] < 3 or mdlParams['balance_classes'] == 7 or mdlParams['balance_classes'] == 11:
        # class_weights = class_weight.compute_class_weight('balanced',np.unique(np.argmax(mdlParams['labels_array'][mdlParams['trainInd'],:],1)),np.argmax(mdlParams['labels_array'][mdlParams['trainInd'],:],1)) 
        class_weights = class_weight.compute_class_weight(
            class_weight='balanced',  # Use keyword argument
            classes=np.unique(np.argmax(mdlParams['labels_array'][mdlParams['trainInd'], :], axis=1)),  # Unique classes
            y=np.argmax(mdlParams['labels_array'][mdlParams['trainInd'], :], axis=1)  # Class labels
        )
        print("Current class weights",class_weights)
        class_weights = class_weights*mdlParams['extra_fac']
        print("Current class weights with extra",class_weights)             
    elif mdlParams['balance_classes'] == 3 or mdlParams['balance_classes'] == 4:
        # Split training set by classes
        not_one_hot = np.argmax(mdlParams['labels_array'],1)
        mdlParams['class_indices'] = []
        for i in range(mdlParams['numClasses']):
            mdlParams['class_indices'].append(np.where(not_one_hot==i)[0])
            # Kick out non-trainind indices
            mdlParams['class_indices'][i] = np.setdiff1d(mdlParams['class_indices'][i],mdlParams['valInd'])
            #print("Class",i,mdlParams['class_indices'][i].shape,np.min(mdlParams['class_indices'][i]),np.max(mdlParams['class_indices'][i]),np.sum(mdlParams['labels_array'][np.int64(mdlParams['class_indices'][i]),:],0))        
    elif mdlParams['balance_classes'] == 5 or mdlParams['balance_classes'] == 6 or mdlParams['balance_classes'] == 13:
        # Other class balancing loss
        class_weights = 1.0/np.mean(mdlParams['labels_array'][mdlParams['trainInd'],:],axis=0)
        print("Current class weights",class_weights)
        if isinstance(mdlParams['extra_fac'], float):
            class_weights = np.power(class_weights,mdlParams['extra_fac'])
        else:
            class_weights = class_weights*mdlParams['extra_fac']
        print("Current class weights with extra",class_weights) 
    elif mdlParams['balance_classes'] == 9:
        # Only use official indicies for calculation
        print("Balance 9")
        indices_ham = mdlParams['trainInd'][mdlParams['trainInd'] < 25331]
        if mdlParams['numClasses'] == 9:
            class_weights_ = 1.0/np.mean(mdlParams['labels_array'][indices_ham,:8],axis=0)
            #print("class before",class_weights_)
            class_weights = np.zeros([mdlParams['numClasses']])
            class_weights[:8] = class_weights_
            class_weights[-1] = np.max(class_weights_)
        else:
            class_weights = 1.0/np.mean(mdlParams['labels_array'][indices_ham,:],axis=0)
        print("Current class weights",class_weights)             
        if isinstance(mdlParams['extra_fac'], float):
            class_weights = np.power(class_weights,mdlParams['extra_fac'])
        else:
            class_weights = class_weights*mdlParams['extra_fac']
        print("Current class weights with extra",class_weights)             

    # Meta scaler
    if mdlParams.get('meta_features',None) is not None and mdlParams['scale_features']:
        mdlParams['feature_scaler_meta'] = sklearn.preprocessing.StandardScaler().fit(mdlParams['meta_array'][mdlParams['trainInd'],:])  
        print("scaler mean",mdlParams['feature_scaler_meta'].mean_,"var",mdlParams['feature_scaler_meta'].var_)  

    # Set up dataloaders
    num_workers = psutil.cpu_count(logical=False)
    # For train
    dataset_train = ISICDataset(mdlParams, 'trainInd')
    # For val
    dataset_val = ISICDataset(mdlParams, 'valInd')
    if mdlParams['multiCropEval'] > 0:
        modelVars['dataloader_valInd'] = DataLoader(dataset_val, batch_size=mdlParams['multiCropEval'], shuffle=False, num_workers=num_workers, pin_memory=True)  
    else:
        modelVars['dataloader_valInd'] = DataLoader(dataset_val, batch_size=mdlParams['batchSize'], shuffle=False, num_workers=num_workers, pin_memory=True)               

    if mdlParams['balance_classes'] == 12 or mdlParams['balance_classes'] == 13:
        #print(np.argmax(mdlParams['labels_array'][mdlParams['trainInd'],:],1).size(0))
        strat_sampler = StratifiedSampler(mdlParams)
        modelVars['dataloader_trainInd'] = DataLoader(dataset_train, batch_size=mdlParams['batchSize'], sampler=strat_sampler, num_workers=num_workers, pin_memory=True) 
    else:
        modelVars['dataloader_trainInd'] = DataLoader(dataset_train, batch_size=mdlParams['batchSize'], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) 
    #print("Setdiff",np.setdiff1d(mdlParams['trainInd'],mdlParams['trainInd']))
    # Define model 
    modelVars['model'] = getModel(mdlParams)()  
    # Load trained model
    if mdlParams.get('meta_features',None) is not None:
        # Find best checkpoint
        files = glob(mdlParams['model_load_path'] + '/CVSet' + str(cv) + '/*')
        global_steps = np.zeros([len(files)])
        #print("files",files)
        for i in range(len(files)):
            # Use meta files to find the highest index
            if 'best' not in files[i]:
                continue
            if 'checkpoint' not in files[i]:
                continue                
            # Extract global step
            nums = [int(s) for s in re.findall(r'\d+',files[i])]
            global_steps[i] = nums[-1]
        # Create path with maximum global step found
        chkPath = mdlParams['model_load_path'] + '/CVSet' + str(cv) + '/checkpoint_best-' + str(int(np.max(global_steps))) + '.pt'
        print("Restoring lesion-trained CNN for meta data training: ",chkPath)
        # Load
        state = torch.load(chkPath)
        # Initialize model
        curr_model_dict = modelVars['model'].state_dict()
        for name, param in state['state_dict'].items():
            #print(name,param.shape)
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            if curr_model_dict[name].shape == param.shape:
                curr_model_dict[name].copy_(param)
            else:
                print("not restored",name,param.shape)
        #modelVars['model'].load_state_dict(state['state_dict'])        
    # Original input size
    #if 'Dense' not in mdlParams['model_type']:
    #    print("Original input size",modelVars['model'].input_size)
    #print(modelVars['model'])
    if 'Dense' in mdlParams['model_type']:
        if mdlParams['input_size'][0] != 224:
            modelVars['model'] = modify_densenet_avg_pool(modelVars['model'])
            #print(modelVars['model'])
        num_ftrs = modelVars['model'].classifier.in_features
        modelVars['model'].classifier = nn.Linear(num_ftrs, mdlParams['numClasses'])
        #print(modelVars['model'])
    elif 'dpn' in mdlParams['model_type']:
        num_ftrs = modelVars['model'].classifier.in_channels
        modelVars['model'].classifier = nn.Conv2d(num_ftrs,mdlParams['numClasses'],[1,1])
        #modelVars['model'].add_module('real_classifier',nn.Linear(num_ftrs, mdlParams['numClasses']))
        #print(modelVars['model'])
    elif 'efficient' in mdlParams['model_type']:
        # Do nothing, output is prepared
        num_ftrs = modelVars['model']._fc.in_features
        modelVars['model']._fc = nn.Linear(num_ftrs, mdlParams['numClasses'])    
    elif 'wsl' in mdlParams['model_type']:
        num_ftrs = modelVars['model'].fc.in_features
        modelVars['model'].fc = nn.Linear(num_ftrs, mdlParams['numClasses'])          
    else:
        num_ftrs = modelVars['model'].last_linear.in_features
        modelVars['model'].last_linear = nn.Linear(num_ftrs, mdlParams['numClasses'])    
    # Take care of meta case
    if mdlParams.get('meta_features',None) is not None:
        # freeze cnn first
        if mdlParams['freeze_cnn']:
            # deactivate all
            for param in modelVars['model'].parameters():
                param.requires_grad = False            
            if 'efficient' in mdlParams['model_type']:
                # Activate fc
                for param in modelVars['model']._fc.parameters():
                    param.requires_grad = True
            elif 'wsl' in mdlParams['model_type']:
                # Activate fc
                for param in modelVars['model'].fc.parameters():
                    param.requires_grad = True
            else:
                # Activate fc
                for param in modelVars['model'].last_linear.parameters():
                    param.requires_grad = True                                
        else:
            # mark cnn parameters
            for param in modelVars['model'].parameters():
                param.is_cnn_param = True
            # unmark fc
            for param in modelVars['model']._fc.parameters():
                param.is_cnn_param = False                              
        # modify model
        print(mdlParams)
        modelVars['model'] = models.modify_meta(mdlParams,modelVars['model'])  
        # Mark new parameters
        for param in modelVars['model'].parameters():
            if not hasattr(param, 'is_cnn_param'):
                param.is_cnn_param = False                 
    # multi gpu support
    if len(mdlParams['numGPUs']) > 1:
        modelVars['model'] = nn.DataParallel(modelVars['model']) 
    modelVars['model'] = modelVars['model'].cuda()
    #summary(modelVars['model'], modelVars['model'].input_size)# (mdlParams['input_size'][2], mdlParams['input_size'][0], mdlParams['input_size'][1]))
    # Loss, with class weighting
    if mdlParams.get('focal_loss',False):
        modelVars['criterion'] = FocalLoss(alpha=class_weights.tolist())
    elif mdlParams['balance_classes'] == 3 or mdlParams['balance_classes'] == 0 or mdlParams['balance_classes'] == 12:
        modelVars['criterion'] = nn.CrossEntropyLoss()
    elif mdlParams['balance_classes'] == 8:
        modelVars['criterion'] = nn.CrossEntropyLoss(reduce=False)
    elif mdlParams['balance_classes'] == 6 or mdlParams['balance_classes'] == 7:
        modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)),reduce=False)
    elif mdlParams['balance_classes'] == 10:
        modelVars['criterion'] = FocalLoss(mdlParams['numClasses'])
    elif mdlParams['balance_classes'] == 11:
        modelVars['criterion'] = FocalLoss(mdlParams['numClasses'],alpha=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
    else:
        modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))

    if mdlParams.get('meta_features',None) is not None:
        if mdlParams['freeze_cnn']:
            modelVars['optimizer'] = optim.Adam(filter(lambda p: p.requires_grad, modelVars['model'].parameters()), lr=mdlParams['learning_rate_meta'])
            # sanity check
            for param in filter(lambda p: p.requires_grad, modelVars['model'].parameters()):
                print(param.name,param.shape)
        else:
            modelVars['optimizer'] = optim.Adam([
                                                {'params': filter(lambda p: not p.is_cnn_param, modelVars['model'].parameters()), 'lr': mdlParams['learning_rate_meta']},
                                                {'params': filter(lambda p: p.is_cnn_param, modelVars['model'].parameters()), 'lr': mdlParams['learning_rate']}
                                                ], lr=mdlParams['learning_rate'])
    else:
        modelVars['optimizer'] = optim.Adam(modelVars['model'].parameters(), lr=mdlParams['learning_rate'])

    # Decay LR by a factor of 0.1 every 7 epochs
    modelVars['scheduler'] = lr_scheduler.StepLR(modelVars['optimizer'], step_size=mdlParams['lowerLRAfter'], gamma=1/np.float32(mdlParams['LRstep']))

    # Define softmax
    modelVars['softmax'] = nn.Softmax(dim=1)

    
    # Set up training
    # loading from checkpoint
    if load_old:
        # Find last, not last best checkpoint
        files = glob(mdlParams['saveDir']+'/*')
        global_steps = np.zeros([len(files)])
        for i in range(len(files)):
            # Use meta files to find the highest index
            if 'best' in files[i]:
                continue
            if 'checkpoint-' not in files[i]:
                continue                
            # Extract global step
            nums = [int(s) for s in re.findall(r'\d+',files[i])]
            global_steps[i] = nums[-1]
        # Create path with maximum global step found
        chkPath = mdlParams['saveDir'] + '/checkpoint-' + str(int(np.max(global_steps))) + '.pt'
        print("Restoring: ",chkPath)
        # Load
        state = torch.load(chkPath)
        # Initialize model and optimizer
        modelVars['model'].load_state_dict(state['state_dict'])
        modelVars['optimizer'].load_state_dict(state['optimizer'])     
        start_epoch = state['epoch']+1
        mdlParams['valBest'] = state.get('valBest',1000)
        mdlParams['lastBestInd'] = state.get('lastBestInd',int(np.max(global_steps)))
    else:
        start_epoch = 1
        mdlParams['lastBestInd'] = -1
        # Track metrics for saving best model
        mdlParams['valBest'] = 1000

    # Num batches
    numBatchesTrain = int(math.floor(len(mdlParams['trainInd'])/mdlParams['batchSize']))
    print("Train batches",numBatchesTrain)

    # Run training
    start_time = time.time()
    print("Start training...")
    for step in range(start_epoch, mdlParams['training_steps']+1):
        # One Epoch of training
        if step >= mdlParams['lowerLRat']-mdlParams['lowerLRAfter']:
            modelVars['scheduler'].step()
        modelVars['model'].train()      
        for j, (inputs, labels, indices) in enumerate(modelVars['dataloader_trainInd']):    
            #print(indices)                  
            #t_load = time.time() 
            # Run optimization        
            if mdlParams.get('meta_features',None) is not None: 
                inputs[0] = inputs[0].cuda()
                inputs[1] = inputs[1].cuda()
            else:
                inputs = inputs.cuda()
            #print(inputs.shape)
            labels = labels.cuda()        
            # zero the parameter gradients
            modelVars['optimizer'].zero_grad()             
            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):             
                if mdlParams.get('aux_classifier',False):
                    outputs, outputs_aux = modelVars['model'](inputs) 
                    loss1 = modelVars['criterion'](outputs, labels)
                    labels_aux = labels.repeat(mdlParams['multiCropTrain'])
                    loss2 = modelVars['criterion'](outputs_aux, labels_aux) 
                    loss = loss1 + mdlParams['aux_classifier_loss_fac']*loss2     
                else:               
                    #print("load",time.time()-t_load)    
                    #t_fwd = time.time()   
                    outputs = modelVars['model'](inputs)     
                    #print("forward",time.time()-t_fwd)     
                    #t_bwd = time.time()   
                    loss = modelVars['criterion'](outputs, labels)         
                # Perhaps adjust weighting of the loss by the specific index
                if mdlParams['balance_classes'] == 6 or mdlParams['balance_classes'] == 7 or mdlParams['balance_classes'] == 8:
                    #loss = loss.cpu()
                    indices = indices.numpy()
                    loss = loss*torch.cuda.FloatTensor(mdlParams['loss_fac_per_example'][indices].astype(np.float32))
                    loss = torch.mean(loss)
                    #loss = loss.cuda()
                # backward + optimize only if in training phase
                loss.backward()                 
                modelVars['optimizer'].step()     
                #print("backward",time.time()-t_bwd)                             
        if step % mdlParams['display_step'] == 0 or step == 1:
            # Calculate evaluation metrics
            if mdlParams['classification']:
                # Adjust model state
                modelVars['model'].eval()
                # Get metrics
                loss, accuracy, sensitivity, specificity, conf_matrix, f1, auc, waccuracy, predictions, targets, _ = getErrClassification_mgpu(mdlParams, eval_set, modelVars)
                # Save in mat
                save_dict['loss'].append(loss)
                save_dict['acc'].append(accuracy)
                save_dict['wacc'].append(waccuracy)
                save_dict['auc'].append(auc)
                save_dict['sens'].append(sensitivity)
                save_dict['spec'].append(specificity)
                save_dict['f1'].append(f1)
                save_dict['step_num'].append(step)
                if os.path.isfile(mdlParams['saveDir'] + '/progression_'+eval_set+'.mat'):
                    os.remove(mdlParams['saveDir'] + '/progression_'+eval_set+'.mat')                
                io.savemat(mdlParams['saveDir'] + '/progression_'+eval_set+'.mat',save_dict)                
            eval_metric = -np.mean(waccuracy)
            # Check if we have a new best value
            if eval_metric < mdlParams['valBest']:
                mdlParams['valBest'] = eval_metric
                if mdlParams['classification']:
                    allData['f1Best'][cv] = f1
                    allData['sensBest'][cv] = sensitivity
                    allData['specBest'][cv] = specificity
                    allData['accBest'][cv] = accuracy
                    allData['waccBest'][cv] = waccuracy
                    allData['aucBest'][cv] = auc
                oldBestInd = mdlParams['lastBestInd']
                mdlParams['lastBestInd'] = step
                allData['convergeTime'][cv] = step
                # Save best predictions
                allData['bestPred'][cv] = predictions
                allData['targets'][cv] = targets
                # Write to File
                with open(mdlParams['saveDirBase'] + '/CV.pkl', 'wb') as f:
                    pickle.dump(allData, f, pickle.HIGHEST_PROTOCOL)                 
                # Delte previously best model
                if os.path.isfile(mdlParams['saveDir'] + '/checkpoint_best-' + str(oldBestInd) + '.pt'):
                    os.remove(mdlParams['saveDir'] + '/checkpoint_best-' + str(oldBestInd) + '.pt')
                # Save currently best model
                state = {'epoch': step, 'valBest': mdlParams['valBest'], 'lastBestInd': mdlParams['lastBestInd'], 'state_dict': modelVars['model'].state_dict(),'optimizer': modelVars['optimizer'].state_dict()}
                torch.save(state, mdlParams['saveDir'] + '/checkpoint_best-' + str(step) + '.pt')               
                            
            # If its not better, just save it delete the last checkpoint if it is not current best one
            # Save current model
            state = {'epoch': step, 'valBest': mdlParams['valBest'], 'lastBestInd': mdlParams['lastBestInd'], 'state_dict': modelVars['model'].state_dict(),'optimizer': modelVars['optimizer'].state_dict()}
            torch.save(state, mdlParams['saveDir'] + '/checkpoint-' + str(step) + '.pt')                           
            # Delete last one
            if step == mdlParams['display_step']:
                lastInd = 1
            else:
                lastInd = step-mdlParams['display_step']
            if os.path.isfile(mdlParams['saveDir'] + '/checkpoint-' + str(lastInd) + '.pt'):
                os.remove(mdlParams['saveDir'] + '/checkpoint-' + str(lastInd) + '.pt')       
            # Duration so far
            duration = time.time() - start_time                          
            # Print
            if mdlParams['classification']:
                print("\n")
                print("Config:",sys.argv[2])
                print('Fold: %d Epoch: %d/%d (%d h %d m %d s)' % (cv,step,mdlParams['training_steps'], int(duration/3600), int(np.mod(duration,3600)/60), int(np.mod(np.mod(duration,3600),60))) + time.strftime("%d.%m.-%H:%M:%S", time.localtime()))
                print("Loss on ",eval_set,"set: ",loss," Accuracy: ",accuracy," F1: ",f1," (best WACC: ",-mdlParams['valBest']," at Epoch ",mdlParams['lastBestInd'],")")
                print("Auc",auc,"Mean AUC",np.mean(auc))
                print("Per Class Acc",waccuracy,"Weighted Accuracy",np.mean(waccuracy))
                print("Sensitivity: ",sensitivity,"Specificity",specificity)
                print("Confusion Matrix")
                print(conf_matrix)
                # Potentially peek at test error
                if mdlParams['peak_at_testerr']:              
                    loss, accuracy, sensitivity, specificity, _, f1, _, _, _, _, _ = getErrClassification_mgpu(mdlParams, 'testInd', modelVars)
                    print("Test loss: ",loss," Accuracy: ",accuracy," F1: ",f1)
                    print("Sensitivity: ",sensitivity,"Specificity",specificity)
                # Potentially print train err
                if mdlParams['print_trainerr'] and 'train' not in eval_set:                
                    loss, accuracy, sensitivity, specificity, conf_matrix, f1, auc, waccuracy, predictions, targets, _ = getErrClassification_mgpu(mdlParams, 'trainInd', modelVars)
                    # Save in mat
                    save_dict_train['loss'].append(loss)
                    save_dict_train['acc'].append(accuracy)
                    save_dict_train['wacc'].append(waccuracy)
                    save_dict_train['auc'].append(auc)
                    save_dict_train['sens'].append(sensitivity)
                    save_dict_train['spec'].append(specificity)
                    save_dict_train['f1'].append(f1)
                    save_dict_train['step_num'].append(step)
                    if os.path.isfile(mdlParams['saveDir'] + '/progression_trainInd.mat'):
                        os.remove(mdlParams['saveDir'] + '/progression_trainInd.mat')                
                    scipy.io.savemat(mdlParams['saveDir'] + '/progression_trainInd.mat',save_dict_train)                     
                    print("Train loss: ",loss," Accuracy: ",accuracy," F1: ",f1)
                    print("Sensitivity: ",sensitivity,"Specificity",specificity)
    # Free everything in modelvars
    modelVars.clear()
    # After CV Training: print CV results and save them
    print("Best F1:",allData['f1Best'][cv])
    print("Best Sens:",allData['sensBest'][cv])
    print("Best Spec:",allData['specBest'][cv])
    print("Best Acc:",allData['accBest'][cv])
    print("Best Per Class Accuracy:",allData['waccBest'][cv])
    print("Best Weighted Acc:",np.mean(allData['waccBest'][cv]))
    print("Best AUC:",allData['aucBest'][cv])
    print("Best Mean AUC:",np.mean(allData['aucBest'][cv]))    
    print("Convergence Steps:",allData['convergeTime'][cv])