In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
import random

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import models


def display_images(img_ids, annotations_df, img_rootdir, label2name_json, rows=3, cols=3):
    """  
    Displays random images in a grid of (rows)*(cols) with their labels
    
    Args:
        img_ids (list): A list of the IDs of the images (= the file name of the images)
        annotations_df (pandas.DataFrame): Contains the img_ids in first column
            and the associated annotation in the second column
        img_rootdir (string): Root directory of the image files
        label2name_json (json): The mapping between each disease code and the real disease name.
        rows (int): Number of rows of the image grid (default=3)
        cols (int): Number of columns of the image grid (default=3)
    """
    with open(label2name_json, "r") as file:
        data = file.read()
    label2name = json.loads(data)
        
    random.shuffle(img_ids)
    img_ids = img_ids[:rows*cols]
    
    fig, axs = plt.subplots(rows, cols, figsize=(20,20))
    for i, img_id in enumerate(img_ids):
        img_path = os.path.join(img_rootdir, img_id)
        img = Image.open(img_path)
        label = annotations_df.loc[annotations_df.image_id==img_id].label.item()
        label_name = label2name[str(label)]
        axs.ravel()[i].imshow(img)
        axs.ravel()[i].set_title(" ".join([str(label),label_name]))
        axs.ravel()[i].set_xlabel(img_id)
        axs.ravel()[i].get_yaxis().set_visible(False)
    plt.show()

    
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 53 * 53, 1000)
        self.fc2 = nn.Linear(1000, 256)
        self.fc3 = nn.Linear(256, 5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 53 * 53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
def create_model(device, pretrained=True):
    """  
    Creates a Resnet50 model with all layers (except fc layer) frozen,
        for 5-class classification
    
    Args:
        device (string): Device, where to put the model
        pretrained (boolean): If the model should be loaded pretrained on ImageNet
            (default=True)

    Returns:
        torch.nn.Module
    """
    model = models.resnet50(pretrained=pretrained)
    
    # freeze all params
    for param in model.parameters():
        param.requires_grad = False
    
    # replace last layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 5)
    
    return model.to(device)

def create_model_layer4(device, pretrained=True): 
    """  
    Creates a Resnet50 model with all layers (except layer4 and fc layer) frozen,
        for 5-class classification
    
    Args:
        device (string): Device, where to put the model
        pretrained (boolean): If the model should be loaded pretrained on ImageNet
            (default=True)

    Returns:
        torch.nn.Module
    """
    model = models.resnet50(pretrained=pretrained)   
    
    # freeze all params and unfreeze params of layer4
    for param in model.parameters():
        param.requires_grad = False
    for param in model.layer4.parameters():
        param.requires_grad = True
    
    #replace last layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 5)

    return model.to(device)
    

class CassavaDataset(Dataset):
    def __init__(self, annotations_df, root_dir, transforms=None, albums=None):
        """
        Args:
            annotations_df (pandas.DataFrame): Contains the img_ids in first 
                column and the associated annotation in the second column
            root_dir (string): Root directory of the image files
            transforms (callable, optional): Optional torchvision.transforms 
                to be applied on a sample
            albums (callable, optional): Optional albumentations to be applied 
                on a sample
        """
        self.annotations_df = annotations_df
        self.img_ids = np.asarray(self.annotations_df.iloc[:,0])
        self.labels = np.asarray(self.annotations_df.iloc[:,1])
        self.root_dir = root_dir
        self.transforms = transforms
        self.albums = albums
        
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_filepath = os.path.join(self.root_dir, str(img_id))
        img = Image.open(img_filepath)
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        if self.transforms is not None: # works for torchvision.transforms
            img = self.transforms(img)
        
        if self.albums is not None: # works for albumentations
            img = np.array(img)
            img = self.albums(image=img)["image"]
            
        return (img, label)
    
    def __len__(self):
        return len(self.img_ids)
    

class CassavaTestDataset(Dataset):
    def __init__(self, img_ids, root_dir, transforms=None, albums=None):
        """
        Args:
            img_ids (list): A list of the IDs of the images (= the file name of the images)
            root_dir (string): Root directory of the image files
            transforms (callable, optional): Optional torchvision.transforms 
                to be applied on a sample
            albums (callable, optional): Optional albumentations to be applied 
                on a sample
        """
        self.root_dir = root_dir
        self.img_ids = img_ids
        self.transforms = transforms
        self.albums = albums
        
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_filepath = os.path.join(self.root_dir, str(img_id))
        img = Image.open(img_filepath)
        
        if self.transforms is not None: # works for torchvision.transforms
            img = self.transforms(img)
        
        if self.albums is not None: # works for albumentations
            img = np.array(img)
            img = self.albums(image=img)["image"]
            
        return (img_id, img)
    
    def __len__(self):
        return len(self.img_ids)

    
def check_dataloaders(dataloaders, dataset_type="train", index=0, std=1, mean=0):
    """  
    Grabs one element of a dataloader and...
        ...for CassavaDataset-dataloader: prints image and its annotation
        ...for CassavaTestDataset-dataloader: prints image and its ID
    
    Args:
        dataloaders (dictionary): Contains DataLoader obect(s) named "train",
            "val" or "test"
        dataset_type (string): Type of the dataloader ("train",
            "val" or "test") (default="train")
        index (int): Defines which index of the dataloader should be plotted
            (requirement: index < batch size)
        std (float): Standard deviation chosen in image transform or albumentation
            of the dataloader (default=1)
        mean (float): Mean chosen in image transform or albumentation
            of the dataloader (default=1)
    """
    dataiter = iter(dataloaders[dataset_type])
    if dataset_type == "test":
        img_id, img = dataiter.next()
    else:
        img, label = dataiter.next()
    
    # get first image in the batch
    im = img.squeeze().numpy()
    # im is already of shape [Height x Width x Channel] if only one file in batch
    # if only one image in batch: don't do next step
    if len(dataloaders[dataset_type]) > 1: im = im[index,:,:,:]  
    
    # permute to [Channel x Height x Width]
    im = im.transpose((1,2,0))
    
    # unnormalize (reverse normalization)
    im = im * std
    im = im + mean
    im[im<0] = 0
       
    if dataset_type == "test":
        print ("ID:", str(img_id[index]))
    else:
        print("Label:", label[index])
        
    plt.imshow(im)
    
     
class Engine():
    def __init__(self, model, optimizer, device):
        """
        Args:
            model (torch.nn.module): Prepared model for task
            optimizer (torch.optim): Optimizer for task
            device (string): Device, where the model is stored
        """
        self.model = model
        self.optimizer = optimizer
        self.device = device
    
    @staticmethod
    def loss_fn(outputs, targets):
        return nn.CrossEntropyLoss()(outputs, targets)
    
    def train(self, dataloader):
        """
        Trains the model for one epoch on train dataloader
        
        Args:
            dataloader (torch.utils.data.DataLoader): Train dataloader created
                with CassavaDataset
                
        Returns:
            float: Training loss
        """
        self.model.train()
        running_loss = 0.0
        datasize = 0
        for inputs, targets in dataloader:
            self.optimizer.zero_grad()
            inputs = inputs.to(self.device, dtype=torch.float)
            targets = targets.long().to(self.device)
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()
        return running_loss / len(dataloader)
    
    def evaluate(self, dataloader):
        """
        Evaluates the model on validation dataloader
        
        Args:
            dataloader (torch.utils.data.DataLoader): Validation dataloader
            created with CassavaDataset
                
        Returns:
            float: Training loss
            float: Accuracy
        """
        self.model.eval()
        running_loss = 0.0
        running_corrects = 0
        dataset_size = 0
        for inputs, targets in dataloader:
            inputs = inputs.to(self.device, dtype=torch.float)
            targets = targets.long().to(self.device)
            outputs = self.model(inputs)
            _, preds = torch.max(outputs.data, 1)
            running_corrects += torch.sum(preds == targets.data)
            loss = self.loss_fn(outputs, targets)
            running_loss += loss.item()
            dataset_size += inputs.size(0)
        return running_loss / len(dataloader), running_corrects / dataset_size
    
    def predict(self, dataloader):
        """
        Returns predictions by one model on validation or test dataloader
        
        Args:
            dataloader (torch.utils.data.DataLoader): Validation or test dataloader
            created with CassavaTestDataset
                
        Returns:
            pandas.DataFrame: Dataframe with image IDs as first column and
                predicted classes as second label
        """
        self.model.eval()
        inference_df = pd.DataFrame()
        for img_id, inputs in dataloader:
            inputs = inputs.to(self.device)
            outputs = self.model(inputs)
            _, preds = torch.max(outputs.data, 1)
            iter_df = pd.DataFrame({"image_id": list(img_id), "label": preds.tolist()})
            inference_df = inference_df.append(iter_df, ignore_index=True)
        return inference_df
    
    def ensemble_predict(self, dataloader, models):
        """
        Returns ensembled predictions by multiple models on validation or test dataloader
        
        Args:
            dataloader (torch.utils.data.DataLoader): Validation or test dataloader
            created with CassavaTestDataset
            models (dictionary of torch.nn.module): Dictionary of multiple models
                
        Returns:
            pandas.DataFrame: Dataframe with image IDs as first column and
                predicted classes as second label
        """
        model_keys = list(models.keys())
        
        inference_df = pd.DataFrame()
        for img_id, img in dataloader:
            inputs = img.to(self.device)
            for count, key in enumerate(model_keys):
                if count == 0:
                    outputs = models[key](inputs).data
                else:
                    outputs += models[key](inputs).data
            outputs /= len(models)
            _, preds = torch.max(outputs, 1)
            iter_df = pd.DataFrame({"image_id": list(img_id), "label": preds.tolist()})
            inference_df = inference_df.append(iter_df, ignore_index=True)
        return inference_df