<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/VI_DAMM_0831.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F

import torch.nn as nn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

from matplotlib.backends.backend_pdf import PdfPages

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA

import seaborn as sns

%matplotlib inline

In [2]:
#%%capture
#!pip install wandb --upgrade

In [3]:
#import wandb
#wandb.login()

In [4]:
## generate data
def generateData(n_clusters = 3, n_obs = 10000, n_features = 2):

  #n_clusters = 3; n_obs = 100; n_features = 2
  
  ## set truth expression means/covariances (multivariate) ##
  mu = np.random.rand(n_clusters, n_features)
  # mu = np.sort(mu, 0) ## sort expressions
  sigma = 0.001 * np.identity(n_features) ## variance-covariance matrix

  ## set truth cell size means/variances (univariate) ##
  psi = [np.random.normal(100, 25) for i in range(n_clusters)]
  #psi = np.arange(90, 90 + 5 * n_clusters, 5)
  psi = np.sort(psi, 0)
  omega = 1 ## standard deviation
  ###

  ## set latent variables distributions ##
  lambda_arr = np.random.binomial(1, .95, n_obs) # p=.95 (a cell belongs to singlet or doublet) 

  n_singlet = np.sum(lambda_arr == 1) ## number of cells in singlet clusters
  n_doublet = np.sum(lambda_arr == 0) ## number of cells in doublet clusters
  
  lambda0_arr = n_singlet / n_obs ## proportion of cells belong to singlet
  lambda1_arr = n_doublet / n_obs ## proportion of cells belong to doublet

  #pi_arr = np.sort(np.random.sample(n_clusters))
  pi_arr = np.sort(np.random.rand(n_clusters))
  pi_arr /= pi_arr.sum()

  n_doublet_clusters = int((n_clusters * n_clusters - n_clusters)/2 + n_clusters)
  #tau_arr = np.sort(np.random.sample(n_doublet_clusters))
  tau_arr = np.sort(np.random.rand(n_doublet_clusters))
  tau_arr /= tau_arr.sum()

  ## draw cells based on defined parameters theta1 = (mu, sigma, psi, omega) & theta2 = (lambda, pi, tau)
  x = np.zeros((n_singlet, n_features+5))
  for i in range(n_singlet):
    selected_cluster = np.random.choice(n_clusters, size = 1, p = pi_arr)[0] ## select a single cell cluster
    x[i] = np.append(np.random.multivariate_normal(mu[selected_cluster], sigma),
                     [np.random.normal(psi[selected_cluster], omega), 
                      0, selected_cluster, 0, selected_cluster + n_doublet_clusters])
  
  x[x < 0] = 1e-4
  lookups = np.triu_indices(n_clusters) # wanted indices
  xx = np.zeros((n_doublet, n_features+5))
  for i in range(n_doublet):
    selected_cluster = np.random.choice(n_doublet_clusters, p = tau_arr)

    indx1 = lookups[0][selected_cluster]
    indx2 = lookups[1][selected_cluster]

    xx[i] = np.append(np.random.multivariate_normal( (mu[indx1] + mu[indx2])/2, (sigma + sigma)/2 ),
                     [np.random.normal( (psi[indx1] + psi[indx2]), omega+omega ), 
                      1, indx1, indx2, selected_cluster])
  xx[xx < 0] = 1e-4
  xxx = np.append(x, xx).reshape(n_obs, n_features+5)

  truth_theta = {
    'log_mu': np.log(mu),
    'log_sigma': np.log(sigma),
    'log_psi': np.log(psi),
    'log_omega': np.log(omega),
    "log_lambda0": np.log(lambda0_arr),
    'log_pi': np.log(pi_arr),
    'log_tau': np.log(tau_arr)
  }

  return xxx[:,:n_features], xxx[:,n_features], xxx, truth_theta

  #return torch.tensor(xxx[:,:n_features]), torch.tensor(xxx[:,n_features]), torch.tensor(xxx), [mu, sigma, psi, omega], [lambda0_arr, pi_arr, tau_arr]

In [5]:
def compute_p_y_given_z(Y, Theta):
  """ Returns NxC
  p(y_n | z_n = c)
  """
  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma'])

  dist_Y = D.Normal(mu, sigma)
  return dist_Y.log_prob(Y.reshape(Y.shape[0], 1, nf)).sum(2) # <- sum because IID over G

def compute_p_s_given_z(S, Theta):
  """ Returns NxC
  p(s_n | z_n = c)
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega'])

  dist_S = D.Normal(psi, omega)
  return dist_S.log_prob(S.reshape(-1,1)) 

def compute_p_y_given_gamma(Y, Theta):
  """ NxCxC
  p(y_n | gamma_n = [c,c'])
  """

  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma'])

  mu2 = mu.reshape(1, nc, nf)
  mu2 = (mu2 + mu2.permute(1, 0, 2)) / 2.0 # C x C x G matrix 

  sigma2 = sigma.reshape(1, nc, nf)
  sigma2 = (sigma2 + sigma2.permute(1,0,2)) / 2.0

  dist_Y2 = D.Normal(mu2, sigma2)
  return  dist_Y2.log_prob(Y.reshape(-1, 1, 1, nf)).sum(3) # <- sum because IID over G

def compute_p_s_given_gamma(S, Theta):
  """ NxCxC
  p(s_n | gamma_n = [c,c'])
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega'])

  psi2 = psi.reshape(-1,1)
  psi2 = psi2 + psi2.T

  omega2 = omega.reshape(-1,1)
  omega2 = omega2 + omega2.T

  dist_S2 = D.Normal(psi2, omega2)
  return dist_S2.log_prob(S.reshape(-1, 1, 1))

In [6]:
class BasicForwardNet(nn.Module):
  """Encoder for when data is input without any encoding"""
  def __init__(self, input_dim, output_dim, hidden_dim = 5):
    super().__init__()
    
    self.input = nn.Linear(input_dim, hidden_dim)
    self.linear1 = nn.Linear(hidden_dim, hidden_dim)
    self.output = nn.Linear(hidden_dim, output_dim)
    
  def forward(self, x):
    out = F.relu(self.input(x))
    out = F.relu(self.linear1(out))
    out = self.output(out)
        
    return F.softmax(out, dim=1), F.log_softmax(out, dim=1) ## r/v/d log_r/log_v/log_d

In [7]:
def ELBO(Theta, Y, S, q1, log_q1, q2, log_q2):

  log_pi = F.log_softmax(Theta['is_pi'])
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc,nc)
  log_delta = F.log_softmax(Theta['is_delta'])
  
  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  log_rzd0 = p_s_given_z + p_y_given_z + log_pi + log_delta[0]

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  log_vgd1 = p_y_given_gamma + p_s_given_gamma + log_tau + log_delta[1]

  #remove_indices = np.tril_indices(nc, -1) ## remove indices
  #log_rd1g[:, remove_indices[0], remove_indices[1]] = float("NaN")

  #q1 = r.exp() * log_rd0z #; q1[torch.isnan(q1)] = 0.0
  #q2 = v.exp() * log_rd1g #; q2[torch.isnan(q2)] = 0.0

  #q1 = r * (log_rd0z - log_r) ## r * log_r / r * log_rd0z
  #q2 = v.reshape(Y.shape[0], nc, nc) * (log_rd1g - log_v.reshape(Y.shape[0], nc, nc))

  #q1 = (d[:,0].reshape(-1,1) * r) * (log_rd0z - log_r) ## r * log_r / r * log_rd0z
  #q2 = (d[:,1].reshape(-1,1) * v).reshape(Y.shape[0], nc, nc) * (log_rd1g - log_v.reshape(Y.shape[0], nc, nc))

  elbo1 = q1 * (log_rzd0 - log_q1)
  elbo2 = q2 * (log_vgd1 - log_q2)

  return elbo1.sum() + elbo2.sum()

In [8]:
#F.sigmoid(torch.tensor([.3, .6]))

In [9]:
nc = 5; no = 1000; nf = 2

Y, S, XX, theta_true = generateData(n_clusters = nc, n_obs = no, n_features = nf)



In [10]:
N_INIT = 20

Y = np.array(Y)
S = np.array(S)

kms = [KMeans(nc).fit(Y) for i in range(N_INIT)]
inertias = [k.inertia_ for k in kms]
km = kms[np.argmin(np.array(inertias))] ## selected "best" kmeans based on inertia score
init_labels = km.labels_

mu_init = np.array([Y[init_labels == i,:].mean(0) for i in np.unique(init_labels)])
sigma_init = np.array([Y[init_labels == i,:].std(0) for i in np.unique(init_labels)])
psi_init = np.array([S[init_labels == i].mean() for i in np.unique(init_labels)])
omega_init = np.array([S[init_labels == i].std() for i in np.unique(init_labels)])
pi_init = np.array([np.mean(init_labels == i) for i in np.unique(init_labels)])
tau_init = np.ones((nc,nc))
tau_init = tau_init / tau_init.sum()

In [11]:
P = Y.shape[1] + 1
r_net = BasicForwardNet(P, nc)
v_net = BasicForwardNet(P, nc ** 2) 
d_net = BasicForwardNet(P, 2)

In [12]:
Theta = {
    'log_mu': np.log(mu_init) + 0.05 * np.random.randn(mu_init.shape[0], mu_init.shape[1]),
    'log_sigma': np.log(sigma_init), #np.zeros_like(sigma_init),
    'log_psi': np.log(psi_init),
    'log_omega': np.log(omega_init),
    "is_delta": F.log_softmax(torch.tensor([0.95, 1-0.95])),
    'is_pi': F.log_softmax(torch.tensor(pi_init)),
    'is_tau': F.log_softmax(torch.tensor(tau_init))
}
Theta = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}

#Theta['is_delta'].requires_grad = False
#Theta['is_pi'].requires_grad = False
#Theta['is_tau'].requires_grad = False

  
  import sys
  
  # Remove the CWD from sys.path while we load stuff.


In [13]:
Y = torch.tensor(Y)
S = torch.tensor(S)
YS = torch.hstack((Y,S.reshape(-1,1))).float()

In [14]:
N_ITER = 50000
lr = 1e-4
tol = 1e-3
params = list(Theta.values()) + list(r_net.parameters()) + list(v_net.parameters()) + list(d_net.parameters())
opt = optim.AdamW(params, lr=lr)

In [15]:
%%capture
!pip install wandb --upgrade

import wandb
wandb.login()

wandb: Paste an API key from your profile and hit enter: ··········


In [16]:
wandb.init(project='jett-vi',
           config={
    "N_EPOCHS": N_ITER,
    "LR": lr,
    "TOL": tol,
    'MODEL_TYPE': 'vi',
    'DATA_TYPE': 'toy_data'
    })

[34m[1mwandb[0m: Currently logged in as: [33myujulee[0m (use `wandb login --relogin` to force relogin)


In [None]:
#loss = []
for i in range(N_ITER):
  
  opt.zero_grad()
  r, log_r = r_net(YS)
  v, log_v = v_net(YS)
  d, log_d = d_net(YS)
  
  q1 = d[:,0].reshape(-1,1) * r
  log_q1 = log_d[:,0].reshape(-1,1) + log_r

  q2 = (d[:,1].reshape(-1,1) * v).reshape(Y.shape[0], nc, nc)
  log_q2 = (log_d[:,1].reshape(-1,1) + log_v).reshape(Y.shape[0], nc, nc)

  #assert( ((d_net(YS)[0][:,1].reshape(-1,1) * v_net(YS)[0]).sum(1) + (d_net(YS)[0][:,0].reshape(-1,1) * r_net(YS)[0]).sum(1)).sum() == 1000. )

  nelbo = -ELBO(Theta, Y, S, q1, log_q1, q2, log_q2)
  nelbo.backward()
  opt.step()
  
  if i % (1000 - 1) == 0:
    print("NELBO: {}; lambda: {}".format(nelbo.detach(), F.log_softmax(Theta['is_delta'].detach()).exp()))
  
  if i > 0 and abs(loss[-1] + nelbo.detach().sum()) < tol:
    break
           
  loss.append(-nelbo.detach().sum())
  
  '''
  wandb.log({
    'ITER': i + 1, 
    'elbo': elbo.detach(),
    'log_mu': Theta['log_mu'],
    'log_sigma': Theta['log_sigma'], #np.zeros_like(sigma_init),
    'log_psi': Theta['log_psi'],
    'log_omega': Theta['log_omega'],
    "is_delta": Theta['is_delta'],
    'is_pi': Theta['is_pi'],
    'is_tau': Theta['is_tau'],
    'r': r,
    'v': v,
    'd': d,
      })
  '''

  This is separate from the ipykernel package so we can avoid doing imports until
  after removing the cwd from sys.path.
  """


NELBO: 4711.335402218515; lambda: tensor([0.7555, 0.2445])
NELBO: 4675.178250545157; lambda: tensor([0.7590, 0.2410])
NELBO: 4639.460127334241; lambda: tensor([0.7625, 0.2375])
NELBO: 4604.096164185553; lambda: tensor([0.7659, 0.2341])
NELBO: 4568.95755924212; lambda: tensor([0.7693, 0.2307])
NELBO: 4533.798675935918; lambda: tensor([0.7726, 0.2274])
NELBO: 4498.364920571923; lambda: tensor([0.7759, 0.2241])
NELBO: 4462.518500089011; lambda: tensor([0.7792, 0.2208])
NELBO: 4426.053426691622; lambda: tensor([0.7824, 0.2176])
NELBO: 4388.692741426011; lambda: tensor([0.7856, 0.2144])
NELBO: 4350.275187313555; lambda: tensor([0.7886, 0.2114])
NELBO: 4311.027033592249; lambda: tensor([0.7916, 0.2084])
NELBO: 4271.816261392525; lambda: tensor([0.7945, 0.2055])
NELBO: 4233.2705610058465; lambda: tensor([0.7974, 0.2026])


In [18]:
ddd

NameError: ignored

In [None]:
F.log_softmax(Theta['is_delta']).exp()

In [None]:
np.exp(theta_true['log_lambda0'])

In [None]:
F.log_softmax(Theta['is_pi']).exp()

In [None]:
np.exp(theta_true['log_pi'])

In [None]:
#F.log_softmax(v.mean(0).reshape(-1)).reshape(nc, nc).exp()

In [None]:
F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc, nc).exp()

In [None]:
np.exp(theta_true['log_tau'])

In [None]:
Theta['log_mu'].exp()

In [None]:
np.exp(theta_true['log_mu'])

In [None]:
lus = np.triu_indices(nc)

In [None]:
#from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader, random_split

part1 = int(0.7 * XX.shape[0])
part2 = int((XX.shape[0] - part1) / 2)
part3 = XX.shape[0] - part1 - part2

train, valid, test = random_split(torch.Tensor(XX), [part1, part2, part3], generator=torch.Generator().manual_seed(42))

trainloader = DataLoader(train, batch_size=256, shuffle=True)
validloader = DataLoader(valid, batch_size=256, shuffle=False)
testloader = DataLoader(test, batch_size=256, shuffle=False)

In [None]:
'''
for epoch in range(N_ITER):
  
  for j, batch_data in enumerate(trainloader):
    
    bY = torch.tensor(batch_data[:,:nf])
    bS = torch.tensor(batch_data[:,nf])
    bYS = torch.hstack((bY,bS.reshape(-1,1))).float()

    #print(bY.shape)
    #print(bS.shape)
    opt.zero_grad()

    r, log_r = r_net(bYS)
    v, log_v = v_net(bYS)
    d, log_d = d_net(bYS)
  
    loss = -ELBO(Theta, bY, bS, r, log_r, v, log_v, d, log_d)
    loss.backward()
    opt.step()
  
  if epoch % (1000 - 1) == 0:
    print(loss.detach())
    #print(F.log_softmax(d.mean(0)).exp())
    #print(F.log_softmax(Theta['is_pi']).exp())
    #print(loss.detach())
    #print(Theta['is_delta'].detach().exp())
    #print(Theta['is_pi'].detach().exp())
'''

In [None]:
F.log_softmax(Theta['is_delta']).exp()

In [None]:
F.log_softmax(Theta['is_pi']).exp()

In [None]:
F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(nc, nc).exp()