In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch.nn import Parameter
import numpy as np
import matplotlib.pyplot as plt

from pyro.infer import MCMC, NUTS, Predictive
import pyro
from pyro import poutine
from pyro import distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO
from pyro.infer.autoguide import AutoDelta, AutoNormal
pyro.enable_validation(True)
import tqdm

import scipy
from scipy import ndimage
from scipy.sparse import load_npz
from scipy.sparse import coo_matrix
#import skimage.color

from sklearn.neighbors import KDTree

from collections import Counter

import scanpy as sc

import seaborn as sns

In [None]:
def subsample_clusters(obj, group_label = 'celltype_subset', subsample_factor = 0.2):
    chosen = []
    indxs = np.arange(obj.n_obs)
    for group in obj.obs[group_label].unique():
        group_indxs = np.where(obj.obs[group_label] == group)[0]
        group_chosen = np.random.choice(group_indxs,
                                        size=np.maximum(int(len(group_indxs)*subsample_factor), [len(group_indxs),250][int(len(group_indxs) > 250)]),
                                        replace=False)
        #print(len(group_chosen))
        chosen += list(group_chosen)

    return np.array(chosen)

def data_to_zero_truncated_cdf(x):
    x_sorted = np.sort(x)
    indx_sorted = np.argsort(x)
    x_sorted = x[indx_sorted]
    zero_ind = np.where(x_sorted == 0)[0][-1]
    p = np.concatenate([np.zeros(zero_ind), 1. * np.arange(len(x_sorted) - zero_ind) / (len(x_sorted) - zero_ind - 1)])
    cdfs = np.zeros_like(x)
    cdfs[indx_sorted] = p
    return cdfs

In [None]:
data = sc.read_h5ad('../data_atlas/atals_processed.h5ad')
data.layers['raw_data'] = ((data.X.expm1() / 10000).multiply(data.obs[['nCount_RNA']].values)).tocsr() #real data 

data.obs['celltype_subset_alt'] = list(data.obs['celltype_subset'].values)
#data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'celltype_subset_alt'] = data.obs.loc[data.obs['gene_module'] != 'no_gene_module', 'gene_module']
data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_subset_alt'] = data.obs.loc[data.obs['celltype_major'] == 'CAFs', 'celltype_minor']

group_today = 'celltype_subset_alt'

data_sub = data[subsample_clusters(data, subsample_factor=0.1, group_label=group_today)]

mean_exp_vals = []
for g in data_sub.obs[group_today].unique():
    #print(g)
    tmp = data_sub.X[np.where(data_sub.obs[group_today] == g)[0],:]
    tmp = tmp.toarray().mean(axis=0)
    mean_exp_vals.append(tmp)

In [None]:
expressed_genes = np.where((np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0)) - 1 > 
                            (np.exp(np.stack(mean_exp_vals, axis=0).max(axis=0))[np.where(data_sub.var.index == 'RPLP0')[0]] - 1)/20))[0]

data_sub = data_sub[:, expressed_genes].copy()

In [None]:
def multi_logist_hierarchical(D, X, label_graph, class_size, device='cuda'):
    '''
    D ~ Multinomial(softmax(Xw))
    
    Parameters
    ----------
    D : torch.tensor
        Cell type assignement (cells x levels) 
    X : torch.tensor
        Expression matrix (cells x genes)
    label_graph : dictionary
        level mapping binary (parent x children)
    class_size : dict 
        
    
    
    device: torch.device
        Device specification (default - CUDA) for CPU use torch.device('cpu')

    '''
    i_cells, g_genes = X.shape
    c_class_levels = [class_size[i].size()[0] for i in range(2)]
    
    #top level weights
    w_top = pyro.sample('w_top', dist.Laplace(torch.tensor([0.]).to(device),
                                     torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[0]]).to_event(2))   
    f_top = torch.nn.functional.softmax(torch.matmul(X, w_top / (class_size[0])**0.5), dim=1)
    #bottom levels
    w_bottom = pyro.sample('w_bottom', dist.Laplace(torch.tensor([0.]).to(device),
                                     torch.tensor([0.5]).to(device)).expand([g_genes, c_class_levels[1]]).to_event(2)) 
    
    
    #w_top_actual = torch.zeros_like(w_top)
    f_bottom = torch.zeros((i_cells, c_class_levels[1])).to(device)
    #for parent in range(c_class_levels[0]):
        #w_top_actual[:, parent] = torch.maximum(w_top[:, parent] - w_bottom[:, label_graph[parent]].sum(axis=1), torch.zeros_like(w_top_actual[:, parent]))
        #f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) / len(label_graph[parent])
    #w_top_actual = pyro.deterministic('w_top_actual', w_top_actual)
    #f_top = torch.nn.functional.softmax(torch.matmul(X, w_top_actual / (class_size[0])**0.5), dim=1)

    for parent in range(c_class_levels[0]):
        f_bottom[:, label_graph[parent]] += torch.nn.functional.softmax(torch.matmul(X, w_bottom[:, label_graph[parent]] / (class_size[1][label_graph[parent]])[None,:]**0.5), dim=1) * f_top[:,parent,None]
    
    
    #bottom level
    obs_top = pyro.sample('likelihood_top', dist.Categorical(f_top).to_event(1), obs=D[:,0])
    obs_bottom = pyro.sample('likelihood_bottom', dist.Categorical(f_bottom).to_event(1), obs=D[:,1])  

In [None]:
#hierarchical labeling 
np.random.seed(42)
pyro.set_rng_seed(42)
torch.manual_seed(42)

device='cpu'
D = torch.tensor(np.float32(np.stack([data_sub.obs['celltype_major'].cat.codes.values, data_sub.obs[group_today].cat.codes.values], axis=1))).to(device=device)
X = np.apply_along_axis(data_to_zero_truncated_cdf, 0, data_sub.X.toarray()) #/ (data_sub.layers['raw_data'].toarray().max(axis=0)**0.5)[None,:]
X = torch.tensor(np.float32(X)).to(device=device)
class_size = [torch.tensor(np.float32(data_sub.obs.groupby(group).size().values)).to(device=device) for group in ['celltype_major', group_today]]
label_graph = {}
for i, g in enumerate(data_sub.obs['celltype_major'].cat.categories):
    label_graph[i] = data_sub.obs[data_sub.obs['celltype_major'] == g][group_today].unique().codes

pyro.clear_param_store()

model = multi_logist_hierarchical
guide = AutoNormal(model)

adam_params = {"lr": 0.01, "betas": (0.95, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 10000
# do gradient steps
loss_list = []
for step in tqdm.tqdm(range(n_steps)):
    loss = svi.step(D, X, label_graph, class_size, 'cpu')
    loss_list.append(loss)


In [None]:
trace = Predictive(model, guide=guide, num_samples=100)(D, X, label_graph, class_size, 'cpu')

In [None]:
gene_names = data_sub.var.index.values

selected_dcit_top = {}
for i, name in enumerate(data_sub.obs['celltype_major'].cat.categories):
    weights = trace['w_top'].mean(axis=0)[:, i].cpu().numpy()
    top4 = np.argpartition(weights, -4)[-4:]
    selected_dcit_top[name] = gene_names[top4]
selected_dcit_top['hk'] = np.array(['RPLP0'])

In [None]:
fig = sc.pl.dotplot(data, selected_dcit_top, 'celltype_major')