# Import the libraries

In [None]:
import os
from time import sleep
from glob import glob
import random
from tqdm import tqdm
import copy
import ntpath

import numpy as np
from imageio import imread
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, plot_confusion_matrix, matthews_corrcoef, classification_report,confusion_matrix, accuracy_score, balanced_accuracy_score, cohen_kappa_score, f1_score,  precision_score, recall_score

import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torchvision.models as models
from torchvision.models import resnet18

from res_model import *

The file path of all the datasets

In [None]:
data_path = {
        "Ace_20": "/bigdata/haicu/venkat31/hemat/Acevedo/", # Acevedo_20 Dataset
        "Mat_19": "/bigdata/haicu/venkat31/hemat/Matek/", # Matek_19 Dataset
        "WBC1": "/bigdata/haicu/venkat31/hemat/WBC1/" # WBC1 dataset
    }

All the labels in the dataset

In [1]:
label_map_all = {
        'basophil': 0,
        'eosinophil': 1,
        'erythroblast': 2,
        'myeloblast' : 3,
        'promyelocyte': 4,
        'myelocyte': 5,
        'metamyelocyte': 6,
        'neutrophil_banded': 7,
        'neutrophil_segmented': 8,
        'monocyte': 9,
        'lymphocyte_typical': 10
    }

Read the metadata

In [None]:
metadata = pd.read_csv('./metadata.csv')
print(metadata)

A dataframe for the 3 different datasets

In [None]:
ace_metadata=metadata.loc[metadata['dataset']=='Ace_20'].reset_index(drop = True)
mat_metadata=metadata.loc[metadata['dataset']=='Mat_19'].reset_index(drop = True)
wbc_metadata=metadata.loc[metadata['dataset']=='WBC1'].reset_index(drop = True)

## Data curation

In [None]:
crop_Ace20=250
crop_Mat19=345
crop_WBC1=288

dataset_image_size = {
    "Ace_20":crop_Ace20,   #250,
    "Mat_19":crop_Mat19,   #345, 
    "WBC1":crop_WBC1,   #288,  
}

In [None]:
example_metadata=metadata
source_domains=['Ace_20', 'Mat_19']
source_index = example_metadata.dataset.isin(source_domains)
example_metadata = example_metadata.loc[source_index,:].copy().reset_index(drop = True)

In [None]:
test_fraction=0.2 #of the whole dataset
val_fraction=0.125 #of 0.8 of the dataset (corresponds to 0.1 of the whole set)

In [None]:
train_index, test_index, train_label, test_label = train_test_split(
    example_metadata.index,
    example_metadata.label + "_" + example_metadata.dataset,
    test_size=test_fraction,
    random_state=0, 
    shuffle=True,
    stratify=example_metadata.label
    )
example_metadata.loc[test_index, 'set']='test'
train_val_metadata=example_metadata.loc[train_index]

In [None]:
train_index, val_index, train_label, val_label = train_test_split(
    train_val_metadata.index,
    train_val_metadata.label + "_" + train_val_metadata.dataset,
    test_size=val_fraction,
    random_state=0, 
    shuffle=True, 
    stratify=train_val_metadata.label
    )
example_metadata.loc[val_index, 'set']='val'

In [None]:
class DatasetGenerator(Dataset):

    def __init__(self, 
                metadata, 
                reshape_size=64, 
                label_map=[],
                dataset = [],
                transform=None,
                selected_channels = [0,1,2],
                dataset_image_size=None):

        self.metadata = metadata.copy().reset_index(drop = True)
        self.label_map = label_map
        self.transform = transform
        self.selected_channels = selected_channels
    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        ## get image and label
        dataset =  self.metadata.loc[idx,"dataset"]
        crop_size = dataset_image_size[dataset]
        
        h5_file_path = self.metadata.loc[idx,"file"]
        image= imread(h5_file_path)[:,:,self.selected_channels]
        image = image / 255.
        h1 = (image.shape[0] - crop_size) /2
        h1 = int(h1)
        h2 = (image.shape[0] + crop_size) /2
        h2 = int(h2)
        
        w1 = (image.shape[1] - crop_size) /2
        w1 = int(w1)
        w2 = (image.shape[1] + crop_size) /2
        w2 = int(w2)
        image = image[h1:h2,w1:w2, :]
        image = np.transpose(image, (2, 0, 1))
        label = self.metadata.loc[idx,"label"]
        

        # map numpy array to tensor
        image = torch.from_numpy(copy.deepcopy(image)) 
        image = image.float()
        
        if self.transform:
            image = self.transform(image) 
        
        label = self.label_map[label]
        label = torch.tensor(label).long()
        return image.float(),  label

In [None]:
resize=224 #image pixel size
number_workers=0

random_crop_scale=(0.8, 1.0)
random_crop_ratio=(0.8, 1.2)

mean=[0.485, 0.456, 0.406] #values from imagenet
std=[0.229, 0.224, 0.225] #values from imagenet

bs=32 #batchsize

In [None]:
normalization = torchvision.transforms.Normalize(mean,std)

train_transform = transforms.Compose([ 
        normalization,
        transforms.Resize(resize)
        #transforms.RandomResizedCrop(resize, scale=random_crop_scale, ratio=random_crop_ratio),
        #transforms.RandomHorizontalFlip(),
        #transforms.RandomVerticalFlip()
        #transforms.RandomEqualize(p=0.8)
])

val_transform = transforms.Compose([ 
        normalization,
        transforms.Resize(resize)])

test_transform = transforms.Compose([ 
        normalization,
        transforms.Resize(resize),
        transforms.Resize(resize),
        transforms.RandomApply(torch.nn.ModuleList([
            transforms.ColorJitter(brightness=0.5, hue=0.1),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.5, 1))
        ]), p=0.5),
        transforms.RandomAdjustSharpness(sharpness_factor=0.2, p=0.5)])

#dataset-creation

train_dataset = DatasetGenerator(example_metadata.loc[train_index,:], 
                                 reshape_size=resize, 
                                 dataset = source_domains,
                                 label_map=label_map_all, 
                                 transform = train_transform,
                                 )
val_dataset = DatasetGenerator(example_metadata.loc[val_index,:], 
                                 reshape_size=resize, 
                                 dataset = source_domains,
                                 label_map=label_map_all, 
                                 transform = val_transform,
                                 )

test_dataset = DatasetGenerator(example_metadata.loc[test_index,:], 
                                 reshape_size=resize, 
                                 dataset = source_domains,
                                 label_map=label_map_all, 
                                 transform = test_transform,
                                 )
train_loader = DataLoader(
    test_dataset, batch_size=bs, shuffle=True, num_workers=number_workers)
valid_loader = DataLoader(
    val_dataset, batch_size=bs, shuffle=True, num_workers=number_workers)
test_loader = DataLoader(
    train_dataset, batch_size=bs, shuffle=False, num_workers=number_workers)

In [None]:
epochs=10 # max number of epochs
lr=1e-5 # learning rate
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_classes = 11
model = EqRes(n_rot=2, n_filter=32, n_class=num_classes)
#model = torch.nn.DataParallel(model) 
model.to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, 
                                                steps_per_epoch=len(train_loader), 
                                               epochs=epochs+1, cycle_momentum=False)

In [None]:
model_save_path='theta2_filter32_fullaug' #path where model with best f1_macro should be stored

#running variables
epoch=0
update_frequency=5 # number of batches before viewed acc and loss get updated
counter=0 #counts batches
f1_macro_best=0 #minimum f1_macro_score of the validation set for the first model to be saved
loss_running=0
acc_running=0
val_batches=0

y_pred=torch.tensor([], dtype=int)
y_true=torch.tensor([], dtype=int)
y_pred=y_pred.to(device)
y_true=y_true.to(device)


#Training

for epoch in range(0, epochs):
    #training
    model.train()
    
    with tqdm(train_loader) as tepoch:   
        for i, data in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch+1}")
            counter+=1

            x, y = data
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()

            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            scheduler.step()

            logits = torch.softmax(out.detach(), dim=1)
            predictions = logits.argmax(dim=1)
            acc = accuracy_score(y.cpu(), predictions.cpu())
            
            if counter >= update_frequency:
                tepoch.set_postfix(loss=loss.item(), accuracy=acc.item())
                counter=0
                
    #validation       
    model.eval()
    with tqdm(valid_loader) as vepoch: 
        for i, data in enumerate(vepoch):
            vepoch.set_description(f"Validation {epoch+1}")

            x, y = data
            x, y = x.to(device), y.to(device)

            out = model(x)
            loss = criterion(out, y)
            
            logits = torch.softmax(out.detach(), dim=1)
            predictions = logits.argmax(dim=1)
            y_pred=torch.cat((y_pred, predictions), 0)
            y_true=torch.cat((y_true, y), 0)
            
            acc = accuracy_score(y_true.cpu(), y_pred.cpu())
            
            loss_running+=(loss.item()*len(y))
            acc_running+=(acc.item()*len(y))
            val_batches+=len(y)
            loss_mean=loss_running/val_batches
            acc_mean=acc_running/val_batches
            
            vepoch.set_postfix(loss=loss_mean, accuracy=acc_mean)
            
        f1_micro=f1_score(y_true.cpu(), y_pred.cpu(), average='micro')
        f1_macro=f1_score(y_true.cpu(), y_pred.cpu(), average='macro')
        print(f'f1_micro: {f1_micro}, f1_macro: {f1_macro}')  
        if f1_macro > f1_macro_best:
            f1_macro_best=f1_macro
            torch.save(model.state_dict(), model_save_path)
            print('model saved')
        
        #reseting running variables
        loss_running=0
        acc_running=0
        val_batches=0
            
        y_pred=torch.tensor([], dtype=int)
        y_true=torch.tensor([], dtype=int)
        y_pred=y_pred.to(device)
        y_true=y_true.to(device)
            
        
    
print('Finished Training')

In [None]:
model.load_state_dict(torch.load('theta2_filter32_fullaug_new'))

In [None]:
metadata_test=example_metadata.loc[test_index,:]
ace_metadata_test=metadata_test.loc[metadata_test['dataset']=='Ace_20'].reset_index(drop = True)
mat_metadata_test=metadata_test.loc[metadata_test['dataset']=='Mat_19'].reset_index(drop = True)

In [None]:
def prediction(metadata=metadata_test, 
               source_domains=['Ace_20', 'Mat_19'], label_map=label_map_all):



    pred_dataset = DatasetGenerator(metadata, 
                                 reshape_size=resize, 
                                 dataset = source_domains,
                                 label_map=label_map, 
                                 transform = test_transform,
                                 )
    
    pred_loader = DataLoader(pred_dataset, 
                             batch_size=1, 
                             shuffle=False, 
                             num_workers=6
                            )
    n=len(pred_loader)
    model.eval()
    preds=torch.tensor([], dtype=int)
    preds=preds.to(device)
    prediction=torch.tensor([])
    prediction=prediction.to(device)
    for i, data in enumerate(pred_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        x, y = data
        x, y = x.to(device), y.to(device)
        out = model(x)
        logits = torch.softmax(out.detach(), dim=1)
        prediction = torch.cat((prediction, logits), 0)
        predic = logits.argmax(dim=1)
        preds=torch.cat((preds, predic), 0)

    preds=preds.cpu()
    preds=preds.detach().numpy()
    np.save('preds', preds)
    y_pred = [label_map_reverse[p] for p in  preds]
    y_true=metadata['label']
    return y_true, y_pred, preds

def classification_complete_report(y_true, y_pred ,labels = None  ): 
    print(classification_report(y_true, y_pred, labels = None))
    print(15*"----")
    print("matthews correlation coeff: %.4f" % (matthews_corrcoef(y_true, y_pred)) )
    print("Cohen Kappa score: %.4f" % (cohen_kappa_score(y_true, y_pred)) )
    print("Accuracy: %.4f & balanced Accuracy: %.4f" % (accuracy_score(y_true, y_pred), balanced_accuracy_score(y_true, y_pred)) )
    #print("macro F1 score: %.4f & micro F1 score: %.4f" % (f1_score(y_true, y_pred, average = "macro"), f1_score(y_true, y_pred, average = "micro")) )
    print("macro Precision score: %.4f & micro Precision score: %.4f" % (precision_score(y_true, y_pred, average = "macro"), precision_score(y_true, y_pred, average = "micro")) )
    print("macro Recall score: %.4f & micro Recall score: %.4f" % (recall_score(y_true, y_pred, average = "macro"), recall_score(y_true, y_pred, average = "micro")) )
    print(labels)
    cm = confusion_matrix(y_true, y_pred,labels= labels, normalize='true')
    fig, ax = plt.subplots(figsize=(10, 10)) #plot size
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical', ax=ax, include_values=False, colorbar=False)
    
    plt.show()
    print(15*"----")

In [None]:
y_true, y_pred, preds=prediction(metadata= ace_metadata_test, source_domains=['Ace_20'])
classification_complete_report(y_true, y_pred, labels=label_list_all)

In [None]:
y_true, y_pred, preds=prediction(metadata= mat_metadata_test, source_domains=['Mat_19'])
classification_complete_report(y_true, y_pred, labels=label_list_all)

In [None]:
y_true, y_pred, preds=prediction(metadata=wbc_metadata, source_domains=['WBC1'], label_map=label_map_pred)

In [None]:
outputdata=wbc_metadata.drop(columns=['file', 'label', 'dataset', 'set', 'mean1', 'mean2', 'mean3'])
outputdata['Label']=y_pred
outputdata['LabelID']=preds
'''
for i in range(len(y_pred)):
    outputdata['LabelID'].loc[i]=y_pred[i]
    outputdata['Label'].loc[i]=label_map_reverse[y_pred[i]]
'''
outputdata.to_csv('submission_new2_aug.csv')
print(outputdata)