# Imports

In [11]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from scipy.spatial import KDTree
import pandas as pd
import numpy as np
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
import math
from itertools import compress, combinations, product
from collections import Counter
from sklearn.metrics import balanced_accuracy_score
from statistics import mean
import matplotlib.pyplot as plt
import warnings

In [35]:
import os
import fastprogress
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader


from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))


# checks for GPU
def get_device(cuda_preference=True):
    print('cuda available:', torch.cuda.is_available(), 
          '; cudnn available:', torch.backends.cudnn.is_available(),
          '; num devices:', torch.cuda.device_count())
    
    use_cuda = False if not cuda_preference else torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    device_name = torch.cuda.get_device_name(device) if use_cuda else 'cpu'
    print('Using device', device_name)
    return device

# trains target domain network
def train_target(X_target, y_target, source_emb, optimizer, model, loss_fn,
                 device, num_epochs, master_bar):

    patience = 50
    counter = 0
    loss_min = math.inf
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, 
                                                     patience=10, min_lr=0.0001)
    for epoch in fastprogress.progress_bar(range(num_epochs), parent=master_bar):
        optimizer.zero_grad()
        model.train()
        target_emb, output = model(X_target.to(device))
        loss_total, loss_Reg, loss_CE = loss_fn(target_emb, source_emb, 
                                                output, y_target.to(device))

        if loss_total >= loss_min:
          counter += 1
        else:
          loss_min = loss_total
          counter = 0
        if counter > patience:
          break
        scheduler.step(loss_total)
        loss_total.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        target_emb, output = model(X_target.to(device))  
    return target_emb, loss_total, loss_Reg, loss_CE

# trains source domain network
def train_source(X_source, target_emb, optimizer, model, loss_fn, num_epochs,
                 device, master_bar):
    patience = 50
    counter = 0
    loss_min = math.inf
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.1, 
                                                     patience=10, min_lr=0.0001)
    for epoch in fastprogress.progress_bar(range(num_epochs), parent=master_bar):
        optimizer.zero_grad()
        model.train()
        source_emb = model(X_source.to(device))
        loss = loss_fn(target_emb, source_emb)
        if loss >= loss_min:
          counter += 1
        else:
          loss_min = loss
          counter = 0
        if counter > patience:
          break
        scheduler.step(loss)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        source_emb = model(X_source.to(device))  
    return source_emb, loss

# predicts class
def predict_class(model_target, model_source, X_test, avg_source):
    model_source.eval()
    with torch.no_grad():
        source_emb = model_source(avg_source.to(device))
    model_target.eval()
    with torch.no_grad():
        target_emb, output = model_target(X_test.to(device))    
    tree = KDTree(source_emb.cpu().numpy())
    distances, predictions = tree.query(target_emb.cpu().numpy(), k=1)
    return distances, predictions

# runs training
def run_training(model_target,  loss_fn_target, model_source, loss_fn_source,
                 lr, eta, X_target, X_source, y_target, device, num_iteration,
                 num_epochs, verbose=True):
    start_time = time.time()
    master_bar = fastprogress.master_bar(range(num_iteration))
    model_source.eval()
    with torch.no_grad():
        source_emb = model_source(X_source.to(device))  
    for iterarion in master_bar:
        optimizer_target = optim.Adam(model_target.parameters(), lr, weight_decay=eta)
        optimizer_source = optim.Adam(model_source.parameters(), lr, weight_decay=eta)
        target_emb, loss_total, loss_Reg, loss_CE = train_target(X_target, y_target, source_emb, optimizer_target,
                                                                 model_target, loss_fn_target, device, num_epochs,
                                                                 master_bar)
        source_emb, loss = train_source(X_source, target_emb, optimizer_source,
                                        model_source, loss_fn_source, num_epochs,
                                        device, master_bar)              
        if verbose:
            master_bar.write(f'target loss: {loss_total.detach().cpu().numpy():.2f}, reg loss: {loss_Reg.detach().cpu().numpy():.2f}, ce loss: {loss_CE.detach().cpu().numpy():.3f}, source loss: {loss.detach().cpu().numpy():.3f}')                
    time_elapsed = np.round(time.time() - start_time, 0).astype(int)
    print(f'Finished training after {time_elapsed} seconds.')
    return


# runs semi-supervised
def fit_semi_supervised(model_target, loss_fn_target,
                        model_source, loss_fn_source, lr, eta,
                        X_norm_tens, X_source_avg_tens, X_test_norm_tens, y_seen_tens, avg_source_tens, device, num_iteration,
                        num_epochs, missing, K_seen=10, K_unseen=10, part=0.5):    
      for k in range(K_unseen):
        distances, labels = predict_class(model_target, model_source, X_test_norm_tens, avg_source_tens)
        all_idx = np.arange(len(labels))
        if sum(np.in1d(labels, missing)) > 1:
          idx = np.array([])
          for i in missing:
            temp = all_idx[np.in1d(labels, i)]
            if len(temp) < 2:
              continue
            lowest_ind = np.argsort(distances[temp])[:math.ceil(len(temp) *((k + 1) / K_unseen) * part)]
            idx = np.concatenate((idx, temp[lowest_ind]))
          seen = np.in1d(all_idx, idx)
          X_target_new = torch.cat((X_norm_tens, X_test_norm_tens[seen]))
          y_new = torch.cat((y_seen_tens, torch.LongTensor(labels[seen])))
          X_source_new = torch.cat((X_source_avg_tens, avg_source_tens[labels[seen]]))
          print(sum(seen))
          run_training(model_target, loss_fn_target,
                      model_source, loss_fn_source, lr, eta,
                      X_target_new, X_source_new, y_new, device, num_iteration,
                      num_epochs, verbose=True)

      for k in range(K_seen):
        distances, labels = predict_class(model_target, model_source, X_test_norm_tens, avg_source_tens)
        all_idx = np.arange(len(labels))
        lowest_ind = np.argsort(distances)[:math.ceil(len(distances) *((k + 1)/ K_seen))]
        idx = all_idx[lowest_ind]
        seen = np.in1d(all_idx, idx)
        print(sum(seen))
        X_target_new = torch.cat((X_norm_tens, X_test_norm_tens[seen]))
        y_new = torch.cat((y_seen_tens, torch.LongTensor(labels[seen])))
        X_source_new = torch.cat((X_source_avg_tens, avg_source_tens[labels[seen]]))
        run_training(model_target, loss_fn_target,
                     model_source, loss_fn_source, lr, eta,
                     X_target_new, X_source_new, y_new, device, num_iteration,
                     num_epochs, verbose=True)

      return

# Specific loss for target network
class CombLoss(nn.Module):
    def __init__(self, lam):
        super(CombLoss, self).__init__()
        self.lam = lam
 
    def forward(self, target_emb, source_emb, y_pred, y_true):

      loss_Reg = nn.MSELoss(reduction="sum")(target_emb, source_emb)
      loss_CE = nn.CrossEntropyLoss(reduction="sum")(y_pred, y_true)
      loss_total = loss_Reg + self.lam * loss_CE

      return(loss_total, loss_Reg, loss_CE)

# Builds visual (target) network
class TAEM_target(nn.Module):
    def __init__(self, dim_in, dim_out, n_classes):
        super(TAEM_target, self).__init__()
        self.linear1 = nn.Linear(dim_in, dim_out)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(dim_out, n_classes)
    
    def forward(self, x):      
        x = self.linear1(x)
        x = self.act1(x)
        y = self.linear2(x)
        return x, y


# Builds sematic (source) network
class TAEM_source(nn.Module):
    def __init__(self, dim_in, dim_out, n_hidden):
        super(TAEM_source, self).__init__()
        self.linear1 = nn.Linear(dim_in, n_hidden)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(n_hidden, dim_out)
        self.act2 = nn.ReLU()
    
    def forward(self, x):       
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        x = self.act2(x)
        return x


# creates source means for each class
def get_class_attributes(X_source, y_source, y_target):
    avg_source = X_source.groupby(y_source).mean()
    avg_source = MinMaxScaler().fit_transform(avg_source)
    X_source_avg = np.zeros((len(y_target), 50))
    for i in range(len(y_target)):
        X_source_avg[i] = avg_source[y_target[i]]
    return X_source_avg, avg_source


# balancing seen classes
def balance_sampling(X, y, n=100):

    warnings.filterwarnings('ignore')
    counts = Counter(y)
    under = np.array([], dtype="int32")
    over = np.array([], dtype="int32")
    for i in counts.keys():
        if counts[i] <= n:
            over = np.concatenate((over, np.array([i])))
        else:
            under = np.concatenate((under, np.array([i])))
    if len(over) == 0:
        dict_under = dict(zip(under, [n for i in range(len(under))]))
        under_sam =  RandomUnderSampler(sampling_strategy=dict_under)
        X_under, y_under = under_sam.fit_resample(X, y)
        return X_under, y_under
    elif len(under) == 0:
        dict_over = dict(zip(over, [n for i in range(len(over))]))
        over_sam = SMOTE(sampling_strategy=dict_over)
        X_over, y_over = over_sam.fit_resample(X, y)
        return X_over, y_over
    else:
        if len(over) == 1:
            # Tricks SMOTE into oversampling one class
            pseudo_X = np.full((n, X.shape[1]), 10000)
            pseudo_y = np.full(n, 10000)
            dict_over = dict()
            dict_over[over[0]] = n
            dict_over[10000] = n
            is_over = np.in1d(y, over)
            over_sam = SMOTE(sampling_strategy=dict_over)
            is_over = np.in1d(y, over)
            X_over_, y_over_ = over_sam.fit_resample(np.concatenate((X[is_over], pseudo_X)),
                                                     np.concatenate((y[is_over], pseudo_y)))
            X_over = X_over_[y_over_==over[0]]
            y_over = y_over_[y_over_==over[0]]

        else:
            dict_over = dict(zip(over, [n for i in range(len(over))]))
            over_sam = SMOTE(sampling_strategy=dict_over)
            is_over = np.in1d(y, over)
            X_over, y_over = over_sam.fit_resample(X[is_over], y[is_over])

        if len(under) == 1:
            # Tricks RandomUnderSampler into working with one class
            pseudo_X = np.full((n, X.shape[1]), 10000)
            pseudo_y = np.full(n, 10000)
            dict_under = dict()
            dict_under[under[0]] = n
            dict_under[10000] = n
            is_under = np.in1d(y, under)
            under_sam = RandomUnderSampler(sampling_strategy=dict_under)
            is_under = np.in1d(y, under)
            X_under_, y_under_ = under_sam.fit_resample(np.concatenate((X[is_under], pseudo_X)),
                                                        np.concatenate((y[is_under], pseudo_y)))
            X_under = X_under_[y_under_==under[0]]
            y_under = y_under_[y_under_==under[0]]
        else:
            dict_under = dict(zip(under, [n for i in range(len(under))]))
            under_sam = RandomUnderSampler(sampling_strategy=dict_under)
            is_under = np.in1d(y, under)
            X_under, y_under = under_sam.fit_resample(X[is_under], y[is_under])

        X_combined_sampling = np.concatenate((X_over, X_under))
        y_combined_sampling = np.concatenate((y_over, y_under))
        return X_combined_sampling, y_combined_sampling


def split_masked_cells(X_t, y_t, masked_cells, balance=False, n=500):
    """
    Maskes cells for generalized zero-shot learning
    :param X_t: feature matrix of target data
    :param y_t: labels of target data
    :param masked_cells: list of cells to be masked from data
    :param balance: whether to balance seen train data
    :param n: desired number of samples per class
    :return: features of seen classes, features of unseen classes, labels seen classes, labels unseen classes
    """
    keep = np.in1d(y_t, masked_cells, invert=True)
    X_t_seen = X_t[keep]
    X_t_unseen = X_t[~keep]
    y_seen = y_t[keep]
    y_unseen = y_t[~keep]
    if balance:
        X_t_seen, y_seen = balance_sampling(X_t_seen, y_seen, n)
    return X_t_seen, X_t_unseen, y_seen, y_unseen


def h_score(y_true, y_pred, masked_cells):
    """
    H score
    :param y_true: true values
    :param y_pred: predictions
    :param masked_cells: list of masked cells
    :return: h-score, acc. of known classes, acc of unknown classes
    """
    warnings.filterwarnings('ignore')
    known = np.in1d(y_true, masked_cells, invert=True)
    acc_known = balanced_accuracy_score(y_true[known], y_pred[known])
    acc_unknown = balanced_accuracy_score(y_true[~known], y_pred[~known])
    h = (2 * acc_known * acc_unknown) / (acc_known + acc_unknown)
    return h, acc_known, acc_unknown


In [44]:
# Hyperparameter selection and run of one masked combination, takes very, very long!!!
def run_comb(X_source, y_source, X_train, y_train, X_test, y_test, combination):
      
      X_seen, X_unseen, y_seen, y_unseen = split_masked_cells(X_train, y_train, combination)
      X_source_avg, avg_source = get_class_attributes(X_source, y_source, y_seen)
      scaler = MinMaxScaler(feature_range=(0, 1))
      X_norm = scaler.fit_transform(X_seen)
      X_norm_tens = torch.FloatTensor(X_norm)
      y_seen_tens = torch.LongTensor(y_seen)
      X_test_norm = scaler.transform(X_test)
      X_test_norm_tens = torch.FloatTensor(X_test_norm)
      avg_source_tens = torch.FloatTensor(avg_source)
      X_source_avg_tens = torch.FloatTensor(X_source_avg)
      eta = [0.001, 0.01]
      lam = [0.2, 0.35, 0.5]
      lr =.001
      num_epochs = 5000
      num_iteration = 5
      n_common = 50
      current_best = 0
      params = dict()
      for e in eta:
        for l in lam:
            score = []
            print("Testing eta="+str(e)+", lambda="+str(l))
            for train_index, test_index in KFold(shuffle=True, n_splits=5).split(X_seen):
                X_train_val, X_val = X_norm_tens[train_index], X_norm_tens[test_index]
                y_train_val, y_val = y_seen_tens[train_index], y_seen_tens[test_index]
                X_source_avg_tens_val = X_source_avg_tens[train_index]
                for i in set(y_seen):
                   X_seen_val, X_unseen_val, y_seen_val, y_unseen_val = split_masked_cells(X_train_val, y_train_val, masked_cells=[i])
                   X_source_avg_seen, _res, _res2, _res3 = split_masked_cells(X_source_avg_tens_val, y_train_val, masked_cells=[i])                
                   loss_fn_target = CombLoss(lam=l)
                   loss_fn_source = nn.MSELoss(reduction="sum")
                   model_target = TAEM_target(50, n_common, 11)
                   model_target.to(device)
                   model_source = TAEM_source(50, n_common, 32)
                   model_source.to(device)

                   run_training(model_target, loss_fn_target,
                                model_source, loss_fn_source, lr, e,
                                X_seen_val, X_source_avg_seen, y_seen_val, device, num_iteration,
                                num_epochs)
            
                   distances, labels = predict_class(model_target, model_source, X_val, avg_source_tens)
                   h, acc_known, acc_unknown = h_score(y_val.numpy(), labels, [i])
                   score.append(h)
            score = mean(score)
            if score > current_best:
              current_best = score
              params["eta"] = e
              params["lam"] = l
      
      print(params)
      loss_fn_target = CombLoss(lam=params["lam"])
      loss_fn_source = nn.MSELoss(reduction="sum")
      model_target = TAEM_target(50, n_common, 11)
      model_target.to(device)
      model_source = TAEM_source(50, n_common, 32)
      model_source.to(device)      
      run_training(model_target, loss_fn_target,
                         model_source, loss_fn_source, lr, params["eta"],
                         X_norm_tens, X_source_avg_tens, y_seen_tens, device, num_iteration,
                         num_epochs, verbose=True)    

      distances, labels = predict_class(model_target, model_source, X_test_norm_tens, avg_source_tens)   
      h, acc_known, acc_unknown = h_score(y_test, labels, combination)   


      fit_semi_supervised(model_target, loss_fn_target,
                          model_source, loss_fn_source, lr, params["eta"],
                          X_norm_tens, X_source_avg_tens, X_test_norm_tens, y_seen_tens, avg_source_tens, device, num_iteration,
                          num_epochs, combination, K_seen=10, K_unseen=10, part=0.5)         
            
      distances, labels = predict_class(model_target, model_source, X_test_norm_tens, avg_source_tens)
      h_ssl, acc_known_ssl, acc_unknown_ssl = h_score(y_test, labels, combination)
      return labels, h, acc_known, acc_unknown, h_ssl, acc_known_ssl, acc_unknown_ssl

# Preparations

In [6]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [7]:
device = get_device()

cuda available: True ; cudnn available: True ; num devices: 1
Using device Tesla T4


Play around...

In [None]:
X_source = pd.read_csv("/content/drive/MyDrive/data/brain_mouse_red_scetm.csv", index_col=0)
y_source = pd.read_csv("/content/drive/MyDrive/data/brain_mouse_red_label.csv", index_col=0)["label"]
X_train = pd.read_csv("/content/drive/MyDrive/data/brain_human_red_train_scetm.csv", index_col=0)
y_train = pd.read_csv("/content/drive/MyDrive/data/brain_human_red_train_label.csv", index_col=0)["label"]
X_test = pd.read_csv("/content/drive/MyDrive/data/brain_human_red_test_scetm.csv", index_col=0)
y_test = pd.read_csv("/content/drive/MyDrive/data/brain_human_red_test_label.csv", index_col=0)["label"]
X_train, y_train = balance_sampling(X_train, y_train, 300)
combination = [6, 9]
result = run_comb(X_source, y_source, X_train, y_train, X_test, y_test, combination)
results