In [None]:
import os
import random
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.autograd import Variable

from matplotlib import pyplot as plt

#%matplotlib inline

In [None]:
train_data_file = "data_HC/menon_train_data.csv"
train_label_file = "data_HC/menon_train_label.csv"
valid_data_file = "data_HC/menon_validation_data.csv"
valid_label_file = "data_HC/menon_validation_label.csv"
test_data_file = "data_HC/lukowski_test_data.csv"


In [None]:
classes_tree = pd.read_csv("data_HC/classes_tree.csv").fillna("Unknown")
classes_tree

In [None]:
n_levels = classes_tree.shape[1]

BATCH_SIZE = 1024
N_EPOCHS = 60
LR = 2e-5

In [None]:
def BuildTree(classes_tree):
    n_levels = classes_tree.shape[1]
    label_tree = dict()
    lv_name = ['Level ' + str(i) if i != 0 else 'Root' for i in range(n_levels)]
    curr_path = []
    for i in range(classes_tree.shape[0]):
        for j in range(len(lv_name)):
            class_name = classes_tree[lv_name[j]][i]
            if class_name != 'Unknown':
                curr_node = label_tree
                curr_path = curr_path[:j]
                for p in curr_path:
                    curr_node = curr_node[p]
                curr_node[class_name] = dict()
                curr_path.append(class_name)
    return label_tree
label_tree = BuildTree(classes_tree)

In [None]:
label_tree

In [None]:
def Get_level_label(n_levels = n_levels, classes_tree = classes_tree):
    level_label = dict()
    for i in range(n_levels):
        labels = np.unique(classes_tree.iloc[:,i])
        level_label[i] = dict()
        for j in range(labels.shape[0]):
            level_label[i][labels[j]] = np.eye(labels.shape[0])[j]
    return level_label

level_label = Get_level_label(n_levels, classes_tree)
level_label

In [None]:
train_data_df = pd.read_csv(train_data_file, index_col=0).fillna("Unknown")
train_label_df = pd.read_csv(train_label_file, index_col=0).fillna("Unknown")
valid_data_df = pd.read_csv(valid_data_file, index_col=0).fillna("Unknown")
valid_label_df = pd.read_csv(valid_label_file, index_col=0).fillna("Unknown")
test_data_df = pd.read_csv(test_data_file, index_col=0).fillna("Unknown")

In [None]:
print(train_label_df.shape)

In [None]:
mean = np.mean(train_data_df, axis = 0)

In [None]:
class DenseCrossEntropy(nn.Module):

    def __init__(self):
        super(DenseCrossEntropy, self).__init__()
        
        
    def forward(self, logits, labels):
        logits = logits.float()
        labels = labels.float()
        
        logprobs = F.log_softmax(logits, dim=-1)
        
        loss = -labels * logprobs
        loss = loss.sum(-1)

        return loss

In [None]:
class CellModel(nn.Module):
    def __init__(self, input_size=2000, output_size=256):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        nn.init.kaiming_normal_(self.fc1.weight) # OPTIONAL
        self.fc2 = nn.Linear(1024, 512)
        nn.init.kaiming_normal_(self.fc2.weight) # OPTIONAL
        self.fc3 = nn.Linear(512, 512)
        nn.init.kaiming_normal_(self.fc3.weight) # OPTIONAL
        self.fc4 = nn.Linear(512, 256)
        nn.init.kaiming_normal_(self.fc3.weight) # OPTIONAL
        #self.fc5 = nn.Linear(128, 64)
        #nn.init.kaiming_normal_(self.fc3.weight) # OPTIONAL
        #self.fc6 = nn.Linear(64, 32)
        #nn.init.kaiming_normal_(self.fc3.weight) # OPTIONAL
               
    def forward(self, x):
        # forward always defines connectivity
        x = x.squeeze(-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        output1 = F.relu(self.fc4(x))
        #output2 = F.relu(self.fc4(output1))
        #output3 = F.relu(self.fc5(output2))
        #output4 = F.relu(self.fc6(output3))
        return output1

In [None]:
class ClassifierNode:
    def __init__(self, name):
        self.name = name
        self.children = list()
        self.children_name = dict()
    
    def AddChild(self, child_node):
        self.children_name[child_node.name] = len(self.children)
        self.children.append(child_node)
    
    def BuildClassifier(self):
        if len(self.children) > 0:
            unknown_node = ClassifierNode('Unknown ' + self.name)
            unknown_node.BuildClassifier()
            self.AddChild(unknown_node)
            self.classifier = nn.Linear(256, len(self.children))
        else:
            self.classifier = None
    
    def GetClfList(self):
        if self.classifier is not None:
            clf_list = [self.classifier, ]
            for child in self.children:
                clf_list += child.GetClfList()
            return clf_list
        else:
            return list()
    
    def GetClfName(self):
        if self.classifier is not None:
            name_list = [list(self.children_name.keys()), ]
            for child in self.children:
                name_list += child.GetClfName()
            return name_list
        else:
            return list()
    
    def to(self, device):
        if self.classifier:
            self.classifier.to(device)
            for child in self.children:
                child.to(device)
    
    def parameters(self):
        if self.classifier:
            for p in self.classifier.parameters():
                yield p
            for child in self.children:
                for p in child.parameters():
                    yield p
    
    def classify(self, x):
        types = np.full(len(x), '', dtype=object)
        if self.classifier:
            clf_output = self.classifier(x).detach().cpu().numpy()
            clf_output = np.argmax(clf_output, axis=1)
            for name, idx in self.children_name.items():
                types[clf_output==idx] = name + ','
                child_result = self.children[idx].classify(x[clf_output==idx])
                types[clf_output==idx] += child_result
        return types



def BuildClassifiers(node_name, tree):
    root_node = ClassifierNode(node_name)
    for key, val in tree.items():
        child_node = BuildClassifiers(key, val)
        root_node.AddChild(child_node)
    root_node.BuildClassifier()
    return root_node

In [None]:
root = BuildClassifiers(list(label_tree.keys())[0], list(label_tree.values())[0])

In [None]:
root.GetClfName()

In [None]:
device = torch.device("cuda:0")
model = CellModel().to(device)
root.to(device)
criterion = DenseCrossEntropy()
plist = [
    {'params': model.parameters(), 'lr': LR},
    {'params': root.parameters(), 'lr': LR}
]
optimizer = optim.Adam(plist, lr=LR)

In [None]:
MEAN = 0
STD = 0
class CellDataset(Dataset):
    def __init__(self, data_df, label_df, clf_target, is_train = True, mean = None, std = None):
        self.data_df = data_df
        self.label_df = label_df
        self.clf_target = clf_target
        self.mean = None
        self.std = None
        if(is_train):
            self.mean = np.mean(data_df, axis = 0)
            self.std = np.std(data_df, axis = 0)
        else:
            self.mean = mean
            self.std = std
        #self.data_df = (self.data_df-self.mean)/(self.std+1e-12)
        MEAN = self.mean
        STD = self.std
    def __len__(self):
        return self.data_df.shape[0]
    
    def __getitem__(self, idx):
        data = self.data_df.iloc[idx,:].astype(np.float32)
        data = torch.from_numpy(np.array(data))
        str_label = self.label_df.iloc[idx,:]
        labels = [np.zeros(len(k), dtype=np.float32) for k in self.clf_target]
        weight = np.zeros(len(self.clf_target), dtype=np.float32)
        for lb_id, lb_name in enumerate(str_label):
            found_unknown = True
            if lb_name == 'Unknown' and lb_id > 0 and str_label[lb_id - 1] != 'Unknown':
                lb_name = 'Unknown ' + str_label[lb_id - 1]
            for i in range(len(self.clf_target)):
                if lb_name in self.clf_target[i]:
                    idx = self.clf_target[i].index(lb_name)
                    labels[i][idx] = 1.
                    weight[i] = 1.
                    
        labels = [torch.from_numpy(lb) for lb in labels]
        weight = torch.from_numpy(weight)

        return data, [labels, weight]

In [None]:
dataset_train = CellDataset(train_data_df, train_label_df, root.GetClfName(), is_train = True)
training_data, validation_data = random_split(dataset_train, [len(dataset_train) - 1024, 1024])
dataloader_train = DataLoader(training_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)
dataloader_valid = DataLoader(validation_data, batch_size=1024, num_workers=0, shuffle=True)

In [None]:
model

In [None]:
root.GetClfList()

In [None]:
def train(model = model, classifiers = root, criterion = criterion, optimizer = optimizer,
          dataloader = dataloader_train, validloader = dataloader_valid):
    clf_list = classifiers.GetClfList()
    for epoch in range(N_EPOCHS):
        model.train()
        [clf.train() for clf in clf_list]
        tr_loss = 0
        
        for step, batch in enumerate(dataloader):
            data = batch[0].to(device)
            label, weight = batch[1]
            weight = weight.to(device)
            for i in range(len(label)):
                label[i] = label[i].to(device)
                
            output = model(data)
            loss = None
            for i in range(len(clf_list)):
                result = clf_list[i](output)
                if(loss is None):
                    loss = weight[:, i:i+1] * criterion(result, label[i].squeeze(-1))
                else:
                    loss += weight[:, i:i+1] * criterion(result, label[i].squeeze(-1))
            
            loss = loss.sum()
            optimizer.zero_grad()
            loss.backward()
            tr_loss += loss.item()
            optimizer.step()
            print("Epoch %d, step %d, loss = %.4f"%(epoch, step, loss))
        
        model.eval()
        [clf.eval() for clf in clf_list]
        val_preds = [None] * len(clf_list)
        val_labels = [None] * len(clf_list)
        valid_samples = [0] * len(clf_list)
        total_samples = [0] * len(clf_list)
        
        for step, batch in enumerate(validloader):
            data = batch[0].to(device)
            label, weight = batch[1]
            weight = weight.to(device)
            for i in range(len(label)):
                valid_samples[i] += float(torch.sum(weight[:, i]))
                total_samples[i] += len(weight)
                label[i] = label[i].to(device)
                if val_labels[i] is None:
                    val_labels[i] = label[i].clone().data.cpu()
                else:
                    val_labels[i] = torch.cat((val_labels[i], label[i].clone().data.cpu()), dim=0)
                
            output = model(data)
            loss = None
            for i in range(len(clf_list)):
                result = clf_list[i](output)
                if(loss is None):
                    loss = weight[:, i:i+1] * criterion(result, label[i].squeeze(-1))
                else:
                    loss += weight[:, i:i+1] * criterion(result, label[i].squeeze(-1))
    
                preds = (weight[:, i:i+1] * torch.softmax(result, dim=1)).data.cpu()
                if(val_preds[i] is None):
                    val_preds[i] = preds
                else:
                    val_preds[i] = torch.cat((val_preds[i], preds), dim=0)
            
        clf_accu = []
        for i in range(len(clf_list)):
            correct_num = (torch.argmax(val_preds[i], dim=1)==torch.argmax(val_labels[i], dim=1)).sum()
            clf_accu.append(1. - (total_samples[i] - correct_num) / valid_samples[i])
            print("Epoch %d, clf%d, num_sample = %d, accu = %.4f"%(epoch, i+1, valid_samples[i], clf_accu[i]))
        print("------------------------------------------------------------------")

In [None]:
train()

In [None]:
dataset_valid = CellDataset(valid_data_df, valid_label_df, root.GetClfName(), is_train = False, mean = MEAN, std = STD)
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)

In [None]:
def valid(model = model, classifiers = root, criterion = criterion, dataloader = dataloader_valid):
    clf_list = classifiers.GetClfList()
    
    model.eval()
    [clf.eval() for clf in clf_list]

    val_preds = [None] * len(clf_list)
    val_labels = [None] * len(clf_list)
    valid_samples = [0] * len(clf_list)
    total_samples = [0] * len(clf_list)
    for step, batch in enumerate(dataloader):
        data = batch[0].to(device)
        label, weight = batch[1]
        weight = weight.to(device)
        
        for i in range(len(label)):
            label[i] = label[i].to(device)
            valid_samples[i] += float(torch.sum(weight[:, i]))
            total_samples[i] += len(weight)
            
            if val_labels[i] is None:
                val_labels[i] = label[i].clone().data.cpu()
            else:
                val_labels[i] = torch.cat((val_labels[i], label[i].clone().data.cpu()), dim=0)
                
        with torch.no_grad():       
            output = model(data)
            loss = None
            for i in range(len(clf_list)):
                result = clf_list[i](output)
                preds = (weight[:, i:i+1] * torch.softmax(result, dim=1)).data.cpu()
                if(val_preds[i] is None):
                    val_preds[i] = preds
                else:
                    val_preds[i] = torch.cat((val_preds[i], preds), dim=0)
        
    level_auc = []
    level_accu = []
    for i in range(len(clf_list)):
        #print(val_preds[i])
        #if(torch.unique(val_preds[i], dim=0).shape[0]>1):
        #    level_auc.append(roc_auc_score(val_preds[i], val_labels[i], average='macro'))
        #else:
        correct_num = (torch.argmax(val_preds[i], dim=1)==torch.argmax(val_labels[i], dim=1)).sum()
        
        level_auc.append(1)
        level_accu.append(1. - (total_samples[i] - correct_num) / valid_samples[i])
        print("clf%d, auc = %.4f, accu = %.4f"%(i+1, level_auc[i], level_accu[i]))

In [None]:
valid()

In [None]:
valid_data_torch = torch.tensor(valid_data_df.values).type(torch.float32).to(device)
out = model(valid_data_torch)

clf_ref = root.classify(out)
clf_ref = [x.strip(',').split(',') for x in clf_ref]
clf_ref = np.array([([s if s[:7] != 'Unknown' else 'Unknown' for s in x] + ['Unknown', ] * 4)[:4]
                    for x in clf_ref], dtype=object)

err_rate = np.mean(np.sum(clf_ref != valid_label_df.values, axis=1, dtype=bool))
print("Total err rate on validation set is %.4f"%(err_rate))

np.savetxt('validset_result.csv', clf_ref, fmt='%s', delimiter=',')

In [None]:
test_data_torch = torch.tensor(test_data_df.values).type(torch.float32).to(device)
out = model(test_data_torch)

clf_ref = root.classify(out)
clf_ref = [x.strip(',').split(',') for x in clf_ref]
clf_ref = np.array([([s if s[:7] != 'Unknown' else 'Unknown' for s in x] + ['Unknown', ] * 4)[:4]
                    for x in clf_ref], dtype=object)

np.savetxt('testset_result.csv', clf_ref, fmt='%s', delimiter=',')