Transfer Learning

Data:

- https://www.kaggle.com/datasets/shonenkov/isic2018
- https://www.kaggle.com/datasets/wenewone/isic2018task3validation
- https://www.kaggle.com/datasets/wenewone/isic-groundtruth-for-classification

In [None]:
import torch
import numpy as np
import pandas as pd
import os
from PIL import Image

import warnings
warnings.filterwarnings('ignore')

In [None]:
import time, logging

class Logger:
    def __init__(self, mode='w', verbose=True, title=""):
        # create logger
        self.logger = logging.getLogger()
        self.logger.setLevel(logging.INFO)  # setting level
        formatter = logging.Formatter("[%(asctime)s] %(message)s")
        console_formatter = logging.Formatter("%(message)s")

        # create file handler
        # setting path for logfile 设置日志文件名称
        start_time = time.strftime('%y-%m-%d-%H%M', time.localtime(time.time()))
        log_path = os.path.join(os.getcwd(), 'logs', title)
        if not os.path.exists(log_path):
            os.makedirs(log_path)       
        log_name = os.path.join(log_path, start_time + '.log')
        
        fh = logging.FileHandler(log_name, mode=mode)
        fh.setLevel(logging.INFO)  # setting level for outputs in logfile
        ## define format
        fh.setFormatter(formatter)
        ## add handeler to the logger
        self.logger.addHandler(fh)

        if verbose:
            # create console handeler
            ch = logging.StreamHandler()
            ch.setLevel(logging.INFO)
            ch.setFormatter(console_formatter)
            self.logger.addHandler(ch)

In [None]:
# log file
model_id = 1
title = 'DenseNet 201'
log = Logger(verbose=True, title='densenet')
log.logger.info("Transfer Learning | {} - {}".format(title, model_id))

## Data Preparation

In [None]:
from torch.utils import data

class Annotation(object):
    """ annotate ISIC 2018

    Attributes:
        df(pd.DataFrame): df.columns=['image_id', 'label']
        categories(list): dermatological types
        class_dict(dict): class name -> index
        label_dict(dict): index -> class name
        class_num(int): the number of classes
        
    Usages:
        count_samples(): get numbers of samples in each class

    """
    def __init__(self, ann_file: str) -> None:
        """
        Args:
            ann_file (str): csv file path
        """
        self.df = pd.read_csv(ann_file, header=0)
        self.categories = list(self.df.columns)
        self.categories.pop(0)
        self.class_num = len(self.categories)
        self.class_dict, self.label_dict = self._make_dicts()
        self.df = self._relabel()
        # self.class_nums = self.count_samples()

    def _make_dicts(self):
        """ make class and label dict from categories' names """
        class_dict = {}
        label_dict = {}
        for i, name in enumerate(self.categories):
            class_dict[name] = i
            label_dict[i] = name

        return class_dict, label_dict

    def _relabel(self) -> pd.DataFrame:
        self.df.rename(columns={'image': 'image_id'}, inplace=True)
        self.df['label'] = self.df.select_dtypes(['number']).idxmax(axis=1)
        self.df['label'] = self.df['label'].apply(lambda x: self.class_dict[x])
        for name in self.categories:
            del self.df[name]
        return self.df

    def count_samples(self) -> list:
        """ count sample_nums """
        value_counts = self.df.iloc[:, 1].value_counts()
        class_nums = [value_counts[i] for i in range(len(value_counts))]
        return class_nums

    def to_names(self, nums):
        """ convert a goup of indices to string names 
        
        Args:
            nums(torch.Tensor): a list of number labels

        Return:
            a list of dermatological names
        
        """
        names = [self.label_dict[int(num)] for num in nums]
        return names


class Data(data.Dataset):
    def __init__(self, annotations, img_dir, transform=None, target_transform=None):
        self.img_labels = annotations
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform        

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx: int):
        img_path = os.path.join(self.img_dir, self.img_labels.image_id[idx] + '.jpg')
        image = Image.open(img_path)
        target = self.img_labels['label'].iloc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        return image, target

In [None]:
# path of img data
pth_train = '../input/isic2018/ISIC2018_Task3_Training_Input/ISIC2018_Task3_Training_Input'
pth_test = '../input/isic2018/ISIC2018_Task3_Training_Input/ISIC2018_Task3_Training_Input'
pth_valid = '../input/isic2018task3validation/ISIC2018_Task3_Validation_Input'


ann_train = '../input/isic-groundtruth-for-classification/ISIC2018_Splited_Training_GroundTruth.csv'
ann_test = '../input/isic-groundtruth-for-classification/ISIC2018_Splited_Test_GroundTruth.csv'
ann_valid = '../input/isic-groundtruth-for-classification/ISIC2018_Task3_Validation_GroundTruth.csv'

ann_train = Annotation(ann_train)
ann_valid = Annotation(ann_valid)
ann_test = Annotation(ann_test)

In [None]:
from torchvision import transforms

# standard transform
transform = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],
                                                     [0.229, 0.224, 0.225])
                                ])

# augmentation transform
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.RandomVerticalFlip(p=0.5),
                                      transforms.RandomRotation(30),
                                      transforms.RandomResizedCrop(
                                          224, scale=(0.4, 1), ratio=(3/4, 4/3)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])
                                      ])

In [None]:
# create train dataset
train_data = Data(ann_train.df, pth_train, transform=transform_train)
train_loader = data.DataLoader(train_data, batch_size=50, shuffle=True, drop_last=True, num_workers=8)

# create validation dataset
valid_data = Data(ann_valid.df, pth_valid, transform=transform)
valid_loader = data.DataLoader(valid_data, batch_size=200, shuffle=False, num_workers=4)

# create test dataset
test_data = Data(ann_test.df, pth_test, transform=transform)
# test_loader = data.DataLoader(test_data, batch_size=200, shuffle=False)

## Network Design

In [None]:
import torch.nn as nn

# GPU setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def load_model(device, path='.', name='model.pkl'):
    """
    load model from path/model/name
    """
    pth_model = os.path.join(path, 'model', name)
    assert os.path.exists(pth_model), "Model file doesn't exist!"
    model = torch.load(pth_model, map_location=device)
    print('Load {} on {} successfully.'.format(name, device))
    return model
    

def save_model(model, path='.', name='model.pkl'):
    """ 
    save model to path/model/name
    """

    model_dir = os.path.join(path, 'model')
    if not os.path.exists(model_dir):
      os.makedirs(model_dir)
      
    pth_model = os.path.join(model_dir, name)
    torch.save(model, pth_model)
    print('Model has been saved to {}'.format(pth_model))


def save_state_dict(model, path='.', name='state_dict.pth'):
    """ 
    save state dict to path/model/temp/name
    """

    model_dir = os.path.join(path, 'model', 'temp')
    if not os.path.exists(model_dir):
      os.makedirs(model_dir)
      
    pth_dict = os.path.join(model_dir, name)
    torch.save(model.state_dict(), pth_dict)
    print('State dict has been saved to {}'.format(pth_dict))
    
    
def load_state_dict(model, device, path='.', name='state_dict.pth'):
    """ 
    load model parmas from state_dict
    """
    pth_dict = os.path.join(path, 'model', 'temp', name)
    assert os.path.exists(pth_dict), "State dict file doesn't exist!"
    model.load_state_dict(torch.load(pth_dict, map_location=device))
    return model

In [None]:
from torchvision import models
from collections import OrderedDict
import torch.nn as nn

model = models.densenet201(pretrained=True)

# freeze layers
for param in model.parameters():
    param.requires_grad = False


classifier = nn.Sequential(OrderedDict([
    ('fc0', nn.Linear(1920, 256)),
    ('norm0', nn.BatchNorm1d(256)),
    ('relu0', nn.ReLU(inplace=True)),
    ('fc1', nn.Linear(256, ann_train.class_num))
]))

model.classifier = classifier

model.to(device)

## Train
### Criterion

Focal Loss
$$
{\text{FL}(p_{t}) = - \alpha_t (1 - p_{t})^\gamma \log\left(p_{t}\right)} 
$$

In [None]:
import torch.nn.functional as F
from sklearn import metrics
from sklearn.preprocessing import label_binarize

class FocalLoss(nn.Module):
    def __init__(self, alpha: list, gamma=2, num_classes: int = 7, reduction='mean'):
        """ Focal Loss

        Args:
            alpha (list): 类别权重 class weight
            gamma (int/float): 难易样本调节参数 focusing parameter
            num_classes (int): 类别数量 number of classes
            reduction (string): 'mean', 'sum', 'none'
        """
        super(FocalLoss, self).__init__()
        assert len(alpha) == num_classes, "alpha size doesn't match with class number"
        self.alpha = torch.Tensor(alpha)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, preds, labels):
        """
        Shape:
            preds: [B, N, C] or [B, C]
            labels: [B, N] or [B]
        """
        preds = preds.view(-1, preds.size(-1))
        self.alpha = self.alpha.to(preds.device)
        preds_softmax = F.softmax(preds, dim=1)
        preds_logsoft = torch.log(preds_softmax)

        preds_softmax = preds_softmax.gather(1, labels.view(-1, 1))
        preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1))
        self.alpha = self.alpha.gather(0, labels.view(-1))
        loss = - torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) - (1-pt)**γ
        loss = torch.mul(self.alpha, loss.t())
        
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

    def __str__(self) -> str:
        return "Criterion: Focal Loss\n α = {}\n γ = {}".format(self.alpha, self.gamma)



class Evaluation:
    def __init__(self, device, categories, best_score=0) -> None:
        self.device = device
        self.categories = categories
        self.class_num = len(categories)
        self.best_score = best_score

    @torch.no_grad()
    def get_probs(self, model, data_loader):
        """ get predicted probabilities

        Returns:
            y: one-hot labels
            prob: predicted probabilities
        """
        model.eval()
        label = []
        prob = [[1 for _ in range(self.class_num)]]
        soft = nn.Softmax(dim=-1)

        for x, y in data_loader:
            x, y = x.to(self.device), y.to(self.device)
            z = model(x)
            p = soft(z)
            prob = np.concatenate((prob, p.to('cpu')), axis=-2)
            label = np.concatenate((label, y.to('cpu')), axis=-1)
        
        self.prob = prob[1::]
        self.y = label_binarize(label, classes=[i for i in range(self.class_num)])

        return self.y, self.prob

    @torch.no_grad()
    def make_predictions(self, model, data_loader):
        """ make predictions on datasets

        Returns:
            lists of labels and predictions
        """
        model.eval()
        self.label = []
        self.pred = []

        for x, y in data_loader:
            x, y = x.to(self.device), y.to(self.device)
            z = model(x)
            _, yhat = torch.max(z.data, 1)
            self.label = np.concatenate((self.label, y.to('cpu')), axis=-1)
            self = np.concatenate((self.pred, yhat.to('cpu')), axis=-1)

        return self.label, self.pred

    def get_acc(self):
        self.acc = metrics.accuracy_score(self.label, self.pred)
        return self.acc

    def get_bacc(self):
        self.b_acc = metrics.balanced_accuracy_score(self.label, self.pred)
        return self.b_acc

    def get_f1(self):
        self.f1_score = list(metrics.f1_score(self.label, self.pred, average=None))
        return self.f1_score

    def accuracies(self, model, data_loader):
        """
        returns accuracy and balanced accuracy
        """
        self.make_predictions(model, data_loader)
        self.get_acc()
        self.get_bacc()
        self.get_f1()

    def auc_scores(self):
        self.fpr = dict()
        self.tpr = dict()
        self.roc_auc = dict()

        # Compute ROC curve and ROC area for each class
        for i in range(self.class_num):
            self.fpr[i], self.tpr[i], _ = metrics.roc_curve(self.y[:, i], self.prob[:, i])
            self.roc_auc[i] = metrics.auc(self.fpr[i], self.tpr[i])

        # Compute micro-average ROC curve and ROC area (computed globally)
        self.fpr["micro"], self.tpr["micro"], _ = metrics.roc_curve(self.y.ravel(), self.prob.ravel())
        self.roc_auc["micro"] = metrics.auc(self.fpr["micro"], self.tpr["micro"])

        # Compute macro-average ROC curve and ROC area (simply average on each label)
        # aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([self.fpr[i] for i in range(self.class_num)]))
        # interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(self.class_num):
            mean_tpr += np.interp(all_fpr, self.fpr[i], self.tpr[i])
        # average it and compute AUC
        mean_tpr /= self.class_num

        self.fpr["macro"] = all_fpr
        self.tpr["macro"] = mean_tpr
        self.roc_auc["macro"] = metrics.auc(self.fpr["macro"], self.tpr["macro"])

        return self.fpr, self.tpr, self.roc_auc

    def get_report(self):
        self.report = metrics.classification_report(self.label, self.pred, target_names=self.categories)
        return self.report

    def get_confusion(self):
        """ calculate the confusion matrix

        Returns:
            DataFrame of confusion matrix: (i, j) - the number of samples with true label being i-th class and predicted label being j-th class.
        """
        c_matrix = metrics.confusion_matrix(self.label, self.pred)
        self.CMatrix = pd.DataFrame(c_matrix, columns=self.categories, index=self.categories)
        return self.CMatrix
    
    def complete_scores(self):
        self.label = np.argmax(self.y, axis=1)
        self.pred = np.argmax(self.prob, axis=1)
        self.get_report()
        self.get_acc()
        self.get_bacc()
        self.get_f1()
        self.auc_scores()
        self.get_confusion()

In [None]:
# weight balancing
class_weight = [1 for _ in range(ann_train.class_num)]
class_weight[1] /= 5

# Focal Loss
criterion = FocalLoss(alpha=class_weight, gamma=2, num_classes=ann_train.class_num)
log.logger.info(criterion)

In [None]:
eval_metrics = Evaluation(device, ann_train.categories)

# initial test
eval_metrics.get_probs(model, valid_loader)
eval_metrics.complete_scores()
log.logger.info("Initial Test: valid_acc = {:.4f}, valid_bacc = {:.4f}, f1_score = {}\nauc: {}".format(eval_metrics.acc, eval_metrics.b_acc, eval_metrics.f1_score, eval_metrics.roc_auc))


### Training

In [None]:
from torch import optim

In [None]:
# checkpoint
def load_train(log, model, optimizer, scheduler=None, pth_check=None):
    """ initialize or load training process from checkpoint
        从checkpoint加载训练状态，pth_check为None时，进行初始化

    Args:
        log (Logger)
        pth_check (str): path of training checkpoint file. e.g. 'ch_training.pth'. (Default: None - 初始化)

    Returns:
        start epoch
    """
    if pth_check == None:
        return 0
    
    pth_check = os.path.join('checkpoint', pth_check)
    log.logger.info("Reloading training checkpoint from {}".format(pth_check))
    checkpoint = torch.load(pth_check)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1

    if scheduler:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    return start_epoch

def load_eval(log, pth_check=None):
    """ initialize or load evaluation from checkpoint
        从checkpoint加载之前训练过程的模型表现，pth_check为None时，进行初始化

    Args:
        log (Logger)
        pth_check (str): path of eval checkpoint file. e.g. 'ch_eval.pth'

    Returns:
        costs, train_accs, test_accs, b_accs, f1_scores, auces
    """

    if pth_check == None:
        return [], [], [], [], [], []

    pth_check = os.path.join('checkpoint', pth_check)
    log.logger.info("Reloading evaluation checkpoint from {}".format(pth_check))
    checkpoint = torch.load(pth_check)

    costs = checkpoint['costs']
    train_accs = checkpoint['train_accs']
    test_accs = checkpoint['test_accs']
    b_accs = checkpoint['b_accs']
    f1_scores = checkpoint['f1_scores']
    auces = checkpoint['auces']
    
    return costs, train_accs, test_accs, b_accs, f1_scores, auces



def check_train(log, model, optimizer, epoch, scheduler=None, pth_check='ch_training.pth', verbose=False):
    """ save training checkpoint
        保存训练参数：model, epoch, optimizer, scheduler

    Args:
        log (Logger)
        pth_check (str): path to store the checkpoint.
    """
    check_dir = 'checkpoint'
    if not os.path.exists(check_dir):
      os.makedirs(check_dir)
    pth_check = os.path.join(check_dir, pth_check)

    if verbose:
        log.logger.info("Saving training checkpoint at {}".format(pth_check))
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    if scheduler:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()

    torch.save(checkpoint, pth_check)


def check_eval(log, costs, train_accs, test_accs, b_accs, f1_scores, auces, pth_check='ch_eval.pth', verbose=True):
    """ saving evaluation checkpoint
        保存训练过程的cost, accs, f1-score, auc

    Args:
        log (Logger)
        pth_eval (str): path to store the checkpoint.
        verbose: whether showing details
    """
    check_dir = 'checkpoint'
    if not os.path.exists(check_dir):
      os.makedirs(check_dir)
    pth_check = os.path.join(check_dir, pth_check)
    
    if verbose:
        log.logger.info("Saving training checkpoint at {}".format(pth_check))
    checkpoint = {
        'costs': costs,
        'train_accs': train_accs,
        'test_accs': test_accs,
        'b_accs': b_accs,
        'f1_scores': f1_scores,
        'auces': auces
    }

    if verbose:
        for key in checkpoint.keys(): 
            log.logger.info('{} = {}\n'.format(key, checkpoint[key]))

    torch.save(checkpoint, pth_check)

In [None]:
def train(model, train_loader, test_loader, max_epoch=100, test_period=5, early_threshold=5):
    """ train with a scheduler on learning rate

    Args:
        test_period (int): period of test
        early_threshold (int): threshold for early stoppig strategy, which pays attention to acc on test set
    """
    N_train = len(train_loader.dataset)
    patience = early_threshold
    
    # make sure the model is in the training mode
    model.train()

    global epoch
    for epoch in range(epoch, max_epoch):
        cost = 0
        correct = 0
        
        for x, y in train_loader:
            # setting GPU
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            z = model(x)
            loss = criterion(z, y)
            loss.backward()
            optimizer.step()
            
            cost += loss.item()
            _, yhat = torch.max(z.data, 1)
            correct += (yhat == y).sum().item()
            
        cost /= len(train_loader) # average cost
        costs.append(cost)
        

        # ! acc on train in train mode
        acc = correct / N_train
        train_accs.append(acc)

        # adjsut learning rate
        scheduler.step()

        if epoch % test_period == 0:
            eval_metrics.get_probs(model, test_loader)
            eval_metrics.complete_scores()

            test_accs.append(eval_metrics.acc)
            b_accs.append(eval_metrics.b_acc)
            f1_scores.append(eval_metrics.f1_score)
            auces.append(eval_metrics.roc_auc) 
            
            # early stopping strategy
            if eval_metrics.acc >= eval_metrics.best_score:
                eval_metrics.best_score = eval_metrics.acc
                patience = early_threshold
            else:
                patience -= 1
                if patience == 0:
                    break
            print("{:3d} cost: {:.4f}\ttrain_acc: {:.4f}\ttest_acc: {:.4f}\ttest_bacc: {:.4f}\tf1_score: {}".format(
                epoch, cost, acc, test_accs[-1], b_accs[-1], f1_scores[-1]))

            # change back to training mode    
            model.train()

In [None]:
# setting
## filename
model_file = 'dense201-{}.pkl'.format(model_id)

## hyper-params
init_lr = 1e-3
weight_decay = 1e-4
max_epoch = 100
test_period = 1
early_threshold = 40

## optimizer: https://pytorch.org/docs/stable/optim.html
optimizer = optim.AdamW(model.classifier.parameters(), lr=init_lr, betas=(0.9, 0.999), weight_decay=weight_decay)

## learning rate decay
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epoch, eta_min=0)

epoch = load_train(log, model, optimizer, scheduler)
costs, train_accs, test_accs, b_accs, f1_scores, auces = load_eval(log)

In [None]:
# # load from checkpoint
# epoch = load_train(log, model, optimizer, scheduler, pth_check='ch_training.pth')
# costs, train_accs, test_accs, b_accs, robust_accs, robust_baccs = load_eval(log, pth_check='ch_eval.pth')

log.logger.info("Training...\n{}\n{}".format(optimizer, scheduler))
warm_up = 10
train(model, train_loader, valid_loader, warm_up, test_period, early_threshold)
# unfreeze layers
for param in model.parameters():
    param.requires_grad = True
log.logger.info("End warming up")
train(model, train_loader, valid_loader, max_epoch, test_period, early_threshold)

In [None]:
# # save checkpoint
# check_train(log, model, optimizer, epoch, scheduler)
# check_eval(log, costs, train_accs, test_accs, b_accs, f1_scores, auces)

In [None]:
# save model
save_model(model, name=model_file)
log.logger.info("Filename: {}\ncosts = {}\ntrain_accs = {}\ntest_acc = {}\nb_accs = {}\nf1_scores = {}\nauces = {}".format(
    model_file, costs, train_accs, test_accs, b_accs, f1_scores, auces))

## Evaluation

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

def draw_confusion(cf_matrix):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=100)
    # based on specificity
    sns.heatmap(cf_matrix/np.sum(cf_matrix), ax=axes[0], annot=True, fmt='.2%', cmap='Blues', annot_kws={"size":8}, cbar=True)
    # based on sensitivity
    sns.heatmap(cf_matrix/np.sum(np.array(cf_matrix), axis=1, keepdims=True), ax=axes[1], annot=True, fmt='.2%', cmap='Blues', annot_kws={"size":8})
    for ax, title in zip(axes, ['specificity', 'sensitivity']):
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title(title)


def performance_evaluation(model, dataset, info):
    """
    make evaluation report
    """
    data_loader = data.DataLoader(dataset, batch_size=200, shuffle=False, num_workers=4)
    eval_metrics.get_probs(model, data_loader)
    eval_metrics.complete_scores()
    log.logger.info("{}\n{}\nAUC: {}".format(info, eval_metrics.report, eval_metrics.roc_auc))
    # confusion matrix
    cf_matrix = eval_metrics.get_confusion()
    draw_confusion(cf_matrix)


In [None]:
del train_loader, valid_loader

performance_evaluation(model, test_data, info='Performance on Test Set')

In [None]:
# acc
plt.figure(figsize=(8, 6), dpi=100)
plt.plot(train_accs, label='Train')
plt.plot(test_accs, label='Test')
plt.plot(b_accs, label='Balanced on Test')
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.ylim(0, 1.)
plt.legend()
plt.grid(axis='y')
plt.title('Accuracy')
plt.show()

# f1
f1_scores = np.array(f1_scores)
plt.figure(figsize=(8, 6), dpi=100)
for i in range(ann_train.class_num):
    plt.plot(f1_scores[:, i], label=ann_train.categories[i])

plt.xlabel('Epoch')
plt.ylabel('F1')
plt.ylim(0, 1.)
plt.legend()
plt.grid(axis='y')
plt.title('F1 Score')
plt.show()

# AUC
plt.figure(figsize=(8, 6), dpi=100)
for i in range(ann_train.class_num):
    plt.plot([auc[i] for auc in auces], label=ann_train.categories[i])
plt.plot([auc['micro'] for auc in auces], label='micro')
plt.plot([auc['macro'] for auc in auces], label='macro')

plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.ylim(0.5, 1.)
plt.legend()
plt.grid(axis='y')
plt.title('AUC')
plt.show()