<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/damm_011122.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F

from sklearn.cluster import KMeans

from scipy.stats.mstats import winsorize

In [None]:
def compute_p_y_given_z(Y, Theta, dist, spillover_rate=None):
  
  ''' return p(x_n | z_n = c)_NxC '''

  if spillover_rate is None:
    mu = torch.exp(Theta['log_mu'])
  else:
    l = torch.sigmoid(Theta['is_so']) * spillover_rate
    first_half = (1 - l).reshape(-1, 1, 1) * torch.exp(Theta['log_mu'])
    second_half = (l.reshape(-1, 1) * SPILLOVER_MAT).reshape(-1, 1, Y.shape[1])
    mu = (first_half + second_half).mean(0)
    Theta['mnu'] = mu

  sigma = torch.exp(Theta['log_sigma'])

  if dist == 'normal':
    dist_Y = D.Normal(loc = mu, scale = sigma)
    #dist_Y = D.MultivariateNormal(loc = mu, covariance_matrix = sigma.reshape(-1,1,Y.shape[1]) * torch.eye(Y.shape[1]))
  elif dist == 'student':
    dist_Y = D.StudentT(df = 2.0, loc = mu, scale = sigma)
    #dist_Y = MultivariateStudentT(df = 2.0, loc = mu, scale_tril = sigma.reshape(-1,1,Y.shape[1]) * torch.eye(Y.shape[1]))
  
  Y_reshape = Y.reshape(-1, 1, Y.shape[1])
  res = dist_Y.log_prob(Y_reshape)
  return res.sum(2) # <- sum because IID over G

  #return dist_Y.log_prob(Y.reshape(-1, 1, Y.shape[1]))

def compute_p_y_given_gamma(Y, Theta, dist):
  
  """ p(y_n | gamma_n = [c,c'])_NxCxC """

  mu = torch.exp(Theta['log_mu'])
  mu_reshape = mu.reshape(1, mu.shape[0], mu.shape[1])
  mu2 = (mu + mu_reshape.permute(1, 0, 2)) / 2.0 # C x C x G matrix 
  #loc_input = mu2.reshape(-1, Y.shape[1])

  sigma = torch.exp(Theta['log_sigma'])
  sigma_reshape = sigma.reshape(1, mu.shape[0], mu.shape[1])
  sigma2 = (sigma + sigma_reshape.permute(1, 0, 2)) / 2.0
  #scale_input = sigma2.reshape(loc_input.shape[0], 1, Y.shape[1]) * torch.eye(Y.shape[1])
  
  if dist == 'normal':
    dist_Y2 = D.Normal(loc = mu2, scale = sigma2)
    #dist_Y2 = D.MultivariateNormal(loc = loc_input, covariance_matrix = scale_input)
  elif dist == 'student':
    dist_Y2 = D.StudentT(df = 2.0, loc = mu2, scale = sigma2)
    #dist_Y2 = MultivariateStudentT(df = 2.0, loc = loc_input, scale_tril = scale_input)

  Y_reshape = Y.reshape(-1, 1, 1, Y.shape[1])
  res = dist_Y2.log_prob(Y_reshape)
  return res.sum(3) # <- sum because IID over G
  #return dist_Y2.log_prob(Y.reshape(-1, 1, Y.shape[1])).reshape(-1, mu.shape[0], mu.shape[0])

In [None]:
def compute_p_s_given_z(S, Theta, dist):
  
  ''' return p(s_n | z_n = c)_NxC '''

  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega'])

  if dist == 'normal':
    dist_S = D.Normal(loc = psi, scale = omega)
  elif dist == 'student':
    dist_S = D.StudentT(df = 2.0, loc = psi, scale = omega)

  S_reshape = S.reshape(-1,1)
  return dist_S.log_prob(S_reshape) 

def compute_p_s_given_gamma(S, Theta, dist):
  
  """ p(s_n | gamma_n = [c,c'])_NxCxC """

  psi = torch.exp(Theta['log_psi'])
  psi_reshape = psi.reshape(1,-1)
  psi2 = psi + psi_reshape.permute(1,0)

  omega = torch.exp(Theta['log_omega'])
  omega_reshape = omega.reshape(1,-1)
  omega2 = omega + omega_reshape.permute(1,0)

  if dist == 'normal':
    dist_S2 = D.Normal(loc = psi2, scale = omega2)
  elif dist == 'student':
    dist_S2 = D.StudentT(df = 2.0, loc = psi2, scale = omega2)
  
  S_reshape = S.reshape(-1, 1, 1)
  return dist_S2.log_prob(S_reshape)

In [None]:
def kmeans_init(Y, S, k):
  
  kms = KMeans(k).fit(Y)
  init_labels = kms.labels_
  init_label_class = np.unique(init_labels)

  mu_init = np.array([Y[init_labels == c,:].mean(0) for c in init_label_class])
  sigma_init = np.array([Y[init_labels == c,:].std(0) for c in init_label_class])
  #sigma_init = np.array([X[init_labels == c,:].var(0) for c in init_label_class]) ## D.MultivariateNormal (no covariance)

  pi_init = np.array([np.mean(init_labels == c) for c in init_label_class])
  tau_init = np.ones((k, k))
  tau_init = tau_init / tau_init.sum()

  Theta = {
    'log_mu': np.log(mu_init),
    'log_sigma': np.log(sigma_init), #np.zeros_like(sigma_init),
    'is_delta': np.log([0.95, 0.05]),
    'is_pi': np.log(pi_init),
    'is_tau': np.log(tau_init),
    'is_so': torch.randn(Y.shape[0]),
    'mnu': torch.zeros(Y.shape[0], k, Y.shape[1])
  }

  if S is not None:
    psi_init = np.array([S[init_labels == c].mean() for c in init_label_class])
    omega_init = np.array([S[init_labels == c].std() for c in init_label_class])

    Theta['log_psi'] = np.log(psi_init),
    Theta['log_omega'] = np.log(omega_init)
    
  Theta = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
  Theta['is_delta'].requires_grad = False
  Theta['mnu'].requires_grad = False

  return Theta

def simulate_data(Y, S): ## use real data to simulate singlets/doublets
  
  ''' return same number of cells as in Y/S, half of them are singlets and another half are doublets '''

  #N_training = 5000
  sample_size = int(Y.shape[0]/2)
  idx_singlet = np.random.choice(Y.shape[0], size = sample_size, replace=True)
  Y_singlet = Y[idx_singlet,:] ## expression
  
  idx_doublet = [np.random.choice(Y.shape[0], size = sample_size), np.random.choice(Y.shape[0], size = sample_size)]
  Y_doublet = (Y[idx_doublet[0],:] + Y[idx_doublet[1],:])/2.
  
  fake_Y = torch.tensor(np.vstack([Y_singlet, Y_doublet]))
  fake_label = torch.tensor(np.concatenate([np.ones(sample_size), np.zeros(sample_size)]))

  if S is None:
    return fake_Y, None, fake_label
  else:
    S_singlet = S[idx_singlet]
    S_doublet = S[idx_doublet[0]] + S[idx_doublet[1]]  
    fake_S = torch.tensor(np.hstack([S_singlet, S_doublet]))
    return fake_Y, fake_S, fake_label ## have cell size and create fake cell size

class ConcatDataset(torch.utils.data.Dataset):
  def __init__(self, *datasets):
    self.datasets = datasets

  def __getitem__(self, i):
    return tuple(d[i] for d in self.datasets)

  def __len__(self):
    return min(len(d) for d in self.datasets)

def compute_nll_posteriors_p_singlet(Y, S, Theta, dist, spillover_rate=None):

  log_pi = F.log_softmax(Theta['is_pi'])
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1)).reshape(log_pi.shape[0], log_pi.shape[0])
  log_delta = F.log_softmax(Theta['is_delta'])

  prob_y_given_z = compute_p_y_given_z(Y, Theta, dist, spillover_rate) ## p(y_n|z=c)
  prob_data_given_z_d0 = prob_y_given_z + log_pi
  
  if S is not None:
    prob_s_given_z = compute_p_s_given_z(S, Theta, dist) ## p(data_n|z=c)
    prob_data_given_z_d0 += prob_s_given_z ## p(data_n|z=c,d=0) -> NxC
  
  prob_y_given_gamma = compute_p_y_given_gamma(Y, Theta, dist) ## p(y_n|g=[c,c']) -> NxCxC
  prob_data_given_gamma_d1 = prob_y_given_gamma + log_tau
  
  if S is not None:
    prob_s_given_gamma = compute_p_s_given_gamma(S, Theta, dist) ## p(s_n|g=[c,c']) -> NxCxC
    prob_data_given_gamma_d1 += prob_s_given_gamma ## p(data_n|d=1) -> NxCxC

  #p_data = torch.cat([prob_data_given_z_d0 + log_delta[0], prob_data_given_gamma_d1.reshape(X.shape[0], -1) + log_delta[1]], dim=1)
  prob_data = torch.hstack([prob_data_given_z_d0 + log_delta[0], prob_data_given_gamma_d1.reshape(Y.shape[0], -1) + log_delta[1]]) ## p(data)
  prob_data_norm = torch.logsumexp(prob_data, dim=1)

  r = prob_data_given_z_d0.T + log_delta[0] - prob_data_norm ## p(d=0,z=c|data)
  v = prob_data_given_gamma_d1.T + log_delta[1] - prob_data_norm ## p(gamma=[c,c']|data)

  ## normalize
  prob_data_given_d0 = torch.logsumexp(prob_data_given_z_d0, dim=1) ## p(data_n|d=0)_N
  prob_singlet = torch.exp(prob_data_given_d0 + log_delta[0] - prob_data_norm)

  ## average negative likelihood scores
  cost = -torch.logsumexp(prob_data, dim=1).mean()

  return cost, prob_singlet, r.T, v.T

In [None]:
def prepLoader(Y, S, incSimDat, incSimCellSize, BatchSize):
  
  if incSimDat:
    if incSimCellSize:
      fake_Y, fake_S, fake_L = simulate_data(Y, S) ## use real data to simulate singlets/doublets
    else:
      fake_Y, fake_S, fake_L = simulate_data(Y, None) ## use real data to simulate singlets/doublets

    if S is None:
      df = ConcatDataset(Y, fake_Y, fake_L)
    else:
      if fake_S is None:
        df = ConcatDataset(Y, S, fake_Y, fake_L)
      else:
        df = ConcatDataset(Y, S, fake_Y, fake_S, fake_L)
  else:
    if S is None:
      df = ConcatDataset(Y)
    else:
      df = ConcatDataset(Y, S)

  return torch.utils.data.DataLoader(df, batch_size=BatchSize, shuffle=True)

def batchSet(batch, incSimDat):
  if incSimDat:
    if len(batch) == 3:
      return bat[0], None, bat[1], None, bat[2]
      #bY = bat[0]; bS = None; bFY = bat[1]; bFS = None; bFL = bat[2]
    elif len(batch) == 4:
      return bat[0], bat[1], bat[2], None, bat[3]
      #bY = bat[0]; bS = bat[1]; bFY = bat[2]; bFS = None; bFL = bat[3]
    else:
      return bat[0], bat[1], bat[2], bat[3], bat[4]
      #bY = bat[0]; bS = bat[1]; bFY = bat[2]; bFS = bat[3]; bFL = bat[4]
  else:
    if len(batch) == 1:
      return bat[0], None, None, None, None
      #bY = bat[0]; bS = None; bFY = None; bFS = None; bFL = None
    else:
      return bat[0], bat[1], None, None, None
      #bY = bat[0]; bS = bat[1]; bFY = None; bFS = None; bFL = None

def batchTrain(batch, Theta, dist, spillover_rate=None, incSimDat=False, regularized_val=1000):

  bY, bS, bFY, bFS, bFL = batchSet(batch, incSimDat)

  opt.zero_grad()

  #rnll, _, _, _ = nll(bY, bS, Theta, noiseModel, False)
  rnll, _, _, _ = compute_nll_posteriors_p_singlet(bY, bS, Theta, dist, spillover_rate)

  if incSimDat:
    #fnll, p_fake_singlet, _, _ = nll(bFY, bFS, Theta, noiseModel, False)
    fnll, p_fake_singlet, _, _  = compute_nll_posteriors_p_singlet(bFY, bFS, Theta, dist, None)
    floss = nn.BCELoss()(p_fake_singlet, bFL) ## want to min 
    closs = rnll + regularized_val * floss
    closs.backward()
    opt.step()
    return closs.detach(), rnll.detach(), fnll.detach(), floss.detach()
  else:
    rnll.backward()
    opt.step()
    return rnll.detach(), 0, 0, 0

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import os
os.chdir('/content/gdrive/MyDrive/Colab Notebooks/')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
#!pip install scanpy
import scanpy as sc
#imc_data = sc.read_h5ad("DAMM/data/mouse_5k.h5ad")
imc_data = sc.read_h5ad("DAMM/data/basel_cohort_single_cell_expression.h5ad")
SPILLOVER_MAT = torch.load("DAMM/data/basel_mask_spillover") ## should we winsorize this?

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


In [None]:
X = imc_data.X.copy() ## expressions
S = winsorize(imc_data.obs['size'], limits=[0, 0.01]).data ## winsorize cell sizes

no, nf = imc_data.shape #number cells (observations) & proteins (features)

for i in range(nf):
  X[:,i] = winsorize(X[:,i], limits=[0, 0.01]).data ## winsorize cell expressions

In [None]:
NC = 5
MAX_EPOCH = 50
BATCH_SIZE = 128
noiseModel = 'student'
TOL = 1e-3 #converagence criterion

incSimDat = True
incSimCellSize = True
regularized_val = 1000 ## or None

spillover_rate = None

In [None]:
Theta = kmeans_init(X, S, NC)
Y = torch.tensor(X)
S = torch.tensor(S)



In [None]:
opt = optim.Adam(Theta.values())
train_loader = prepLoader(Y, S, incSimDat, incSimCellSize, BATCH_SIZE)

loss = []
for epoch in range(MAX_EPOCH):

  with tqdm(train_loader, unit="batch") as tepoch:
  
    rnlls = 0; fnlls = 0; floss = 0; tloss = 0
    for j, bat in enumerate(tepoch):
      
      tepoch.set_description(f"Epoch {epoch}")
      
      tot_loss, real_nll, fake_nll, fake_loss = batchTrain(bat, Theta, noiseModel, spillover_rate, incSimDat, regularized_val)

      tloss += tot_loss
      rnlls += real_nll
      fnlls += fake_nll
      floss += fake_loss

      tepoch.set_postfix(loss=tloss.item())
      #print(tot_loss.item())

    with torch.no_grad():

      loss.append(tloss)
      print('Epoch: {}: Loss: {}'.format(epoch, loss[-1]))
      
      if epoch > 10 and abs(np.mean(loss[-5:]) - np.mean(loss[-6:-1])) < TOL:
        break

Epoch 0: 100%|██████████| 6271/6271 [02:33<00:00, 40.74batch/s, loss=3.3e+6]


Epoch: 0: Loss: 3298017.100899173


Epoch 1: 100%|██████████| 6271/6271 [02:33<00:00, 40.74batch/s, loss=2.63e+6]


Epoch: 1: Loss: 2629183.510330238


Epoch 2: 100%|██████████| 6271/6271 [02:33<00:00, 40.77batch/s, loss=2.57e+6]


Epoch: 2: Loss: 2573556.566849665


Epoch 3: 100%|██████████| 6271/6271 [02:34<00:00, 40.68batch/s, loss=2.54e+6]


Epoch: 3: Loss: 2536742.445507876


Epoch 4: 100%|██████████| 6271/6271 [02:34<00:00, 40.68batch/s, loss=2.51e+6]


Epoch: 4: Loss: 2511660.669029004


Epoch 5: 100%|██████████| 6271/6271 [02:34<00:00, 40.51batch/s, loss=2.49e+6]


Epoch: 5: Loss: 2493489.019274751


Epoch 6: 100%|██████████| 6271/6271 [02:36<00:00, 40.01batch/s, loss=2.48e+6]


Epoch: 6: Loss: 2478865.200222315


Epoch 7: 100%|██████████| 6271/6271 [02:39<00:00, 39.25batch/s, loss=2.47e+6]


Epoch: 7: Loss: 2466543.5282160193


Epoch 8: 100%|██████████| 6271/6271 [02:40<00:00, 39.10batch/s, loss=2.46e+6]


Epoch: 8: Loss: 2455889.392090912


Epoch 9: 100%|██████████| 6271/6271 [02:39<00:00, 39.21batch/s, loss=2.45e+6]


Epoch: 9: Loss: 2446730.4346180256


Epoch 10: 100%|██████████| 6271/6271 [02:40<00:00, 39.16batch/s, loss=2.44e+6]


Epoch: 10: Loss: 2438111.218512453


Epoch 11: 100%|██████████| 6271/6271 [02:39<00:00, 39.21batch/s, loss=2.43e+6]


Epoch: 11: Loss: 2430531.9071918493


Epoch 12: 100%|██████████| 6271/6271 [02:39<00:00, 39.25batch/s, loss=2.42e+6]


Epoch: 12: Loss: 2423731.063210821


Epoch 13: 100%|██████████| 6271/6271 [02:37<00:00, 39.72batch/s, loss=2.42e+6]


Epoch: 13: Loss: 2417694.50114292


Epoch 14: 100%|██████████| 6271/6271 [02:38<00:00, 39.55batch/s, loss=2.41e+6]


Epoch: 14: Loss: 2412651.173956381


Epoch 15: 100%|██████████| 6271/6271 [02:37<00:00, 39.84batch/s, loss=2.41e+6]


Epoch: 15: Loss: 2408465.874113013


Epoch 16: 100%|██████████| 6271/6271 [02:37<00:00, 39.90batch/s, loss=2.41e+6]


Epoch: 16: Loss: 2405046.8494217135


Epoch 17: 100%|██████████| 6271/6271 [02:41<00:00, 38.90batch/s, loss=2.4e+6]


Epoch: 17: Loss: 2401316.35489583


Epoch 18: 100%|██████████| 6271/6271 [02:40<00:00, 39.04batch/s, loss=2.4e+6]


Epoch: 18: Loss: 2398817.323984304


Epoch 19: 100%|██████████| 6271/6271 [02:44<00:00, 38.19batch/s, loss=2.4e+6]


Epoch: 19: Loss: 2395927.1943968036


Epoch 20: 100%|██████████| 6271/6271 [02:42<00:00, 38.55batch/s, loss=2.39e+6]


Epoch: 20: Loss: 2394189.3150067767


Epoch 21: 100%|██████████| 6271/6271 [02:43<00:00, 38.27batch/s, loss=2.39e+6]


Epoch: 21: Loss: 2391782.726567285


Epoch 22: 100%|██████████| 6271/6271 [02:43<00:00, 38.41batch/s, loss=2.39e+6]


Epoch: 22: Loss: 2390278.9470099215


Epoch 23: 100%|██████████| 6271/6271 [02:40<00:00, 39.02batch/s, loss=2.39e+6]


Epoch: 23: Loss: 2388578.8661501883


Epoch 24: 100%|██████████| 6271/6271 [02:39<00:00, 39.23batch/s, loss=2.39e+6]


Epoch: 24: Loss: 2386892.7617484326


Epoch 25: 100%|██████████| 6271/6271 [02:41<00:00, 38.88batch/s, loss=2.39e+6]


Epoch: 25: Loss: 2385855.4580163746


Epoch 26: 100%|██████████| 6271/6271 [02:42<00:00, 38.67batch/s, loss=2.38e+6]


Epoch: 26: Loss: 2384517.7738934355


Epoch 27: 100%|██████████| 6271/6271 [02:41<00:00, 38.72batch/s, loss=2.38e+6]


Epoch: 27: Loss: 2383774.015577328


Epoch 28: 100%|██████████| 6271/6271 [02:41<00:00, 38.91batch/s, loss=2.38e+6]


Epoch: 28: Loss: 2382391.3270837753


Epoch 29: 100%|██████████| 6271/6271 [02:44<00:00, 38.06batch/s, loss=2.38e+6]


Epoch: 29: Loss: 2382263.007541943


Epoch 30: 100%|██████████| 6271/6271 [02:42<00:00, 38.68batch/s, loss=2.38e+6]


Epoch: 30: Loss: 2381219.5095238658


Epoch 31: 100%|██████████| 6271/6271 [02:43<00:00, 38.44batch/s, loss=2.38e+6]


Epoch: 31: Loss: 2380738.1850750176


Epoch 32: 100%|██████████| 6271/6271 [02:43<00:00, 38.38batch/s, loss=2.38e+6]


Epoch: 32: Loss: 2380322.318315875


Epoch 33: 100%|██████████| 6271/6271 [02:43<00:00, 38.28batch/s, loss=2.38e+6]


Epoch: 33: Loss: 2379787.3650349546


Epoch 34: 100%|██████████| 6271/6271 [02:43<00:00, 38.40batch/s, loss=2.38e+6]


Epoch: 34: Loss: 2378892.1018270184


Epoch 35: 100%|██████████| 6271/6271 [02:44<00:00, 38.16batch/s, loss=2.38e+6]


Epoch: 35: Loss: 2378842.952868787


Epoch 36: 100%|██████████| 6271/6271 [02:46<00:00, 37.69batch/s, loss=2.38e+6]


Epoch: 36: Loss: 2378262.217773454


Epoch 37: 100%|██████████| 6271/6271 [02:44<00:00, 38.16batch/s, loss=2.38e+6]


Epoch: 37: Loss: 2377887.591316088


Epoch 38: 100%|██████████| 6271/6271 [02:46<00:00, 37.76batch/s, loss=2.38e+6]


Epoch: 38: Loss: 2377561.401906977


Epoch 39: 100%|██████████| 6271/6271 [02:45<00:00, 37.81batch/s, loss=2.38e+6]


Epoch: 39: Loss: 2377008.484215345


Epoch 40: 100%|██████████| 6271/6271 [02:45<00:00, 37.86batch/s, loss=2.38e+6]


Epoch: 40: Loss: 2376917.3875858905


Epoch 41: 100%|██████████| 6271/6271 [02:47<00:00, 37.42batch/s, loss=2.38e+6]


Epoch: 41: Loss: 2376402.3200523793


Epoch 42: 100%|██████████| 6271/6271 [02:47<00:00, 37.55batch/s, loss=2.38e+6]


Epoch: 42: Loss: 2376109.8330418933


Epoch 43: 100%|██████████| 6271/6271 [02:45<00:00, 37.84batch/s, loss=2.38e+6]


Epoch: 43: Loss: 2376221.25752942


Epoch 44: 100%|██████████| 6271/6271 [02:45<00:00, 37.92batch/s, loss=2.38e+6]


Epoch: 44: Loss: 2375699.3041182575


Epoch 45: 100%|██████████| 6271/6271 [02:43<00:00, 38.31batch/s, loss=2.38e+6]


Epoch: 45: Loss: 2375280.1971438434


Epoch 46: 100%|██████████| 6271/6271 [02:47<00:00, 37.35batch/s, loss=2.37e+6]


Epoch: 46: Loss: 2374839.968107286


Epoch 47: 100%|██████████| 6271/6271 [02:46<00:00, 37.70batch/s, loss=2.38e+6]


Epoch: 47: Loss: 2375129.5498166466


Epoch 48: 100%|██████████| 6271/6271 [02:44<00:00, 38.18batch/s, loss=2.37e+6]


Epoch: 48: Loss: 2374466.155406171


Epoch 49: 100%|██████████| 6271/6271 [02:44<00:00, 38.20batch/s, loss=2.37e+6]

Epoch: 49: Loss: 2374124.78787979



