In [1]:
import os
import time
import random
import pickle


import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import seaborn as sns
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import torch
from tqdm.notebook import tqdm, trange

In [None]:
# only for english!!!
LABELS = [i for i in range(26)]

In [2]:
def seed_torch(seed=42):
    '''Function for locking the random seed. Used to compare models in equal conditions.'''
    
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['PL_GLOBAL_SEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(mode=True)

In [None]:
def recolor(picture_path):
    '''Replaces line color with magenta and makes backgroud black.'''
    
    picture = Image.open(picture_path)
    width, height = picture.size
    for x in range(width):
        for y in range(height):
            current_color = picture.getpixel((x, y))
            if current_color[-1] != 0:
                picture.putpixel((x,y), (255, 0, 255, 255))
            else:
                picture.putpixel((x,y), (0, 0, 0, 255))
    
    picture.save(picture_path)

In [3]:
def imshow(inp, title=None):
    '''Shows image of given tensor.'''
    
    inp = inp.permute(1, 2, 0).numpy()
    inp = np.clip(inp, 0, 1)
    #plt.figure(figsize=(16, 9))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

In [49]:
def train(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, device='cuda', savename=None, verbose=2, class_labels=LABELS):
    '''Performs basic model training
        verbose: 0 for complete abscence of text during training, 1 for showing only main pbar, 2 for full info
    '''
    
    losses = {'train': [], 'val': []}
    accuracies = {'train': [], 'val': []}
    roc_aucs = {'train': [], 'val': []}
    
    best_acc = 0
    best_roc_auc = 0
    best_model_wts = model.state_dict()
    
    model.to(device)
    pbar = trange(num_epochs, desc='Epoch', disable=(verbose == 0))
    
    start_time = time.time()
    for epoch in pbar:
        epoch_probas, epoch_labels = {'train': torch.tensor([]), 'val': torch.tensor([])}, {'train': torch.tensor([]), 'val': torch.tensor([])}
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0
            running_corrects = 0
            
            train_pbar = tqdm(dataloaders[phase], leave=True, desc=f'{phase} iter', disable=(verbose < 2))
            for data in train_pbar:
                
                inputs, labels = data
                inputs, labels = inputs.cuda(), labels.cuda()
                
                if phase == 'train':
                    optimizer.zero_grad()
                    outputs = model(inputs)
                else:
                    with torch.inference_mode():
                        outputs = model(inputs)
                
                preds = torch.argmax(outputs, -1)
                loss = criterion(outputs, labels)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                
                epoch_probas[phase] = torch.cat((epoch_probas[phase], outputs.cpu().detach().float()))
                epoch_labels[phase] = torch.cat((epoch_labels[phase], labels.cpu()))
                
                running_loss += loss.item()
                running_corrects += int(torch.sum(preds == labels).item())
            
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            epoch_roc_auc = roc_auc_score(epoch_labels[phase].numpy(), epoch_probas[phase].numpy(), multi_class='ovo', labels=class_labels)
            
            losses[phase].append(epoch_loss)
            accuracies[phase].append(epoch_acc)
            roc_aucs[phase].append(epoch_roc_auc)
            pbar.set_description(f'{phase} Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.4f} RocAuc: {epoch_roc_auc:.4f}')
            if epoch_roc_auc > best_roc_auc:
                best_roc_auc = epoch_roc_auc
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                if savename:
                    best_model_wts = model.state_dict()
                    torch.save(model.state_dict(), 'models/' + savename)
    
    time_elapsed = time.time() - start_time
    print(f'Training complete in {time_elapsed // 60}m {(time_elapsed % 60):.4f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    print(f'Best val Roc-Auc: {best_roc_auc:.4f}')
    
    model.load_state_dict(best_model_wts)
    
    return model, losses, accuracies, roc_aucs


In [7]:
def evaluate(model, testloader, device='cuda', class_labels=LABELS):
    '''Evaluates given model. Returns accuracy, roc_auc_score, presicion and recall.'''
    
    model.eval()
    objects_processed = 0
    runninig_correct = []
    probas_val, labels_val, preds_val = torch.tensor([]), torch.tensor([]), torch.tensor([])
    for data in testloader:
        
        inputs, labels = data[0].to(device), data[1].to(device)
        objects_processed += len(labels)
        
        outputs = model(inputs)
        preds = torch.argmax(outputs, -1)
        runninig_correct.append(int(torch.sum(preds == labels)))
        
        preds_val = torch.cat((preds_val, preds.cpu()))
        probas_val = torch.cat((probas_val, outputs.cpu().detach().float()))
        labels_val = torch.cat((labels_val, labels.cpu()))
    
    return (sum(runninig_correct) / objects_processed,
            roc_auc_score(labels_val.numpy(), probas_val.numpy(), multi_class='ovo', labels=class_labels),
            precision_score(labels_val.numpy(), preds_val.numpy(), average='macro', labels=class_labels),
            recall_score(labels_val.numpy(), preds_val.numpy(), average='macro'))

In [8]:
def train_plot(losses, accuracies, roc_aucs):
    '''Makes plots of loss, accuracy and roc auc score changes with epochs'''
    
    sns.set(style='darkgrid', font_scale=1.4)
    # losses
    plt.figure(figsize=(16, 9))
    plt.plot(losses['train'], label='train', linewidth=4, color='blue')
    plt.plot(losses['val'], label='validation', linewidth=4, color='red')
    plt.axhline(y=min(losses['train']), linewidth=3, linestyle='--', color='blue')
    plt.axhline(y=min(losses['val']), linewidth=3, linestyle='--', color='red')
    plt.axvline(x=losses['train'].index(min(losses['train'])), label='min on train', linewidth=3, linestyle='--', color='blue')
    plt.axvline(x=losses['val'].index(min(losses['val'])), label='min on validation', linewidth=3, linestyle='--', color='red')
    plt.title('Loss')
    plt.legend()
    plt.show()
    
    # accuracies
    plt.figure(figsize=(16, 9))
    plt.plot(accuracies['train'], label='train', linewidth=4, color='blue')
    plt.plot(accuracies['val'], label='validation', linewidth=4, color='red')
    plt.axhline(y=max(accuracies['train']), linewidth=3, linestyle='--', color='blue')
    plt.axhline(y=max(accuracies['val']), linewidth=3, linestyle='--', color='red')
    plt.axvline(x=accuracies['train'].index(max(accuracies['train'])), label='max on train', linewidth=3, linestyle='--', color='blue')
    plt.axvline(x=accuracies['val'].index(max(accuracies['val'])), label='max on validation', linewidth=3, linestyle='--', color='red')
    plt.title('Accuracy')
    plt.legend()
    plt.show()
    
    # roc auc scores
    plt.figure(figsize=(16, 9))
    plt.plot(roc_aucs['train'], label='train', linewidth=4, color='blue')
    plt.plot(roc_aucs['val'], label='validation', linewidth=4, color='red')
    plt.axhline(y=max(roc_aucs['train']), linewidth=3, linestyle='--', color='blue')
    plt.axhline(y=max(roc_aucs['val']), linewidth=3, linestyle='--', color='red')
    plt.axvline(x=roc_aucs['train'].index(max(roc_aucs['train'])), label='max on train', linewidth=3, linestyle='--', color='blue')
    plt.axvline(x=roc_aucs['val'].index(max(roc_aucs['val'])), label='max on validation', linewidth=3, linestyle='--', color='red')
    plt.title('Roc Auc Score')
    plt.legend()
    plt.show()

In [None]:
colors = ['dimgray', 'indianred', 'red', 'sienna', 'sandybrown', 'darkorange', 'moccasin', 'gold', 'darkkhaki', 'yellowgreen',\
                'seagreen', 'turquoise', 'aqua', 'deepskyblue', 'midnightblue', 'blue', 'darkviolet', 'violet', 'hotpink']
    

In [1]:
def comparison_plots(losses, accuracies, model_names, num_epochs=25, titles=['Losses on val', 'Accuracies on val'], linestyle='-'):
    '''Shows comaparison plots of accuracies and losses for all models.'''
    
    sns.set(style='darkgrid', font_scale=1.4)
    colormap = plt.cm.seismic
    colors = [colormap(int(i * colormap.N / len(losses))) for i in range(len(losses))]
    
    # losses
    plt.figure(figsize=(16, 9))
    
    for i, model in enumerate(losses):
        plt.plot(model['val'], label=model_names[i], linewidth=4, color=colors[i], linestyle=linestyle)
        plt.xticks(np.arange(0, num_epochs, 1), labels=[str(i + 1) for i in range(num_epochs)])
    plt.title(titles[0])
    plt.legend()
    plt.show()
    
    # accuracies
    plt.figure(figsize=(16, 9))
    for i, model in enumerate(accuracies):
        plt.plot(model['val'], label=model_names[i], linewidth=4, color=colors[i], linestyle=linestyle)
        plt.xticks(np.arange(0, num_epochs, 1), labels=[str(i + 1) for i in range(num_epochs)])
    plt.title(titles[1])
    plt.legend()
    plt.show()

In [None]:
def save(data, file_names):
    for name, data in zip(file_names, data):
        with open('pickled/' + name + '.pickle', 'wb') as handler:
            pickle.dump(data, handler, protocol=pickle.HIGHEST_PROTOCOL)


def load(file_names):
    output = []
    for name in file_names:
        with open('pickled/' + name + '.pickle', 'rb') as handler:
            output.append(pickle.load(handler))
    
    return output