In [None]:
# various import statements
import numpy as np

import torch
import torch.nn as nn
from torch.nn.functional import softplus, softmax
from torch.distributions import constraints
from torch.optim import Adam
import seaborn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape
from pyro.optim import MultiStepLR
from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO
from pyro.ops.indexing import Vindex
import pyro.contrib
import tqdm
import math
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
foo='hello'
bar='world'
import os
import sklearn
from sklearn import cluster
import pandas as pd
import re
import scanpy as sc
print(f"{foo=} {bar=}")
from torch import exp
from IPython.display import Audio, display
def allDone():
    display(Audio(url='https://notification-sounds.com/soundsfiles/Meditation-bell-sound.mp3', autoplay=True))
sc.settings.figdir=os.path.expanduser('~/WbFigures/SpeciesDivergenceNoScaling')
    

In [None]:
import sys
sys.path.append('/home/matthew.schmitz/Matthew/code/scANTIPODE/')
import model_functions
from model_functions import *
import model_distributions
from model_distributions import *
import model_modules
from model_modules import *

import importlib
model_functions=importlib.reload(model_functions)
from model_functions import *

In [None]:
smoke_test=False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
#device="cpu"


In [None]:
#adata=sc.read(os.path.expanduser('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/taxtest/HvQvM/HvQvM_900k_RPCA_clusters.h5ad'))
adata=sc.read(os.path.expanduser('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/taxtest/HvQvM/HvQvMall_part1.h5ad'))

species_arg=adata.obs['species'].cat.codes
species_values=torch.nn.functional.one_hot(torch.tensor(adata.obs['species'].cat.codes).long(),num_classes=len(adata.obs['species'].cat.categories)).float()

adata.uns['batch_cats']=dict(zip([str(x) for x in adata.obs['batch_name'].cat.categories],[str(x) for x in sorted(set(adata.obs['batch_name'].cat.codes))]))
batch_arg=adata.obs['batch_name'].replace(adata.uns['batch_cats']).astype(int)
batch_values=torch.nn.functional.one_hot(torch.tensor(adata.obs['batch_name'].replace(adata.uns['batch_cats']).astype(int)).long(),
                                         num_classes=len(adata.uns['batch_cats'].keys())).float()


adata.X=adata.raw.X[:,adata.raw.var.index.isin(adata.var.index)]#.todense()
#adata=adata[np.random.choice(adata.obs.index,100000,replace=False),:]
#sc.pp.highly_variable_genes(adata,n_top_genes=5000,flavor="seurat_v3",subset=True)

adata=adata[~adata.obs.general_region.isin(['Cultured']),:]
adata=adata[~adata.obs.batch_name.str.contains('P10'),:]
#adata=adata[adata.obs['general_region'].isin(['bn','cp','ctx','ge']),:]

adata.obs['species']=adata.obs['species'].astype('category')
sc.pl.umap(adata,color=['species'])
sc.pl.umap(adata,color=['leiden'],legend_loc="on data")
sc.pl.umap(adata,color='region')
adata.obsm['X_original_umap']=adata.obsm['X_umap']
adata.obs['n_counts']=adata.X.sum(1)


In [None]:
sc.pl.umap(adata,color='region')

In [None]:
#adata.obs['batch_name']=adata.obs['batch_name'].astype(str)
#adata.obs['batch_name']=adata.obs['batch_name'].astype('category')
#adata.uns['batch_cats']=dict(zip([str(x) for x in adata.obs['batch_name'].cat.categories],[str(x) for x in sorted(set(adata.obs['batch_name'].cat.codes))]))

In [None]:
samplegroups={'FB':'vt',
'bn':'vt',
'ctx':'ctx',
'cp':'vt',
'de':'de',
'ge':'vt',
'h':'vt',
'hb':'hb',
'mb':'mb'}
adata.obs['sample_region']=adata.obs['general_region'].replace(samplegroups)
adata.obs['region_species']=adata.obs['sample_region'].astype(str)+'_'+adata.obs['species'].astype(str)


In [None]:
#sc.pl.umap(adata,color='sample_region')

In [None]:
adata.obs['general_region'].value_counts().plot.bar()

In [None]:
adata.obs['sample_region'].value_counts().plot.bar()

In [None]:
adata.obs.groupby('species')['sample_region'].value_counts(normalize=True).unstack().T.plot.bar()

In [None]:
batch_values=torch.nn.functional.one_hot(torch.tensor(adata.obs['batch_name'].cat.codes).long(),num_classes=len(adata.obs['batch_name'].cat.categories)).float()
adata.obs['batch_name']

# Put everything together

In [None]:
# Packages the antipode model and guide as a PyTorch nn.Module

class ANTIPODE(nn.Module):
    """
    scANTIPODE is a variational inference model that is developed for the analysis and 
    categorization of cell types across evolution based on single-cell data. 

    Parameters:
    num_var (int): Number of variables (features) in the dataset.
    l_loc (float): The location parameter for the length scale of the Gaussian Process.
    l_scale (float): The scale parameter for the length scale of the Gaussian Process.
    level_sizes (list of int, optional): Sizes of each level in the model's hierarchical structure. 
                                         Defaults to [1, 10, 20, 50].
    num_latent (int, optional): Number of latent dimensions. Defaults to 10.
    scale_factor (float, optional): Scaling factor for data normalization. Defaults to 1.0.
    num_species (int, optional): Number of species (or groups) in the dataset. Defaults to 1.
    num_batch (int, optional): Number of batches to correct for in the data. Defaults to 1.
    prior_scale (float, optional): Scale of the laplace prior distributions. Defaults to 100.
    tree_approx (bool, optional): Flag to indicate whether to use a tree approximation. Defaults to False. Doesn't work well.
    bi_depth (int, optional): Depth of the tree for the approximation of batch by identity effects. Defaults to 2.

    """
    def __init__(self, num_var, l_loc, l_scale,level_sizes=[1,10,20,50],
                 num_latent=10, scale_factor=1.0,
                 num_species=1,num_batch=1,prior_scale=100,
                 tree_approx=False,bi_depth=1,batch_embed=10):
        
        self.pi=3.14159265
        self.num_var = num_var
        self.num_species = num_species
        self.num_batch = num_batch
        self.num_latent = num_latent
        self.scale_factor=scale_factor
        self.level_sizes=level_sizes
        self.num_labels = np.sum(self.level_sizes)
        self.level_indices=np.cumsum([0]+self.level_sizes)
        #how deep in the tree should batch by identity be accounted for
        self.bi_depth=bi_depth
        self.bi_depth_num=sum(self.level_sizes[:self.bi_depth])
        self.batch_embed=batch_embed
        
        # The next two hyperparameters determine the prior over the log_count latent variable `l`
        self.l_loc = l_loc
        self.l_scale = l_scale

        self.dm=DM(self)
        self.bm=BM(self)
        self.di=DI(self)
        self.bie=BIEmbed(self)
        self.ci=CI(self)
        self.zdw=ZDW(self)
        self.dc=DC(self)
        self.bc=BC(self)
        self.zs=ZScale(self)
        self.zl=ZLoc(self)
        self.zld=ZLocDynam(self)
        self.tree_edges=TreeEdges(self,straight_through=True)
        self.tree_convergence=TreeConvergence(self)        
        self.z_transform=null_function#centered_sigmoid#torch.special.expit
        
        super().__init__()
        # Setup the various neural networks used in the model and guide
        self.z_decoder=ZDecoder(num_latent=self.num_latent, num_var=num_var, hidden_dims=[])        
        self.zl_encoder=ZLEncoder(num_var=num_var,hidden_dims=[6000,3000,1000],num_cat_input=self.num_species,
                    outputs=[(self.num_latent,None),(self.num_latent,softplus),(1,None),(1,softplus)])
        
        self.classifier=Classifier(num_latent=self.num_latent,
                    outputs=[(self.num_labels,safe_sigmoid),(1,None),(1,softplus)])

        #self.bc_nn=SimpleFFNN(in_dim=self.num_batch,hidden_dims=[200,200,50,5],
        #            out_dim=self.num_var*self.num_latent)
        
        #Too large to exactly model gene-level batch effects for all cluster x batch. so embed
        self.be_nn=SimpleFFNN(in_dim=self.num_batch,hidden_dims=[1000,500,200],
                    out_dim=self.batch_embed)
        
        self.epsilon = 0.006
        #Initialize model in approximation mode
        self.approx=False
        self.hard=False
        #Whether to approximate tree connectivity during approximation mode
        self.tree_approx=tree_approx#Approx has major integration problems, but could still be useful

        self.prior_scale=prior_scale

        
    def set_approx(self,b: bool):
        self.approx=b
        
    def set_hard(self,b: bool):
        self.hard=b

    def model(self, s=None,species=None,batch=None,y1=None):
        # Register various nn.Modules (i.e. the decoder/encoder networks) with Pyro
        pyro.module("antipode", self)
        if batch is None:
            batch=s.new_zeros((s.shape[0],self.num_batch))
        if species is None:
            species=s.new_zeros((s.shape[0],self.num_species))
        species_ind=species.argmax(1)
        batch_ind=batch.argmax(1)

        with poutine.scale(scale=self.scale_factor):
            # "This gene-level parameter modulates the variance of the observation distribution"
            s_theta = pyro.param("s_inverse_dispersion", 50.0 * s.new_ones(self.num_var),
                               constraint=constraints.positive)

            locs=self.zl.model_sample(s)
            scales=self.zs.model_sample(s)
            #scales=self.zs.make_params(s)
            locs_dynam=self.zld.model_sample(s)
            species_dm=self.dm.model_sample(s)
            species_di=self.di.model_sample(s)
            batch_dm=self.bm.model_sample(s)
            cluster_intercept=self.ci.model_sample(s)
            z_decoder_weight=self.zdw.model_sample(s)
            #z_decoder_weight=self.zdw.make_params(s)
            level_edges=self.tree_edges.model_sample(s,approx=self.approx&self.tree_approx)
            species_dc=self.dc.model_sample(s)
            
            l_mu = pyro.param("l_mu", self.l_loc * s.new_ones(self.num_batch,1))
            l_scale = pyro.param("l_scale",  self.l_scale * s.new_ones(self.num_batch,1),
                               constraint=constraints.positive) 
            scd=pyro.param("species_common_de", s.new_zeros((self.num_species,self.num_var)))
            bie=self.bie.model_sample(s)
            
            # We scale all sample statements by scale_factor so that the ELBO loss function
            # is normalized wrt the number of datapoints and genes.
            # This helps with numerical stability during optimization.
            with pyro.plate("batch", s.shape[0]):
                l = pyro.sample("l", dist.LogNormal(l_mu[batch_ind], l_scale[batch_ind]).to_event(1))
                beta_prior_a=1.*s.new_ones(self.num_labels)
                beta_prior_a[0]=10.

                if self.approx:
                    y1 = pyro.sample("y1", dist.Beta(beta_prior_a,1.*s.new_ones(self.num_labels),validate_args=True).to_event(1))
                    l = pyro.sample("l_obs", dist.LogNormal(l.log(), s.new_ones(s.shape[0],1)).to_event(1),obs=s.sum(1).unsqueeze(-1))
                else:
                    y1 = pyro.sample("y1", dist.Beta(beta_prior_a,s.new_ones(self.num_labels),validate_args=True).to_event(1))
                    y1 = pyro.sample('y1_ber',dist.RelaxedBernoulli(temperature=0.1*s.new_ones(1),probs=y1).to_event(1))
                    l = pyro.sample("l_obs", dist.LogNormal(l.log(), s.new_ones(s.shape[0],1)).to_event(1),obs=s.sum(1).unsqueeze(-1))

                self.tree_convergence.model_sample(y1,level_edges,s)
                #bc=self.bc.model_sample(s)
                bi=torch.einsum('bi,ijk->bjk',self.be_nn(batch),bie)
                bi=torch.einsum('bj,bjk->bk',y1[...,:self.bi_depth_num],bi)
                psi = pyro.sample('psi',dist.Normal(s.new_zeros(s.shape[0],1),1*s.new_ones(s.shape[0],1)).to_event(1))
                this_locs=oh_index(locs,y1)
                this_scales=oh_index(scales,y1)
                if self.hard:
                    z_loc=pyro.sample('z_loc', dist.Normal(this_locs,this_scales+self.epsilon,validate_args=True).to_event(1))
                z=pyro.sample('z', dist.Normal(this_locs,this_scales+self.epsilon,validate_args=True).to_event(1))
                z=z+oh_index2(species_dm[species_ind],y1) + oh_index2(batch_dm[batch_ind],y1)+(oh_index(locs_dynam,y1)*psi)
                z=self.z_transform(z)
                cur_species_di=oh_index2(species_di[species_ind],y1)
                cur_cluster_intercept=oh_index(cluster_intercept,y1)
                mu=torch.einsum('...bi,...bij->...bj',z,z_decoder_weight+species_dc[species_ind])#+bc
                spliced_mu=mu+scd[species_ind]+cur_species_di+cur_cluster_intercept+bi
                spliced_out=torch.softmax(spliced_mu,dim=-1)
                log_mu = (l * spliced_out + 1e-6).log()
                s_dist = dist.NegativeBinomial(total_count=s_theta,logits=log_mu-s_theta.log(),validate_args=True)
                s_out=pyro.sample("s", s_dist.to_event(1), obs=s.int())

    
    # The guide specifies the variational distribution
    def guide(self, s, species=None,batch=None,y1=None):
        pyro.module("antipode", self)
        if batch is None:
            batch=s.new_zeros((s.shape[0],self.num_batch))
        if species is None:
            species=s.new_zeros((s.shape[0],self.num_species))
        species_ind=species.argmax(1)
        batch_ind=batch.argmax(1)
        with poutine.scale(scale=self.scale_factor):
            locs=self.zl.guide_sample(s)
            scales=self.zs.guide_sample(s)
            #scales=self.zs.make_params(s)          
            locs_dynam=self.zld.guide_sample(s)
            z_decoder_weight=self.zdw.guide_sample(s)
            #Scale by vector of y to answer long standing sampling problem :)
            species_dm=self.dm.guide_sample(s)
            batch_dm=self.bm.guide_sample(s)
            species_di=self.di.guide_sample(s)
            cluster_intercept=self.ci.guide_sample(s)
            species_dc=self.dc.guide_sample(s)
            level_edges=self.tree_edges.guide_sample(s,approx=self.approx&self.tree_approx) 
            self.bie.guide_sample(s)
           
            with pyro.plate("batch", s.shape[0]):
                z_loc, z_scale,l_loc, l_scale= self.zl_encoder(s,species)
                l=pyro.sample("l", dist.LogNormal(l_loc, l_scale+self.epsilon).to_event(1))
                if self.hard:
                    z=pyro.sample('z_loc',dist.Delta(z_loc).to_event(1))
                else:
                    z=pyro.sample('z',dist.Normal(z_loc,z_scale+self.epsilon).to_event(1))
                z=self.z_transform(z)
                y1_probs,psi_loc,psi_scale=self.classifier(z)
                pyro.sample('psi',dist.Normal(psi_loc,psi_scale).to_event(1))
                y1_dist = dist.Delta(y1_probs,validate_args=True).to_event(1)
                y1 = pyro.sample("y1", y1_dist)
                if not self.approx:
                    y1 = pyro.sample('y1_ber',dist.RelaxedBernoulli(temperature=0.1*s.new_ones(1),probs=y1).to_event(1))
                if self.hard:
                    this_locs=oh_index(locs,y1)
                    this_scales=oh_index(scales,y1)
                    z=pyro.sample('z',dist.Normal(this_locs,this_scales+self.epsilon).to_event(1))
                
                self.tree_convergence.guide_sample(y1,level_edges,s)
                #self.bc.guide_sample(self.bc_nn,batch,s)

In [None]:
try:
    del antipode_model
    torch.cuda.empty_cache()
except:
    pass

In [None]:
import sys
sys.path.append('/home/matthew.schmitz/Matthew/code/scANTIPODE/')
import model_functions
from model_functions import *
import model_distributions
from model_distributions import *
import model_modules
from model_modules import *

import importlib
model_modules=importlib.reload(model_modules)
from model_modules import *

model_functions=importlib.reload(model_functions)
from model_functions import *


In [None]:
device='cuda'
num_var=adata.shape[1]
l_mean=np.log(adata.X.sum(1)).mean()
l_scale=np.log(adata.X.sum(1)).std()
batch_size=32
level_sizes=[1,25,75,225]
num_latent=200
num_labels=sum(level_sizes)
steps=0
max_steps=400000
print_every=5000

# Clear Pyro param store so we don't conflict with previous run
pyro.clear_param_store()
# Fix random number seed to a lucky number
pyro.util.set_rng_seed(13)
# Enable optional validation warnings
pyro.enable_validation(True)

# Instantiate instance of model/guide and various neural networks
antipode_model = ANTIPODE(num_var=num_var, num_latent=num_latent,level_sizes=level_sizes,
                l_loc=l_mean, l_scale=l_scale,num_species=len(adata.obs['species'].unique()),
                num_batch=len(adata.obs['batch_name'].unique()),
                scale_factor=1e2 / (3*batch_size * num_var * num_labels * num_latent),
                bi_depth=len(level_sizes),tree_approx=True,prior_scale=100.,batch_embed=10)


# Setup an optimizer (Adam) and learning rate scheduler.
#Use OneCycleLR to give better convergence
scheduler=pyro.optim.OneCycleLR({'max_lr':0.002,'total_steps':max_steps,'div_factor':100,'optim_args':{},'optimizer':torch.optim.Adam})

# Setup a variational objective for gradient-based learning.
# Note we use TraceEnum_ELBO in order to leverage Pyro's machinery
# for automatic enumeration of the discrete latent variable y.
elbo = pyro.infer.JitTrace_ELBO(num_particles=3,strict_enumeration_warning=False)

svi = SVI(antipode_model.model, antipode_model.guide, scheduler, elbo)

In [None]:
device='cuda'
antipode_model.train()
antipode_model.zl_encoder.train()
num_epochs=2

grad_check=False
from collections import defaultdict
param_d=defaultdict(list)

adata_toppath=os.path.expanduser('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/taxtest/HvQvM')
adata_paths=os.listdir(adata_toppath)
adata_paths=sorted([os.path.join(adata_toppath,x) for x in adata_paths if 'HvQvMall_part' in x])

antipode_model=antipode_model.to(device)
antipode_model.set_approx(False)#Approx doesn't work very well for integration
loss_tracker=[]
#for steps in range(max_steps):
pbar = tqdm.tqdm(total=max_steps, position=0)
while steps < max_steps:
    for adata_path in reversed(adata_paths[1:]):
        #pbar.write(adata_path)
        #TODO integrate SCVI dataloader with proper chunking
        dataloader=make_dataloader(origdata=adata,adata_path=os.path.expanduser(adata_path),batch_size=batch_size)
        for x in dataloader:
            args=[y.to(device) for y in x]
            loss=svi.step(*args)
            scheduler.step()
            steps+=1
            pbar.update(1)
            loss_tracker.append(loss)
            if ((steps%print_every == 0) | (steps%print_every == 1))& grad_check:
                for p in pyro.get_param_store():
                    param_d[p].append(pyro.param(p).clone().detach().cpu())
            if steps%print_every == 0:
                # Tell the scheduler we've done one epoch.
                #print("[Step %02d]  Loss: %.5f" % (steps, np.mean(loss_tracker[-print_every:])))
                pbar.write("[Step %02d]  Loss: %.5f" % (steps, np.mean(loss_tracker[-print_every:])))

pbar.close()
allDone()
print("Finished training!")

In [None]:
svi.optim.get_state()['locs']['scheduler']

In [None]:
seaborn.scatterplot(x=list(range(len(loss_tracker))),y=loss_tracker,alpha=0.5,s=2)
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
w=300
mvavg=moving_average(np.pad(loss_tracker,int(w/2),mode='edge'),w)
seaborn.lineplot(x=list(range(len(mvavg))),y=mvavg,color='coral')



In [None]:
if grad_check:
    param_delta=dict()
    for p in param_d.keys():
        param_delta[p]=[param_d[p][i]-param_d[p][i+1] for i in range(0,len(param_d[p])-1,2)]

    meandelta=dict()
    for p in param_delta.keys():
        meandelta[p]=np.array([x.abs().mean().numpy() for x in param_delta[p]])

    for p in meandelta.keys():
        if '' in p:#edges
            ax=seaborn.lineplot(x=list(range(len(meandelta[p]))),y=meandelta[p],label=p,)
            seaborn.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
            ax
        #plt.show()

In [None]:
"""
antipode_model.dc_batch_encoder.cpu()
seaborn.histplot(antipode_model.dc_batch_encoder(torch.eye(batch_values.shape[1])).detach().cpu().numpy().flatten()[:10000],bins=100)
plt.show()
antipode_model.dc_batch_encoder.to(device)
"""

In [None]:
seaborn.histplot(pyro.param('z_decoder_weight').detach().cpu().numpy().flatten(),bins=50)

In [None]:
seaborn.clustermap(antipode_model.z_transform(pyro.param('locs')).cpu().detach().numpy(),cmap='coolwarm')

In [None]:
seaborn.clustermap(pyro.param('scales').cpu().detach().numpy())

In [None]:
seaborn.histplot(pyro.param('s_inverse_dispersion').detach().cpu().numpy().flatten(),bins=50)


In [None]:
seaborn.histplot(pyro.param('z_decoder_weight').var(1).detach().cpu().numpy().flatten(),bins=50)


In [None]:
seaborn.heatmap(pyro.param('edges_1').cpu().detach().numpy())
plt.show()
seaborn.heatmap(pyro.param('edges_2').cpu().detach().numpy())
plt.show()

In [None]:
seaborn.heatmap(torch.softmax(pyro.param('edges_1'),dim=-1).cpu().detach().numpy())
plt.show()
seaborn.heatmap(torch.softmax(pyro.param('edges_2'),dim=-1).cpu().detach().numpy())
plt.show()

In [None]:
seaborn.histplot(pyro.param('z_decoder_weight').detach().cpu().numpy().flatten(),color='yellow',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_dc')[0].cpu().detach().numpy().flatten(),color='red',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_dc')[1].cpu().detach().numpy().flatten(),color='blue',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_dc')[2].cpu().detach().numpy().flatten(),color='green',bins=50,stat='percent')

In [None]:
print(pyro.param('z_decoder_weight').var())
print(pyro.param('species_dc').flatten(1,2).var(1))

In [None]:
seaborn.histplot(pyro.param('locs').cpu().detach().numpy().flatten(),color="green",label='ancestral',bins=50)
seaborn.histplot(pyro.param('species_dm')[0,:,:].cpu().detach().numpy().flatten(),color="red",bins=50)
seaborn.histplot(pyro.param('species_dm')[1,:,:].cpu().detach().numpy().flatten(),color="pink",bins=50)
seaborn.histplot(pyro.param('species_dm')[2,:,:].cpu().detach().numpy().flatten(),color="purple",bins=50)
plt.legend()

In [None]:
seaborn.clustermap(pyro.param('locs_dynam').cpu().detach().numpy())

In [None]:
seaborn.histplot(pyro.param('z_decoder_weight').detach().cpu().numpy().flatten(),bins=50)

In [None]:
seaborn.histplot(pyro.param('z_decoder_weight').var(1).detach().cpu().numpy().flatten(),bins=50)


In [None]:
seaborn.histplot(pyro.param('cluster_intercept').cpu().detach().numpy().flatten(),color='yellow',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_di')[0].cpu().detach().numpy().flatten(),color='red',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_di')[1].cpu().detach().numpy().flatten(),color='blue',bins=50,stat='percent')
seaborn.histplot(pyro.param('species_di')[2].cpu().detach().numpy().flatten(),color='green',bins=50,stat='percent')


In [None]:
device='cpu'
antipode_model=antipode_model.to(device)
avgpool=nn.AvgPool1d(2,stride=1)
# Now that we're done training we'll inspect the latent representations we've learned
import scanpy as sc
adata.obsm["X_umap"]=adata.obsm["X_original_umap"]
# Put the neural networks in evaluation mode (needed because of batch norm)
antipode_model.eval()
antipode_model.zl_encoder.eval()

# Compute latent representation (z_loc) for each cell in the dataset
inputs=[torch.tensor(adata.layers['spliced'].todense()).to(device),species_values]
function=antipode_model.zl_encoder
num_outs=4
encoded=batch_torch_outputs(inputs,function,batch_size=2048,device='cuda')

clusternames=[]
for i in range(antipode_model.num_latent):
    clusternames.append(str(i))

adata.obs['LDA0max']=np.array(clusternames)[encoded[0].cpu().detach().numpy().argmax(1)]

sc.pl.umap(adata,color=['LDA0max'],legend_loc="on data",palette=sc.pl.palettes.godsnot_102)
sc.pl.umap(adata,color=['LDA0max'],palette=sc.pl.palettes.godsnot_102)
"""
for i in range(num_latent):
    adata.obs['test'+str(i)]=torch.softmax(encoded[0],dim=-1)[:,i].cpu().detach().numpy()
sc.pl.umap(adata,color=['test'+str(x) for x in range(num_latent)],use_raw=False,cmap='Purples')
"""
seaborn.clustermap(np.corrcoef((antipode_model.z_transform(encoded[0])).T.cpu().detach().numpy()))


In [None]:
seaborn.scatterplot(x=pyro.param('l_mu').detach().cpu().flatten(),y=pyro.param('l_scale').detach().cpu().flatten())

In [None]:
seaborn.scatterplot(x=pyro.param('l_mu').detach().cpu().flatten()[batch_values.argmax(1)],y=encoded[2].detach().cpu().flatten(),alpha=0.1)

In [None]:
seaborn.scatterplot(x=pyro.param('l_scale').detach().cpu().flatten()[batch_values.argmax(1)],y=encoded[3].detach().cpu().flatten())

In [None]:
seaborn.scatterplot(x=encoded[2].detach().cpu().flatten(),y=np.log(adata.obs['n_counts']))

In [None]:
classouts=batch_torch_outputs([(antipode_model.z_transform(encoded[0]))],antipode_model.classifier,batch_size=2048,device='cuda')
o2=classouts[0]
adata.obs['psi']=classouts[1].cpu().detach().numpy()

In [None]:
y1=classouts[0].detach().cpu().numpy()
levels=[y1[...,antipode_model.level_indices[i]:antipode_model.level_indices[i+1]] for i in range(len(antipode_model.level_sizes))]
for i in range(len(levels)):
    adata.obs['level_'+str(i)]=levels[i].argmax(1)
    adata.obs['level_'+str(i)]=adata.obs['level_'+str(i)].astype(str)

adata.obs['antipode_model_cluster']=adata.obs['level_1']+'_'+adata.obs['level_2']+'_'+adata.obs['level_3']

In [None]:
#TODO encapsulate into a function
y1=classouts[0].detach().cpu().numpy()

levels = [y1[..., antipode_model.level_indices[i]:antipode_model.level_indices[i + 1]]
          for i in range(len(antipode_model.level_sizes))]


for i, level in enumerate(levels):
    max_probs = level.max(axis=1)
    # Assign argmax if max probability > 0.5, else assign empty string
    adata.obs['level_' + str(i)] = np.where(max_probs > 0.5, level.argmax(axis=1).astype(str), '')

# Concatenate the levels to form the cluster label
adata.obs['antipode_model_cluster'] = adata.obs.apply(lambda x: '_'.join([x['level_' + str(i)] for i in range(len(levels))]), axis=1)

In [None]:
#TODO encapsulate into a function
levels=[y1[...,antipode_model.level_indices[i]:antipode_model.level_indices[i+1]] for i in range(len(antipode_model.level_sizes))]

level_edges=[safe_softmax(pyro.param('edges_'+str(i))).detach().cpu().numpy() for i in range(len(antipode_model.level_sizes)-1)]

#But still need to propagate edges to get the root value (check for cycles)
results=[levels[-1]]
for i in range(len(antipode_model.level_sizes) - 2, -1, -1):
    result=levels[i+1]@level_edges[i]
    results.append(result)
results=results[::-1]

for i in range(len(results)):
    adata.obs['prop_level_'+str(i)]=results[i].argmax(1)
    adata.obs['prop_level_'+str(i)]=adata.obs['prop_level_'+str(i)].astype(str)

adata.obs['prop_antipode_model_cluster']=adata.obs['prop_level_1']+'_'+adata.obs['prop_level_2']+'_'+adata.obs['prop_level_3']

In [None]:
vcs=adata.obs['antipode_model_cluster'].value_counts()
legit_clusters=vcs.index[vcs>100]
adata.obs.loc[~adata.obs['antipode_model_cluster'].isin(legit_clusters),'antipode_model_cluster']='nan'

vcs=adata.obs['prop_antipode_model_cluster'].value_counts()
legit_clusters=vcs.index[vcs>100]
adata.obs.loc[~adata.obs['prop_antipode_model_cluster'].isin(legit_clusters),'prop_antipode_model_cluster']='nan'

In [None]:
sc.pl.umap(adata,color=['level_1','level_2','level_3','antipode_model_cluster'],use_raw=False,legend_loc=None,palette=sc.pl.palettes.godsnot_102)

In [None]:
sc.pl.umap(adata,color=['prop_level_1','prop_level_2','prop_level_3','antipode_model_cluster'],use_raw=False,legend_loc=None,palette=sc.pl.palettes.godsnot_102)

In [None]:
seaborn.clustermap(o2[:10000,:],row_cluster=False)

In [None]:
seaborn.histplot(np.log10(o2.sum(0).cpu().detach().numpy()+1))

In [None]:
seaborn.clustermap(o2[:1000,:]>0.5)

In [None]:
adata.obsm['X_antipode_model']=(antipode_model.z_transform(encoded[0])).detach().data.cpu().numpy()
random_choice=np.random.choice(adata.obs.index,size=100000,replace=False)
random_choice=np.where(adata.obs.index.isin(random_choice))[0]
xdata=adata[random_choice,:]
sc.pp.neighbors(xdata,n_neighbors=20, use_rep="X_antipode_model")
sc.tl.umap(xdata)
sc.pl.umap(xdata,color=['general_region','leiden','species'],use_raw=False,legend_loc='on data',palette=sc.pl.palettes.godsnot_102)

In [None]:
sc.pl.umap(xdata,color=['psi'],use_raw=False,legend_loc='on data',cmap='coolwarm')

In [None]:
sc.pl.umap(xdata,color=['level_1','level_2','level_3','antipode_model_cluster'],use_raw=False,legend_loc='on data',palette=sc.pl.palettes.godsnot_102)

In [None]:
sc.pl.umap(xdata,color=['antipode_model_cluster','prop_antipode_model_cluster'],use_raw=False,palette=sc.pl.palettes.godsnot_102)

In [None]:
xdata.X=xdata.raw.X[:,xdata.raw.var.index.isin(xdata.var.index)]
sc.pp.normalize_per_cell(xdata)
sc.pp.log1p(xdata)
sc.pp.scale(xdata,max_value=10)

In [None]:
sc.pl.umap(xdata,color=['GBX2','EOMES','SIX3','OTX2','FOXG1','RBFOX3','TH','PDGFRA','AQP4','FOXJ1','AIF1','TTR','MOG','COL1A2','CD34','COL4A1','NPY','NKX2-1','FOXP2','SATB2','RORB','FEZF2','EMX1'],use_raw=False,cmap='Purples')

In [None]:
sc.pl.umap(xdata,color=['DLX2','PROX1','SCGN','TSHZ1','MEIS2','NKX2-1','LHX6','CRABP1','TSHZ1','FOXG1','PDGFRA','AIF1','AQP4','EDNRB','FOXJ1','CD34','MKI67'],cmap='Purples',use_raw=False)


In [None]:
sc.pl.umap(xdata,color=['RPL7','RPS17','RPL13A','MEF2C'],cmap='Purples',use_raw=False)


In [None]:
seaborn.clustermap(o2.T@o2/o2.sum(0))

In [None]:
sc.pl.umap(xdata,color=['batch_name'],use_raw=False,palette=sc.pl.palettes.godsnot_102)

In [None]:
seaborn.histplot(antipode_model.be_nn(torch.eye(464)).cpu().detach().flatten().numpy())