In [36]:
import os, sys
sys.path.append("..")
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import itertools
%matplotlib inline 

import torch
import torch.nn as nn
import torchvision
import gc
import random
import torchvision.datasets as datasets
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset, TensorDataset
from src import distributions
from src.tools import unfreeze, freeze
from src.tools import weights_init_D
from src.tools import fig2data, fig2img
from src.guided_samplers import Sampler, PairedSubsetSampler, SubsetGuidedDataset, get_indicies_subset
from src.gaussian_utils import generate_data, plot_gaussian, build_dataloader
from tqdm import tqdm_notebook as tqdm
import ot

In [None]:
dataset_path = "../RNAseq/Splatter/processed_data/"
print("dataset_path: ", dataset_path)
normcounts = pd.read_csv(dataset_path + 'combine_expression.csv')
labels = pd.read_csv(dataset_path + 'combine_labels.csv')
domain_labels = pd.read_csv(dataset_path + 'domain_labels.csv')
data_set = {'features': normcounts.T.values, 'labels': labels.iloc[:, 0].values,
           'accessions': domain_labels.iloc[:, 0].values}

## Init Parameters

In [22]:
NUM_LABELED = 10
T_ITERS = 10
BATCH_SIZE = 32
С_SIZE = 2
NUM_MODES = 4

D_LR, T_LR = 1e-4, 1e-4
PLOT_INTERVAL = 100
CPKT_INTERVAL = 1000
MAX_STEPS = 1001 #increase if you need better results
SEED = 0x000001
torch.manual_seed(SEED); np.random.seed(SEED)

def test_accuracy(classifier, loader, T=None):
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in loader:
            inputs = inputs.reshape(-1, 657)
            if T:
                inputs = T(inputs.cuda())
            outputs = classifier(inputs.cuda())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()

        accuracy = 100 * correct / total
        print(f"Accuracy of the model on the test inputs: {accuracy}%")
    return accuracy

## Build ClassGuided Segerstolpe Datasets

In [8]:
kwargs = {'num_workers': 0, 'pin_memory': True}

source_name = "TM_baron_mouse_for_segerstolpe"
target_name = "segerstolpe_human"
domain_to_indices = np.where(data_set['accessions'] == source_name)[0]
source_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices], 'accessions': data_set['accessions'][domain_to_indices]}

domain_to_indices = np.where(data_set['accessions'] == target_name)[0]
target_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices],'accessions': data_set['accessions'][domain_to_indices]}

common_labels = np.intersect1d(np.unique(source_set['labels']), np.unique(target_set['labels']))
print('common_labels:', common_labels)

source_set_filtered_indices = np.isin(source_set['labels'], common_labels)
source_set_filtered = {
    'features': source_set['features'][source_set_filtered_indices],
    'labels': source_set['labels'][source_set_filtered_indices],
    'accessions': source_set['accessions'][source_set_filtered_indices],
}

label_mapping = {label: index for index, label in enumerate(np.unique(source_set_filtered['labels']))}
print('label_mapping:', label_mapping)

source_set_filtered['labels'] = np.array([label_mapping[label] for label in source_set_filtered['labels']])
target_set['labels'] = np.array([label_mapping[label] for label in target_set['labels']])

print('source labels:', np.unique(source_set_filtered['labels']), ' target labels:', np.unique(target_set['labels']))

test_set_eval = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices], 'accessions': data_set['accessions'][domain_to_indices]}

print(source_set_filtered['features'].shape)
print(target_set['features'].shape)

source_data = torch.utils.data.TensorDataset(torch.FloatTensor(source_set_filtered['features']), torch.LongTensor(source_set_filtered['labels']))
source_loader = torch.utils.data.DataLoader(source_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, **kwargs)

target_data = torch.utils.data.TensorDataset(torch.FloatTensor(target_set['features']), torch.LongTensor(target_set['labels']))
target_loader = torch.utils.data.DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, **kwargs)
target_test_loader = torch.utils.data.DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, **kwargs)

common_labels: [0 1 3 4 5 6 7 8]
label_mapping: {0: 0, 1: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}
source labels: [0 1 2 3 4 5 6 7]  target labels: [0 1 2 3 4 5 6 7]
(3329, 657)
(2108, 657)


In [9]:
source_classes = torch.tensor(list(range(0,len(np.unique(source_set_filtered['labels'])))))
source_labels = {i.item():i.item() for i in source_classes}

target_classes = source_classes
new_target_labels = source_labels

In [10]:
subset_samples, labels, source_class_indicies = get_indicies_subset(source_data, 
                                                                    new_labels = source_labels, 
                                                                    classes=len(source_classes), 
                                                                    subset_classes=source_classes)
source_train = TensorDataset(torch.stack(subset_samples), torch.LongTensor(labels))

target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(target_data, 
                                                                                  new_labels = new_target_labels, 
                                                                                  classes=len(target_classes), 
                                                                                  subset_classes=target_classes)
target_train = TensorDataset(torch.stack(target_subset_samples), torch.LongTensor(target_labels))

train_set = SubsetGuidedDataset(source_train, target_train, 
                                num_labeled=NUM_LABELED, 
                                in_indicies = source_class_indicies, 
                                out_indicies = target_class_indicies)

full_set = SubsetGuidedDataset(source_train, target_train, 
                               num_labeled='all', 
                               in_indicies = source_class_indicies, 
                               out_indicies = target_class_indicies)
T_XY_sampler = PairedSubsetSampler(train_set, subsetsize=С_SIZE)
D_XY_sampler = PairedSubsetSampler(full_set, subsetsize=1)

## Init Networks

In [23]:
class Classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

T = FeedForward(657, 512, 657).cuda()
D = FeedForward(657, 1024, 1).cuda()
T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)

## Train Oracle Classifier

In [24]:
classifier = Classifier(657, 8).cuda()
classifier_opt = torch.optim.Adam(classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [25]:
total_step = len(target_loader)
for epoch in range(10):
    for i, (inputs, labels) in enumerate(target_loader):
        inputs = inputs.reshape(-1, 657)
        outputs = classifier(inputs.cuda())
        loss = criterion(outputs, labels.cuda())
        classifier_opt.zero_grad()
        loss.backward()
        classifier_opt.step()

        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{10}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}")
            with torch.no_grad():
                correct = 0
                total = 0
                for inputs, labels in target_test_loader:
                    inputs = inputs.reshape(-1, 657)
                    outputs = classifier(inputs.cuda())
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels.cuda()).sum().item()

                print(f"Accuracy of the model on the test inputs: {100 * correct / total}%")

Epoch [1/10], Step [10/65], Loss: 1.8696
Accuracy of the model on the test inputs: 21.537001897533205%
Epoch [1/10], Step [20/65], Loss: 1.8387
Accuracy of the model on the test inputs: 34.58254269449716%
Epoch [1/10], Step [30/65], Loss: 1.7604
Accuracy of the model on the test inputs: 47.01138519924098%
Epoch [1/10], Step [40/65], Loss: 1.5038
Accuracy of the model on the test inputs: 58.96584440227704%
Epoch [1/10], Step [50/65], Loss: 1.4988
Accuracy of the model on the test inputs: 64.04174573055029%
Epoch [1/10], Step [60/65], Loss: 1.0741
Accuracy of the model on the test inputs: 66.93548387096774%
Epoch [2/10], Step [10/65], Loss: 1.1407
Accuracy of the model on the test inputs: 69.87666034155598%
Epoch [2/10], Step [20/65], Loss: 1.1629
Accuracy of the model on the test inputs: 72.34345351043643%
Epoch [2/10], Step [30/65], Loss: 0.8445
Accuracy of the model on the test inputs: 74.47817836812145%
Epoch [2/10], Step [40/65], Loss: 0.9549
Accuracy of the model on the test inputs

## Accuracy on the source

In [27]:
test_accuracy(classifier, source_loader)

Accuracy of the model on the test inputs: 63.01081730769231%


63.01081730769231

## Start Training

In [26]:
for step in tqdm(range(MAX_STEPS)):
    unfreeze(T); freeze(D)
    for t_iter in range(T_ITERS): 
        T_opt.zero_grad()
        X, Y = T_XY_sampler.sample(BATCH_SIZE)
        T_X = T(X.flatten(start_dim=0, end_dim=1)).permute(1,0).reshape(657, -1, С_SIZE).permute(1,2,0)    
        T_var = .5 * torch.cdist(T_X, T_X).mean() * С_SIZE / (С_SIZE -1)
        cost = (Y-T_X).norm(dim=2).mean()
        T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()
        T_loss.backward(); T_opt.step()
        
    del T_X, X, Y, T_var; gc.collect(); torch.cuda.empty_cache() 

    freeze(T); unfreeze(D)
    X, _ = T_XY_sampler.sample(BATCH_SIZE)
    _, Y = D_XY_sampler.sample(BATCH_SIZE)
    with torch.no_grad():
        T_X = T(X.flatten(start_dim=0, end_dim=1)) 
    Y = torch.squeeze(Y)
    D_opt.zero_grad()
    D_loss = D(T_X).mean() - D(Y).mean()
    D_loss.backward(); D_opt.step();
    
    if step % PLOT_INTERVAL == 0:
        print('Loss:', T_loss.item())
        print('step:', step)
        _ = test_accuracy(classifier, source_loader, T)
    del D_loss, Y, X, T_X; gc.collect(); torch.cuda.empty_cache()

  0%|          | 0/1001 [00:00<?, ?it/s]

Loss: 33.0207405090332
step: 0
Accuracy of the model on the test inputs: 7.782451923076923%
Loss: 11.21495246887207
step: 100
Accuracy of the model on the test inputs: 83.50360576923077%
Loss: 11.319890022277832
step: 200
Accuracy of the model on the test inputs: 90.05408653846153%
Loss: 12.746994972229004
step: 300
Accuracy of the model on the test inputs: 91.64663461538461%
Loss: 11.459309577941895
step: 400
Accuracy of the model on the test inputs: 88.64182692307692%
Loss: 11.525487899780273
step: 500
Accuracy of the model on the test inputs: 92.51802884615384%
Loss: 12.656315803527832
step: 600
Accuracy of the model on the test inputs: 90.80528846153847%
Loss: 12.93928337097168
step: 700
Accuracy of the model on the test inputs: 90.53485576923077%
Loss: 13.689186096191406
step: 800
Accuracy of the model on the test inputs: 91.22596153846153%
Loss: 11.785080909729004
step: 900
Accuracy of the model on the test inputs: 90.625%
Loss: 13.19101333618164
step: 1000
Accuracy of the model 

## Find the Best Discrete Solver

In [43]:
source_list_subset = [source_loader.dataset[n][0].numpy().flatten() for n in range(len(source_loader.dataset))]
source_labels_subset = [source_loader.dataset[n][1] for n in range(len(source_loader.dataset))]

Xs = np.array(source_list_subset)
ys = np.array(source_labels_subset)

target_list_subset = [target_loader.dataset[n][0].numpy().flatten() for n in range(len(target_loader.dataset))]
target_labels_subset = [target_loader.dataset[n][1] for n in range(len(target_loader.dataset))]

Xs = np.array(source_list_subset)
ys = np.array(source_labels_subset)
Xt = np.array(target_list_subset)
yt = np.array(target_labels_subset)

unique_classes = np.unique(yt)
mask = np.full(yt.size, False)

# Set 10 samples per class to True in the mask
for c in unique_classes:
    class_indices = np.where(yt == c)[0]
    selected_indices = np.random.choice(class_indices, NUM_LABELED, replace=False)
    mask[selected_indices] = True

# Set the non-selected samples to -1
yt[~mask] = -1

In [44]:
print('Model Accuracy before adaptation:')
test_accuracy(classifier, source_loader)

for reg_e in [0.1, 1, 2, 5, 10, 100]:
    for reg in [0.1, 1, 2, 5, 10, 100]:
        print('Regularizetion reg_cl:',reg)
        print('Regularizetion reg_e:', reg_e)
        print('--------------------------')
        transports = {'EMD':ot.da.EMDTransport(), 
                      'Sinkhorn':ot.da.SinkhornTransport(reg_e=reg_e, verbose=False), 
                      'SinkhornLpl1':ot.da.SinkhornLpl1Transport(reg_e=reg_e, reg_cl=reg, verbose=False),
                      'SinkhornL1l2':ot.da.SinkhornL1l2Transport(reg_e=reg_e, reg_cl=reg, verbose=False),
                      'MapOT':ot.da.MappingTransport(kernel="linear", mu=1, eta=1e-0, bias=True, max_iter=20, verbose=False)
                     }

        accs = []
        for ot_name in transports:
            ot_mapping = transports[ot_name]
            ot_mapping.fit(Xs=Xs[:], 
                           Xt=Xt[:],
                           ys=ys, 
                           yt=yt)

            # for out of source samples, transform applies the linear mapping
            X_test_mapped_ = ot_mapping.transform(Xs=Xs)
            X_test_mapped = TensorDataset(torch.FloatTensor(X_test_mapped_), torch.LongTensor(ys)) 
            X_test_mapped = DataLoader(X_test_mapped, batch_size=100)
            print(ot_name)
            accs.append(test_accuracy(classifier,X_test_mapped))
            print('--------------------------')

Model Accuracy before adaptation:
Accuracy of the model on the test inputs: 62.98076923076923%
Regularizetion reg_cl: 2
Regularizetion reg_e: 2
--------------------------
EMD
Accuracy of the model on the test inputs: 50.435566236106936%
--------------------------
Sinkhorn
Accuracy of the model on the test inputs: 6.848903574647041%
--------------------------
SinkhornLpl1
Accuracy of the model on the test inputs: 4.025232802643436%
--------------------------
SinkhornL1l2
Accuracy of the model on the test inputs: 17.482727545809553%
--------------------------
MapOT
Accuracy of the model on the test inputs: 42.234905376990085%
--------------------------
Regularizetion reg_cl: 5
Regularizetion reg_e: 2
--------------------------
EMD
Accuracy of the model on the test inputs: 50.435566236106936%
--------------------------
Sinkhorn
Accuracy of the model on the test inputs: 6.848903574647041%
--------------------------
SinkhornLpl1
Accuracy of the model on the test inputs: 3.4244517873235205%
