In [None]:
import numpy as np
from os.path import join
import os
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, normalized_mutual_info_score as nmi_score, completeness_score
import umap
import sys
import importlib
import torch
import anndata as ad
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.cm import rainbow
%matplotlib inline

# Init

In [None]:
PROJ_DIR = '/home/gcgreen2/neurips_comp'
DATA_DIR = join(PROJ_DIR, 'data')
OUT_DIR = os.getcwd()

In [None]:
sys.path.append(join(PROJ_DIR,'cae'))
from scripts import models, utils, task3_metrics as t3

In [None]:
# importlib.reload(models)

In [None]:
with open(join(OUT_DIR,'config.py'), 'r') as fh:
    lines = fh.read()
    eval(compile(lines, '<string>', 'exec'))

In [None]:
par

## Load data & model

In [None]:
adata_mod1 = ad.read_h5ad(par['data_mod1'])
adata_mod2 = ad.read_h5ad(par['data_mod2'])

In [None]:
mod1_pca = np.load(files['mod1_pca'])
mod2_pca = np.load(files['mod2_pca'])

In [None]:
adata_mod1.obsm['X_pca'] = mod1_pca
adata_mod2.obsm['X_pca'] = mod2_pca

In [None]:
adata_mod1.obs['mod'] = 'mod1'
adata_mod2.obs['mod'] = 'mod2'

In [None]:
model = eval(utils.model_str(par))
model.load_state_dict(torch.load(files['model']))

## Get latent space repr. and perform clustering

In [None]:
X,Y = [torch.FloatTensor(data.obsm['X_pca']) for data in [adata_mod1,adata_mod2]]
_,_,Mu_mod1,Logvar_mod1,Mu_mod2,Logvar_mod2,Z_mod1,Z_mod2 = [x.detach().numpy() for x in model(X,Y)]

In [None]:
adata_mod1.obsm['Z'] = Z_mod1
adata_mod2.obsm['Z'] = Z_mod2
adata_mod1.obsm['Z_mu'] = Mu_mod1
adata_mod2.obsm['Z_mu'] = Mu_mod2
adata_mod1.obsm['Z_var'] = Logvar_mod1
adata_mod2.obsm['Z_var'] = Logvar_mod2
adata_mod1.obsm['Z_2d'] = Z_mod1[:,:2]
adata_mod2.obsm['Z_2d'] = Z_mod2[:,:2]

adata_mod1.obsm['Z_mu_avg'] = 1/2 * (Mu_mod1+Mu_mod2)

In [None]:
overwrite = False
umap_path = join(OUT_DIR, 'umap.npy')
if os.path.exists(umap_path) and not overwrite:
    Z_umap = np.load(umap_path, allow_pickle=True)
else:
    Z = np.concatenate((Z_mod1,Z_mod2), axis=0)
    Z_umap = umap.UMAP().fit_transform(Z)
    np.save(umap_path, Z_umap)

In [None]:
adata_mod1.obsm['Z_umap'] = Z_umap[:len(Z_umap)//2]
adata_mod2.obsm['Z_umap'] = Z_umap[len(Z_umap)//2:]

In [None]:
# tsne_path = join(OUT_DIR, 'tsne.npy')
# if os.path.exists(tsne_path):
#     Z_tsne = np.load(tsne_path, allow_pickle=True)
# else:
#     Z = np.concatenate((Z_mod1,Z_mod2), axis=0)
#     Z_tsne = TSNE(2).fit_transform(Z)
#     np.save(tsne_path, Z_tsne)

In [None]:
# adata_mod1.obsm['Z_tsne'] = Z_tsne[:len(Z_tsne)//2]
# adata_mod2.obsm['Z_tsne'] = Z_tsne[len(Z_tsne)//2:]

## Add test and train info

In [None]:
# idx_train = np.loadtxt(files['idx_train'], dtype=int)
# idx_test = np.loadtxt(files['idx_test'], dtype=int)

In [None]:
# is_train = np.full(len(adata_mod1), True)
# is_train[idx_test] = False

In [None]:
# adata_mod1.obs['is_train'] = is_train
# adata_mod2.obs['is_train'] = is_train

In [None]:
# mod1_train, mod2_train, mod1_test, mod2_test = \
#     [adata[idx] for idx in [idx_train, idx_test] for adata in [adata_mod1,adata_mod2]]

# Latent space plots

In [None]:
names=['Mu_mod1','Mu_mod2','Logvar_mod1','Logvar_mod2','Z_mod1','Z_mod2']
xs=[Mu_mod1,Mu_mod2,np.exp(0.5*Logvar_mod1),np.exp(0.5*Logvar_mod2),Z_mod1,Z_mod2]
plt.figure(figsize=(10,7))
for x,n in zip(xs,names):
    plt.hist(x.flatten(),alpha=0.5,label=n)
plt.legend()

In [None]:
def plot_latent(mod1, mod2, col, train_test='all', clustering='Z_umap'):
#     if train_test != 'all':
#         idx = np.where(mod1.obs['is_train'] == (train_test=='train'))
#         mod1 = mod1[idx]
#         mod2 = mod2[idx]
        
    labels = set(np.unique(mod1.obs[col]))
    labels.update(np.unique(mod2.obs[col]))
    colors = rainbow(np.linspace(0,1,len(labels),endpoint=False), alpha=0.2)
    
    plt.figure(figsize=(9,7))
    plt.title(col + ', latent space tsne')
    for i,label in enumerate(labels):
        cur_mod1 = mod1[np.where(mod1.obs[col] == label)].obsm[clustering]
        cur_mod2 = mod2[np.where(mod2.obs[col] == label)].obsm[clustering]
        x = np.concatenate((cur_mod1[:,0], cur_mod2[:,0]), axis=0)
        y = np.concatenate((cur_mod1[:,1], cur_mod2[:,1]), axis=0)
        plt.plot(x, y, 'o', color=colors[i], label=label)
        
    leg = plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    for lh in leg.legendHandles: 
        lh._legmarker.set_alpha(1)

### Modality

In [None]:
plot_latent(adata_mod1, adata_mod2, 'mod')

### Cell type

In [None]:
plot_latent(adata_mod1, adata_mod2, 'cell_type')

### Batch

In [None]:
plot_latent(adata_mod1, adata_mod2, 'batch')

# Clustering metrics

In [None]:
clust_path = join(OUT_DIR, 'clustering.npy')
sc.pp.neighbors(adata_mod1, n_pcs=0, use_rep='Z_mu_avg')
sc.tl.louvain(adata_mod1)
np.save(clust_path, adata_mod1.obs['louvain'])

In [None]:
pred = adata_mod1.obs['louvain']
gt = adata_mod1.obs['cell_type']

In [None]:
silhouette = silhouette_score(adata_mod1.obsm['Z_mu_avg'], gt)
nmi = nmi_score(gt, pred)
completeness = completeness_score(gt, pred)
print(f'silhouette coef: {silhouette},  nmi: {nmi},  completeness: {completeness}')

# Task 3 Metrics

In [None]:
# t3.evaluation_task3()