<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/VI_DAMM_0902.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F

import torch.nn as nn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

from matplotlib.backends.backend_pdf import PdfPages

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA

import seaborn as sns

%matplotlib inline

In [2]:
#%%capture
#!pip install wandb --upgrade

In [3]:
#import wandb
#wandb.login()

In [4]:
## generate data
def generateData(n_clusters = 3, n_obs = 10000, n_features = 2):

  #n_clusters = 3; n_obs = 100; n_features = 2
  
  ## set truth expression means/covariances (multivariate) ##
  mu = np.random.rand(n_clusters, n_features)
  # mu = np.sort(mu, 0) ## sort expressions
  sigma = 0.001 * np.identity(n_features) ## variance-covariance matrix

  ## set truth cell size means/variances (univariate) ##
  psi = [np.random.normal(100, 25) for i in range(n_clusters)]
  #psi = np.arange(90, 90 + 5 * n_clusters, 5)
  psi = np.sort(psi, 0)
  omega = 1 ## standard deviation
  ###

  ## set latent variables distributions ##
  lambda_arr = np.random.binomial(1, .95, n_obs) # p=.95 (a cell belongs to singlet or doublet) 

  n_singlet = np.sum(lambda_arr == 1) ## number of cells in singlet clusters
  n_doublet = np.sum(lambda_arr == 0) ## number of cells in doublet clusters
  
  lambda0_arr = n_singlet / n_obs ## proportion of cells belong to singlet
  lambda1_arr = n_doublet / n_obs ## proportion of cells belong to doublet

  #pi_arr = np.sort(np.random.sample(n_clusters))
  pi_arr = np.sort(np.random.rand(n_clusters))
  pi_arr /= pi_arr.sum()

  n_doublet_clusters = int((n_clusters * n_clusters - n_clusters)/2 + n_clusters)
  #tau_arr = np.sort(np.random.sample(n_doublet_clusters))
  tau_arr = np.sort(np.random.rand(n_doublet_clusters))
  tau_arr /= tau_arr.sum()

  ## draw cells based on defined parameters theta1 = (mu, sigma, psi, omega) & theta2 = (lambda, pi, tau)
  x = np.zeros((n_singlet, n_features+5))
  for i in range(n_singlet):
    selected_cluster = np.random.choice(n_clusters, size = 1, p = pi_arr)[0] ## select a single cell cluster
    x[i] = np.append(np.random.multivariate_normal(mu[selected_cluster], sigma),
                     [np.random.normal(psi[selected_cluster], omega), 
                      0, selected_cluster, 0, selected_cluster + n_doublet_clusters])
  
  x[x < 0] = 1e-4
  lookups = np.triu_indices(n_clusters) # wanted indices
  xx = np.zeros((n_doublet, n_features+5))
  for i in range(n_doublet):
    selected_cluster = np.random.choice(n_doublet_clusters, p = tau_arr)

    indx1 = lookups[0][selected_cluster]
    indx2 = lookups[1][selected_cluster]

    xx[i] = np.append(np.random.multivariate_normal( (mu[indx1] + mu[indx2])/2, (sigma + sigma)/2 ),
                     [np.random.normal( (psi[indx1] + psi[indx2]), omega+omega ), 
                      1, indx1, indx2, selected_cluster])
  xx[xx < 0] = 1e-4
  xxx = np.append(x, xx).reshape(n_obs, n_features+5)

  truth_theta = {
    'log_mu': np.log(mu),
    'log_sigma': np.log(sigma),
    'log_psi': np.log(psi),
    'log_omega': np.log(omega),
    "log_lambda0": np.log(lambda0_arr),
    'log_pi': np.log(pi_arr),
    'log_tau': np.log(tau_arr)
  }

  return xxx[:,:n_features], xxx[:,n_features], xxx, truth_theta

  #return torch.tensor(xxx[:,:n_features]), torch.tensor(xxx[:,n_features]), torch.tensor(xxx), [mu, sigma, psi, omega], [lambda0_arr, pi_arr, tau_arr]

In [5]:
def compute_p_y_given_z(Y, Theta):
  """ Returns NxC
  p(y_n | z_n = c)
  """
  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma'])

  dist_Y = D.Normal(mu, sigma)
  return dist_Y.log_prob(Y.reshape(Y.shape[0], 1, nf)).sum(2) # <- sum because IID over G

def compute_p_s_given_z(S, Theta):
  """ Returns NxC
  p(s_n | z_n = c)
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega'])

  dist_S = D.Normal(psi, omega)
  return dist_S.log_prob(S.reshape(-1,1)) 

def compute_p_y_given_gamma(Y, Theta):
  """ NxCxC
  p(y_n | gamma_n = [c,c'])
  """

  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma'])

  mu2 = mu.reshape(1, nc, nf)
  mu2 = (mu2 + mu2.permute(1, 0, 2)) / 2.0 # C x C x G matrix 

  sigma2 = sigma.reshape(1, nc, nf)
  sigma2 = (sigma2 + sigma2.permute(1,0,2)) / 2.0

  dist_Y2 = D.Normal(mu2, sigma2)
  return  dist_Y2.log_prob(Y.reshape(-1, 1, 1, nf)).sum(3) # <- sum because IID over G

def compute_p_s_given_gamma(S, Theta):
  """ NxCxC
  p(s_n | gamma_n = [c,c'])
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega'])

  psi2 = psi.reshape(-1,1)
  psi2 = psi2 + psi2.T

  omega2 = omega.reshape(-1,1)
  omega2 = omega2 + omega2.T

  dist_S2 = D.Normal(psi2, omega2)
  return dist_S2.log_prob(S.reshape(-1, 1, 1))

In [6]:
class BasicForwardNet(nn.Module):
  """Encoder for when data is input without any encoding"""
  def __init__(self, input_dim, output_dim, hidden_dim = 5):
    super().__init__()
    
    self.input = nn.Linear(input_dim, hidden_dim)
    self.linear1 = nn.Linear(hidden_dim, hidden_dim)
    self.output = nn.Linear(hidden_dim, output_dim)
    
  def forward(self, x):
    out = F.relu(self.input(x))
    out = F.relu(self.linear1(out))
    out = self.output(out)
        
    return F.softmax(out, dim=1), F.log_softmax(out, dim=1) ## r/v/d log_r/log_v/log_d

In [7]:
def compute_joint_probs(Theta, Y, S):

  log_pi = F.log_softmax(Theta['is_pi'])
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc,nc)
  log_delta = F.log_softmax(Theta['is_delta'])
  
  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  log_rzd0 = p_s_given_z + p_y_given_z + log_pi + log_delta[0]

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  log_vgd1 = p_y_given_gamma + p_s_given_gamma + log_tau + log_delta[1]

  #remove_indices = np.tril_indices(nc, -1) ## remove indices
  #log_rd1g[:, remove_indices[0], remove_indices[1]] = float("NaN")

  #q1 = r.exp() * log_rd0z #; q1[torch.isnan(q1)] = 0.0
  #q2 = v.exp() * log_rd1g #; q2[torch.isnan(q2)] = 0.0

  return log_rzd0, log_vgd1.reshape(Y.shape[0], nc*nc)

In [8]:
#F.sigmoid(torch.tensor([.3, .6]))

In [9]:
nc = 5; no = 1000; nf = 10

Y, S, XX, theta_true = generateData(n_clusters = nc, n_obs = no, n_features = nf)



In [10]:
N_INIT = 20

Y = np.array(Y)
S = np.array(S)

kms = [KMeans(nc).fit(Y) for i in range(N_INIT)]
inertias = [k.inertia_ for k in kms]
km = kms[np.argmin(np.array(inertias))] ## selected "best" kmeans based on inertia score
init_labels = km.labels_

mu_init = np.array([Y[init_labels == i,:].mean(0) for i in np.unique(init_labels)])
sigma_init = np.array([Y[init_labels == i,:].std(0) for i in np.unique(init_labels)])
psi_init = np.array([S[init_labels == i].mean() for i in np.unique(init_labels)])
omega_init = np.array([S[init_labels == i].std() for i in np.unique(init_labels)])
pi_init = np.array([np.mean(init_labels == i) for i in np.unique(init_labels)])
tau_init = np.ones((nc,nc))
tau_init = tau_init / tau_init.sum()

In [11]:
P = Y.shape[1] + 1
r_net = BasicForwardNet(P, nc)
v_net = BasicForwardNet(P, nc ** 2)
d_net = BasicForwardNet(P, 2)

In [12]:
Theta = {
    'log_mu': np.log(mu_init) + 0.05 * np.random.randn(mu_init.shape[0], mu_init.shape[1]),
    'log_sigma': np.log(sigma_init), #np.zeros_like(sigma_init),
    'log_psi': np.log(psi_init),
    'log_omega': np.log(omega_init),
    "is_delta": F.log_softmax(torch.tensor([0.95, 1-0.95])),
    'is_pi': F.log_softmax(torch.tensor(pi_init)),
    'is_tau': F.log_softmax(torch.tensor(tau_init))
}
Theta = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}

Theta['is_delta'].requires_grad = False
#Theta['is_pi'].requires_grad = False
#Theta['is_tau'].requires_grad = False

  
  import sys
  
  # Remove the CWD from sys.path while we load stuff.


In [13]:
Y = torch.tensor(Y)
S = torch.tensor(S)
YS = torch.hstack((Y,S.reshape(-1,1))).float()

In [14]:
N_ITER = 100000
lr = 1e-3
tol = 1e-6
params = list(Theta.values()) + list(r_net.parameters()) + list(v_net.parameters()) + list(d_net.parameters())
opt = optim.AdamW(params, lr=lr)

In [15]:
%%capture
!pip install wandb --upgrade

import wandb
wandb.login()

In [16]:
wandb.init(project='jett-vi',
           config={
    "N_EPOCHS": N_ITER,
    "LR": lr,
    "TOL": tol,
    'MODEL_TYPE': 'vi',
    'DATA_TYPE': 'toy_data'
    })

In [17]:
#elbo1 = q0 * (log_rzd0 - log_q0)
  #elbo2 = q1 * (log_vgd1 - log_q1)
  #nelbo = elbo1.sum() + elbo2.sum()
  #-elbo, entro, recon

  #assert( ((d_net(YS)[0][:,1].reshape(-1,1) * v_net(YS)[0]).sum(1) + (d_net(YS)[0][:,0].reshape(-1,1) * r_net(YS)[0]).sum(1)).sum() == 1000. )


In [18]:
F.log_softmax(Theta['is_delta']).exp()

  """Entry point for launching an IPython kernel.


tensor([0.7109, 0.2891])

In [19]:
np.exp(theta_true['log_lambda0'])

0.948

In [20]:
F.log_softmax(Theta['is_pi']).exp()

  """Entry point for launching an IPython kernel.


tensor([0.2379, 0.2205, 0.2082, 0.1656, 0.1678], dtype=torch.float64,
       grad_fn=<ExpBackward>)

In [21]:
np.exp(theta_true['log_pi'])

array([0.02103361, 0.02438582, 0.22756926, 0.31859506, 0.40841625])

In [22]:
#F.log_softmax(v.mean(0).reshape(-1)).reshape(nc, nc).exp()

In [23]:
F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc, nc).exp()

  """Entry point for launching an IPython kernel.


tensor([[0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64,
       grad_fn=<ExpBackward>)

In [24]:
np.exp(theta_true['log_tau'])

array([0.00630912, 0.01578501, 0.04260968, 0.0430699 , 0.04610766,
       0.0507849 , 0.06659788, 0.07802901, 0.08070339, 0.08388491,
       0.08659892, 0.09785327, 0.09790554, 0.10138304, 0.10237778])

In [25]:
Theta['log_mu'].exp()

tensor([[0.9598, 0.4289, 0.2520, 0.5270, 0.6664, 0.4022, 0.0882, 0.3180, 0.2526,
         0.0664],
        [0.2401, 0.0269, 0.5955, 0.9111, 0.0357, 0.2994, 0.5842, 0.9004, 0.4905,
         0.1744],
        [0.3114, 0.1989, 0.9038, 0.0275, 0.5197, 0.1303, 0.1461, 0.8533, 0.0637,
         0.0738],
        [0.1133, 0.4906, 0.2214, 0.1646, 0.1766, 0.2295, 0.5400, 0.2331, 0.3784,
         0.6832],
        [0.3939, 0.6383, 0.8990, 0.5421, 0.1434, 0.1496, 0.6564, 0.8707, 0.1226,
         0.2588]], dtype=torch.float64, grad_fn=<ExpBackward>)

In [26]:
np.exp(theta_true['log_mu'])

array([[0.00904657, 0.55598951, 0.09080346, 0.09733775, 0.11390822,
        0.23035873, 0.60357522, 0.1050592 , 0.38603035, 0.83302096],
       [0.43674605, 0.76201288, 0.95740535, 0.49211773, 0.08576602,
        0.1148113 , 0.71157188, 0.96032855, 0.0777623 , 0.2623704 ],
       [0.31573514, 0.18051475, 0.95159539, 0.00943698, 0.47352374,
        0.11908228, 0.14722634, 0.87920615, 0.06163368, 0.07974902],
       [0.24253041, 0.02465594, 0.59050483, 0.93315681, 0.0313658 ,
        0.30368694, 0.53299172, 0.91242219, 0.48882501, 0.18550026],
       [0.96780974, 0.44234748, 0.26163233, 0.55654946, 0.70978338,
        0.39712541, 0.08835898, 0.30419818, 0.24552656, 0.06124418]])

In [None]:
loss = []
for i in range(N_ITER):
  
  opt.zero_grad()

  r, log_r = r_net(YS)
  v, log_v = v_net(YS)
  d, log_d = d_net(YS)
  
  ## row sums to 1 (from neural net)
  log_q0 = log_d[:,0].reshape(-1,1) + log_r ## like r in em version
  log_q1 = log_d[:,1].reshape(-1,1) + log_v ## like v in em version

  log_rzd0, log_vgd1 = compute_joint_probs(Theta, Y, S)

  entro = (log_q0.exp() * log_q0).sum() + (log_q1.exp() * log_q1).sum()
  recon = (log_q0.exp() * log_rzd0).sum() + (log_q1.exp() * log_vgd1).sum()
  nelbo = entro - recon

  #nelbo = (log_q0.exp() * (log_q0 - log_rzd0)).sum() + (log_q1.exp() * (log_q1 - log_vgd1)).sum()

  nelbo.backward()
  opt.step()
  
  wandb.log({
    'ITER': i + 1, 
    'nelbo': nelbo.detach(),
    'entropy': entro.detach(),
    'reconstruction_loss': recon.detach(),
    'log_mu': Theta['log_mu'],
    'log_sigma': Theta['log_sigma'], #np.zeros_like(sigma_init),
    'log_psi': Theta['log_psi'],
    'log_omega': Theta['log_omega'],
    "is_delta": Theta['is_delta'],
    'is_pi': Theta['is_pi'],
    'is_tau': Theta['is_tau'],
    'r': r,
    'v': v,
    'd': d,
    })
  
  if i % (1000 - 1) == 0:
    #print("NELBO: {}; lambda: {}; pi: {}".format(nelbo.detach(), F.log_softmax(Theta['is_delta'].detach()).exp(), F.log_softmax(Theta['is_pi'].detach()).exp()))
    #print("NELBO: {}; entro: {}; recon: {}; pi: {}".format(nelbo.detach(), entro.detach(), recon.detach(), F.log_softmax(Theta['is_pi'].detach()).exp()))
    print("NELBO: {}; pi: {}".format(nelbo.detach(), F.log_softmax(Theta['is_pi'].detach()).exp()))
  
  if i > 0 and abs(loss[-1] - nelbo.detach()) < tol:
    break
           
  loss.append(nelbo.detach())

  This is separate from the ipykernel package so we can avoid doing imports until
  after removing the cwd from sys.path.
  """


NELBO: 212923.44381713963; pi: tensor([0.2382, 0.2203, 0.2081, 0.1655, 0.1680], dtype=torch.float64)
NELBO: 7045.969665988164; pi: tensor([0.2507, 0.2076, 0.1962, 0.1645, 0.1810], dtype=torch.float64)
NELBO: 5304.534763378304; pi: tensor([0.2502, 0.2073, 0.1960, 0.1650, 0.1814], dtype=torch.float64)
NELBO: 4817.916598155215; pi: tensor([0.2496, 0.2072, 0.1960, 0.1655, 0.1818], dtype=torch.float64)
NELBO: 4685.253952916848; pi: tensor([0.2491, 0.2070, 0.1959, 0.1659, 0.1821], dtype=torch.float64)
NELBO: 4622.829624605565; pi: tensor([0.2486, 0.2069, 0.1959, 0.1663, 0.1823], dtype=torch.float64)
NELBO: 4575.669052501761; pi: tensor([0.2481, 0.2068, 0.1959, 0.1667, 0.1826], dtype=torch.float64)
NELBO: 4530.657024383784; pi: tensor([0.2476, 0.2067, 0.1959, 0.1671, 0.1828], dtype=torch.float64)
NELBO: 4481.51007124214; pi: tensor([0.2471, 0.2066, 0.1959, 0.1674, 0.1830], dtype=torch.float64)
NELBO: 4432.970539441295; pi: tensor([0.2466, 0.2065, 0.1960, 0.1677, 0.1832], dtype=torch.float64)


In [None]:
array([0.09716075, 0.12602853, 0.15238882, 0.2619916 , 0.36243031])

In [None]:
ddd

In [None]:
abs(loss[-1] - nelbo.detach())

In [None]:
i

In [None]:
loss[-1]

In [None]:
nelbo.detach()

In [None]:
lus = np.triu_indices(nc)

In [None]:
#from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader, random_split

part1 = int(0.7 * XX.shape[0])
part2 = int((XX.shape[0] - part1) / 2)
part3 = XX.shape[0] - part1 - part2

train, valid, test = random_split(torch.Tensor(XX), [part1, part2, part3], generator=torch.Generator().manual_seed(42))

trainloader = DataLoader(train, batch_size=256, shuffle=True)
validloader = DataLoader(valid, batch_size=256, shuffle=False)
testloader = DataLoader(test, batch_size=256, shuffle=False)

In [None]:
'''
for epoch in range(N_ITER):
  
  for j, batch_data in enumerate(trainloader):
    
    bY = torch.tensor(batch_data[:,:nf])
    bS = torch.tensor(batch_data[:,nf])
    bYS = torch.hstack((bY,bS.reshape(-1,1))).float()

    #print(bY.shape)
    #print(bS.shape)
    opt.zero_grad()

    r, log_r = r_net(bYS)
    v, log_v = v_net(bYS)
    d, log_d = d_net(bYS)
  
    loss = -ELBO(Theta, bY, bS, r, log_r, v, log_v, d, log_d)
    loss.backward()
    opt.step()
  
  if epoch % (1000 - 1) == 0:
    print(loss.detach())
    #print(F.log_softmax(d.mean(0)).exp())
    #print(F.log_softmax(Theta['is_pi']).exp())
    #print(loss.detach())
    #print(Theta['is_delta'].detach().exp())
    #print(Theta['is_pi'].detach().exp())
'''

In [None]:
F.log_softmax(Theta['is_delta']).exp()

In [None]:
F.log_softmax(Theta['is_pi']).exp()

In [None]:
F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc, nc).exp()