In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pyro
from pyro.optim import Adam, ClippedAdam
import congas as cg
from congas.models import LatentCategorical
from pyro.infer import TraceMeanField_ELBO,TraceEnum_ELBO, TraceGraph_ELBO, Trace_ELBO


In [3]:
import pickle

data_file = open("data.pkl",'rb')

data = pickle.loads(data_file.read())
data_file.close()

param_file =  open("parameters.pkl",'rb')

param = pickle.loads(param_file.read())
param_file.close()


interface = cg.Interface(CUDA = False)

In [4]:
I  = 15
sim_data = cg.simulate_data_congas(ncells = 5000, seed = 16, subclones=8, 
                                   mixture_proportion_rna = torch.ones(8) / 8,
                                  mixture_proportion_atac = torch.ones(8) / 8, n_segments=I)
sim_data

{'atac': {'exp': array([[  82.,  309.,  279., ...,  128.,   54.,   27.],
         [  38.,  238.,  293., ...,  120.,  187.,  117.],
         [ 330.,  178.,  122., ...,   79.,  294.,  201.],
         ...,
         [ 753., 1097.,  326., ...,  218.,  252., 1179.],
         [ 121.,  205.,  189., ...,  302.,  391.,  506.],
         [ 293.,  375.,  319., ...,  435.,  536.,  324.]]),
  'MAF': array([[0.37053729, 0.32253585, 0.4158179 , ..., 0.28930478, 0.40497724,
          0.2818496 ],
         [0.24322206, 0.39135757, 0.34969442, ..., 0.37679794, 0.37214288,
          0.36666653],
         [0.41908695, 0.38388334, 0.35205534, ..., 0.38393695, 0.37789072,
          0.22120664],
         ...,
         [0.39910356, 0.41747507, 0.47764789, ..., 0.3003187 , 0.23592494,
          0.28347963],
         [0.23674662, 0.33409717, 0.36924853, ..., 0.30981061, 0.33964248,
          0.19481125],
         [0.43433506, 0.31206319, 0.21098797, ..., 0.2784139 , 0.36360118,
          0.37149347]]),
  'labels'

In [5]:
sim_data["rna"]["exp"].shape

(5000, 15)

In [6]:
data = {
    "data_atac" : torch.tensor(sim_data["atac"]["exp"]).T.float(),
    "data_rna" : torch.tensor(sim_data["rna"]["exp"]).T.float(),
    "norm_factor_atac" : torch.tensor(sim_data["atac"]["norm_factors"]).float(),
    "norm_factor_rna" : torch.tensor(sim_data["rna"]["norm_factors"]).float(),
    "pld" : torch.tensor(sim_data['CNA_bulk']['total'].values).float(), 
}

In [7]:
param["binom_prior_limits"] = torch.tensor([10,1000])
param["theta_shape_atac"] = torch.tensor(sim_data["atac"]["theta_shape_atac"]).float()
param["theta_rate_atac"] = torch.tensor(sim_data["atac"]["theta_rate_atac"]).float()
param["theta_shape_rna"] = torch.tensor(sim_data["rna"]["theta_shape_rna"]).float()
param["theta_rate_rna"] = torch.tensor(sim_data["rna"]["theta_rate_rna"]).float()
param['nb_size_init_atac'] = torch.tensor([150., 150., 150., 150.])
param["K"] = 3
param['probs'] = torch.tensor([0.2000, 0.6000, 0.0500, 0.0250, 0.0250])
param["likelihood_atac"] = "NB"
param["likelihood_rna"] = "NB"
param["temperature"] = 10
param["lambda"] = 1
param["purity"] = None
param["nb_size_init_atac"] = torch.ones(I) * 150
param["nb_size_init_rna"] = torch.ones(I) * 150
param["CUDA"] = False
param["multiome"] = True



In [8]:
interface.set_model(LatentCategorical)
interface.set_optimizer(ClippedAdam)
interface.set_loss(Trace_ELBO)
interface.initialize_model(data)
interface.set_model_params(param)


In [9]:
ll = interface.run(steps = 20, param_optimizer = {"lr":0.1})

ELBO: 7.218102623  : 100%|██████████| 20/20 [00:01<00:00, 13.97it/s] 


Done!





In [10]:
import numpy as np


lr = interface.learned_parameters()
ICs = interface.calculate_ICs()


Computing assignment probabilities
Computing information criteria.


In [11]:
np.sum(lr["assignment_atac"] == 0)

1117

In [12]:
np.sum(lr["assignment_rna"] == 0)

1117

In [13]:
lr

{'mixture_weights': array([0.15130064, 0.12047607, 0.72822326], dtype=float32),
 'NB_size_rna': array([29.572866, 29.572866, 29.572866, 29.572866, 29.572866, 29.572866,
        29.572866, 29.572866, 29.572866, 29.572866, 29.572866, 29.572866,
        29.572866, 29.572866, 29.572866], dtype=float32),
 'segment_factor_rna': array([69.07002 , 70.03304 , 71.405876, 67.96777 , 82.21541 , 74.76908 ,
        67.96777 , 67.96777 , 67.96777 , 78.02877 , 78.8904  , 67.976364,
        74.759636, 74.76908 , 80.389626], dtype=float32),
 'NB_size_atac': array([150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150.,
        150., 150., 150., 150.], dtype=float32),
 'segment_factor_atac': array([18.49565, 18.49565, 18.49565, 18.49565, 18.49565, 18.49565,
        18.49565, 18.49565, 18.49565, 18.49565, 18.49565, 18.49565,
        18.49565, 18.49565, 18.49565], dtype=float32),
 'CNV_probabilities': array([[[0.00730497, 0.00730497, 0.01110983, 0.48714015, 0.48714015],
         [0.00731977, 0.00