In [None]:
import os
import time
from sklearn.metrics import f1_score, precision_score, recall_score
# , classification_report, confusion_matrix
from models.net import WideOrthoResNet, OrthoVGG
from layers.blocks import BasicBlock, HadamardBlock, HarmonicBlock, SlantBlock
from loader import LoaderSmall
import pandas as pd
import warnings
import torch
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from torch import nn
from torch.nn import functional as F
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import gc
from typing import List  # pylint: ignore
from tqdm import tqdm_notebook as tqdm
# import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from albumentations.augmentations import transforms as T
import random
import argparse
import matplotlib
from matplotlib import pyplot as plt

matplotlib.use('AGG')

warnings.filterwarnings('ignore')


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(42)

weights=None
dataset='skin'

if dataset not in ['cifar10', 'cifar100', 'imagenet', 'isic2019']:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    '''transforms.ToTensor(),
    transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                      (4, 4, 4, 4),
                                      mode='reflect').squeeze()),
    transforms.ToPILImage(),
    transforms.RandomCrop(64, padding=4),
    transforms.RandomHorizontalFlip(),'''
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    print("Loading metadata...")
    metadata = pd.read_csv('metadata/HAM10000_metadata.csv')
    enc = LabelEncoder()
    metadata['dx'] = enc.fit_transform(metadata['dx'])
    metadata['dx_type'] = enc.fit_transform(metadata['dx_type'])
    metadata['sex'] = enc.fit_transform(metadata['sex'])
    metadata['localization'] = enc.fit_transform(metadata['localization'])
    metadata['lesion_id'] = enc.fit_transform(metadata['lesion_id'])
    labels = metadata.dx.values
    imageid_path_dict = {x: f'./HAM10000_small/{x}.jpg' for x in metadata.image_id}

    print("Loading Images...")
    train_names, val_names, \
    train_labels, val_labels = train_test_split(
        np.asarray(list(imageid_path_dict.keys())),
        labels,
        test_size=0.15)
    target = torch.from_numpy(train_labels)
    class_sample_count = torch.tensor(
        [(target == t).sum() for t in torch.unique(target, sorted=True)])
    weight = 1. / class_sample_count.float()
    samples_weight = torch.tensor([weight[t] for t in target]).float()
    train_sampler = torch.utils \
        .data.WeightedRandomSampler(samples_weight,
                                    len(samples_weight))
    target = torch.from_numpy(val_labels)
    class_sample_count = torch.tensor(
        [(target == t).sum() for t in torch.unique(target, sorted=True)])
    weight = 1. / class_sample_count.float()
    samples_weight = torch.tensor([weight[t] for t in target]).float()
    test_sampler = torch.utils \
        .data.WeightedRandomSampler(samples_weight,
                                    len(samples_weight))
    trainset = LoaderSmall(imageid_path_dict,
                           labels = train_labels,
                           names = train_names,
                           weights=weights,
                           weighting=False,
                           transform=transform_train,
                           color_space=None)
    testset = LoaderSmall(imageid_path_dict,
                          labels=val_labels,
                          names=val_names,
                          weighting=False,
                          transform=transform_test,
                          color_space=None)
    
    bs=32

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs,
                                              sampler=train_sampler,
                                              shuffle=False,
                                              num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=bs,
                                             shuffle=False,
                                             sampler=test_sampler,
                                             num_workers=2)
    arch='wrn'
    if arch == 'wrn':
        net = WideOrthoResNet(in_channels=3,
                              block=HadamardBlock,
                              alpha_root=None,
                              kernel_size=4,
                              depth=22,
                              widen_factor=3,
                              num_classes=7,
                              lmbda=None,
                              diag=False)

gc.collect()
print("Number of trainable parameters:", net._num_parameters()[0])

net = net.cuda()
net = torch.nn.DataParallel(
    net,
    device_ids=list(range(torch.cuda.device_count())))
base_lr = 0.1
best_acc = 0
start_epoch = 0

train_losses = []  # type: List[float]
test_losses = []  # type: List[float]
train_accs = []  # type: List[float]
test_accs = []  # type: List[float]
train_error = []
test_error = []

criterion = nn.CrossEntropyLoss()#weight=torch.from_numpy(weights).cuda())

optimizer = optim.SGD(params=net.parameters(),
                      lr=base_lr,
                      momentum=0.9,
                      dampening=0,
                      weight_decay=5e-4,
                      nesterov=True)

gc.collect()


def get_lr(optimizer=optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def one_hot_enc(output, target, num_classes=7):
    labels = target.view((-1, 1))
    batch_size, _ = labels.size()
    labels_one_hot = torch.FloatTensor(
        batch_size, num_classes).zero_().to('cuda')
    labels_one_hot.scatter_(1, labels, 1)
    return labels_one_hot


# Training (https://github.com/kuangliu/pytorch-cifar/blob/master/main.py)
def train(epoch, stop_after=1):
    # print('\nEpoch: %d' % epoch)
    print(f"\nEpoch: {epoch}, learning rate = {lr:1.1e};")
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        if stop_after is None or batch_idx<=stop_after:
            inputs, targets = inputs.to('cuda'), targets.to('cuda')
            outputs = net(inputs)

            optimizer.zero_grad()
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print("Pred: ", predicted)
    print("Truth: ", targets)
    print(
        'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss / (len(trainloader)), 100. * correct / total, correct, total))
    train_accs.append(100. * correct / total)
    train_error.append(100. - 100. * correct / total)
    train_losses.append(train_loss)


def test(epoch, stop_after=1):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):
            if stop_after is None or batch_idx<=stop_after:
                inputs, targets = inputs.to('cuda'), targets.to('cuda')
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)

                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        print("Pred: ", predicted)
        print(f"Truth: {targets}\n")
        print('Loss: %.3f | Acc: %.3f%% (%d/%d) | Error: %.3f%%' % (
            test_loss / (len(testloader)), 100. * correct / total, correct, total, 100. * (1. - correct / total)))
        test_accs.append(100. * correct / total)
        test_error.append(100. - 100. * correct / total)
        test_losses.append(test_losses)
    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:
        '''print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'lr': get_lr(optimizer),
            # 'optimizer': optimizer.state_dict()

        }
        if not os.path.isdir(f'checkpoint_{args.dataset}_{time.strftime("%Y_%m_%d")}'):
            os.mkdir(f'checkpoint_{args.dataset}_{time.strftime("%Y_%m_%d")}')
        a = ''
        if args.alpha_root is not None:
            a = f'alpha_{args.alpha_root}_'

        torch.save(state,
                   f'./checkpoint_{args.dataset}_{time.strftime("%Y_%m_%d")}' +
                   f'/ckpt_{args.block}_{args.arch}_{args.depth}_{args.widen}_{args.kernel_size}x{args.kernel_size}_' +
                   a + f'_{acc:.2f}.t7')'''
        best_acc = acc


def adjust_learning_rate(optimizer,
                         epoch,
                         update_list=(25, 75),
                         factor=10.,
                         lim=1.):
    # [60, 120, 160]  #[2,5,8,11,14,17,20]
    if epoch in update_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(param_group['lr'] * factor, lim)
    return


def save_state(model, best_acc):
    print('\n==> Saving model ...\n')
    state = {'best_acc': best_acc,
             'state_dict': model.state_dict()}
    keys = list(state['state_dict'].keys())
    for key in keys:
        if 'module' in key:
            state['state_dict'][key.replace('module.', '')] = \
                state['state_dict'].pop(key)
    torch.save(state, 'harmonic_network.tar')




In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
testset[0]

In [None]:
for epoch in range(start_epoch, 200):
    adjust_learning_rate(optimizer, epoch, [60, 120, 160], factor=0.2, lim=1e-6)
    lr = get_lr()
    train(epoch,None)
    test(epoch,0)
    print(f"Best Accuracy: {best_acc:.3f}")
    gc.collect()
    torch.cuda.empty_cache()


In [None]:
class_sample_count = torch.tensor(
    [(trainset.labels == t).sum() for t in torch.unique(trainset.labels, sorted=True)])
class_sample_count

In [None]:
class_sample_count = torch.tensor(
    [(testset.labels == t).sum() for t in torch.unique(testset.labels, sorted=True)])
class_sample_count

In [None]:
metadata = pd.read_csv('metadata/HAM10000_metadata.csv')
enc = LabelEncoder()
metadata['dx'] = enc.fit_transform(metadata['dx'])
metadata['dx_type'] = enc.fit_transform(metadata['dx_type'])
metadata['sex'] = enc.fit_transform(metadata['sex'])
metadata['localization'] = enc.fit_transform(metadata['localization'])
metadata['lesion_id'] = enc.fit_transform(metadata['lesion_id'])
labels = metadata.dx.values
imageid_path_dict = {x: f'./HAM10000_small/{x}.jpg' for x in metadata.image_id}

In [None]:
metadata.dx.value_counts()

In [None]:
metadata['dx_type'].head()

In [None]:
_, batch = next(enumerate(testloader))
labels = batch[1]
torch.tensor(
    [(labels == t).sum() for t in torch.unique(labels, sorted=True)])