<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/real_all_versions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%%capture
!pip install wandb --upgrade

%pip install scanpy
import wandb

import argparse
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F

#from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader, random_split

from sklearn.metrics import confusion_matrix
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import adjusted_mutual_info_score

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

import random

def compute_p_y_given_z(Y, Theta, reg=1e-6):
  """ Returns NxC
  p(y_n | z_n = c)
  """
  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma']) + reg

  dist_Y = D.Normal(mu, sigma)
  return dist_Y.log_prob(Y.reshape(Y.shape[0], 1, NF)).sum(2) # <- sum because IID over G

def compute_p_s_given_z(S, Theta, reg=1e-6):
  """ Returns NxC
  p(s_n | z_n = c)
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega']) + reg

  dist_S = D.Normal(psi, omega)
  return dist_S.log_prob(S.reshape(-1,1)) 

def compute_p_y_given_gamma(Y, Theta, reg=1e-6):
  """ NxCxC
  p(y_n | gamma_n = [c,c'])
  """

  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma']) + reg

  mu2 = mu.reshape(1, NC, NF)
  mu2 = (mu2 + mu2.permute(1, 0, 2)) / 2.0 # C x C x G matrix 

  sigma2 = sigma.reshape(1, NC, NF)
  sigma2 = (sigma2 + sigma2.permute(1,0,2)) / 2.0

  dist_Y2 = D.Normal(mu2, sigma2)
  return  dist_Y2.log_prob(Y.reshape(-1, 1, 1, NF)).sum(3) # <- sum because IID over G

def compute_p_s_given_gamma(S, Theta, reg=1e-6):
  """ NxCxC
  p(s_n | gamma_n = [c,c'])
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega']) + reg

  psi2 = psi.reshape(-1,1)
  psi2 = psi2 + psi2.T

  omega2 = omega.reshape(-1,1)
  omega2 = omega2 + omega2.T

  dist_S2 = D.Normal(psi2, omega2)
  return dist_S2.log_prob(S.reshape(-1, 1, 1))

def compute_r_v_2(Y, S, Theta):
  """Need to compute
  p(gamma = [c,c'], d= 1 | Y,S)
  p(z = c, d=0 | Y,S)
  """
  log_pi = F.log_softmax(Theta['is_pi'], 0)
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1), 0).reshape(NC,NC)
  log_delta = F.log_softmax(Theta['is_delta'], 0)

  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  p_data_given_z_d0 = p_y_given_z + p_s_given_z + log_pi
  p_data_given_d0 = torch.logsumexp(p_data_given_z_d0, dim=1) # this is p(data|d=0)

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  p_data_given_gamma_d1 = (p_y_given_gamma + p_s_given_gamma + log_tau).reshape(Y.shape[0], -1)

  # p_data_given_d1 = torch.logsumexp(p_data_given_gamma_d1, dim=1)

  p_data = torch.cat([p_data_given_z_d0 + log_delta[0], p_data_given_gamma_d1 + log_delta[1]], dim=1)
  p_data = torch.logsumexp(p_data, dim=1)

  r = p_data_given_z_d0.T + log_delta[0] - p_data
  v = p_data_given_gamma_d1.T + log_delta[1] - p_data

  p_singlet = torch.exp(p_data_given_d0 + log_delta[0] - p_data)

  return r.T, v.T.reshape(-1,NC,NC), p_data.sum(), p_singlet #, p_assign, p_assign1

## for EM version
def Q(Theta, Y, S, r, v, ignored_indices):

  log_pi = F.log_softmax(Theta['is_pi'], 0)
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1), 0).reshape(NC,NC)
  log_delta = F.log_softmax(Theta['is_delta'], 0)

  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  log_rd0z = p_s_given_z + p_y_given_z + log_pi + log_delta[0]

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  log_rd1g = p_y_given_gamma + p_s_given_gamma + log_tau + log_delta[1] # can use torch.triu to get upper triangle

  #remove_indices = np.tril_indices(nc, -1) ## remove indices
  #log_rd1g[:, remove_indices[0], remove_indices[1]] = float("NaN")

  q1 = log_rd0z * r.exp() #; q1[torch.isnan(q1)] = 0.0
  q2 = log_rd1g * v.exp() #; q2[torch.isnan(q2)] = 0.0

  #print("{} {} {}".format(log_rd1g.shape, v.shape, q2.shape))
  return q1.sum() + q2.sum()

def ll(Y, S, Theta):
  """compute
  p(gamma = [c,c'], d= 1 | Y,S)
  p(z = c, d=0 | Y,S)
  """
  log_pi = F.log_softmax(Theta['is_pi'], 0)
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1), 0).reshape(NC,NC)
  log_delta = F.log_softmax(Theta['is_delta'], 0)

  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  p_data_given_z_d0 = p_y_given_z + p_s_given_z + log_pi
  p_data_given_d0 = torch.logsumexp(p_data_given_z_d0, dim=1) # this is p(data|d=0)

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  p_data_given_gamma_d1 = (p_y_given_gamma + p_s_given_gamma + log_tau).reshape(Y.shape[0], -1)

  # p_data_given_d1 = torch.logsumexp(p_data_given_gamma_d1, dim=1)

  p_data = torch.cat([p_data_given_z_d0 + log_delta[0], p_data_given_gamma_d1 + log_delta[1]], dim=1)
  #p_data = torch.logsumexp(p_data, dim=1)

  #r = p_data_given_z_d0.T + log_delta[0] - p_data
  #v = p_data_given_gamma_d1.T + log_delta[1] - p_data

  #p_singlet = torch.exp(p_data_given_d0 + log_delta[0] - p_data)

  #return r.T, v.T.reshape(-1,nc,nc), -p_data, p_singlet

  return torch.logsumexp(p_data, dim=1).sum()

## for VI version
def compute_joint_probs(Theta, Y, S):

  log_pi = F.log_softmax(Theta['is_pi'], 0)
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1), 0).reshape(NC,NC)
  log_delta = F.log_softmax(Theta['is_delta'], 0)
  
  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  log_rzd0 = p_s_given_z + p_y_given_z + log_pi + log_delta[0]

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  log_vgd1 = p_y_given_gamma + p_s_given_gamma + log_tau + log_delta[1]

  #remove_indices = np.tril_indices(nc, -1) ## remove indices
  #log_rd1g[:, remove_indices[0], remove_indices[1]] = float("NaN")

  #q1 = r.exp() * log_rd0z #; q1[torch.isnan(q1)] = 0.0
  #q2 = v.exp() * log_rd1g #; q2[torch.isnan(q2)] = 0.0

  return log_rzd0, log_vgd1.reshape(Y.shape[0], NC*NC)

class BasicForwardNet(nn.Module):
  """Encoder for when data is input without any encoding"""
  def __init__(self, input_dim, output_dim, hidden_dim = 20, hidden_layer = 10):
    super().__init__()
    
    self.input = nn.Linear(input_dim, hidden_dim)
    #self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        
    self.linear1 = nn.ModuleList(
        [nn.Linear(hidden_dim, hidden_dim) for i in range(hidden_layer)]
    )

    self.output = nn.Linear(hidden_dim, output_dim)
    
  def forward(self, x):

    #out = F.relu(self.input(x))
    out = F.leaky_relu(self.input(x))

    #out = F.relu(self.linear1(out))

    for net in self.linear1:
      out = F.relu(net(out))
    
    out = self.output(out)
        
    return F.softmax(out, dim=1), F.log_softmax(out, dim=1) ## r/v/d log_r/log_v/log_d


## train function for toy data
def torch_em(Y, S, Theta):
  
  #wandb.init(project='emr_{}'.format(PROJECT_NAME))
  #config = wandb.config
  #config.data_type = 'sub_real'
  
  opt = optim.Adam(Theta.values())
         
  ls = []
  for i in range(N_ITER * N_ITER_OPT):
    #print(i)
    
    # E Step:
    with torch.no_grad():
      r, v, L, p_singlet = compute_r_v_2(Y, S, Theta)
            
    # M step (i.e. maximizing Q):
    #for j in range(N_ITER_OPT):
    opt.zero_grad()
    q = -Q(Theta, Y, S, r, v, None)
    q.backward()
    opt.step()
      
    #if i % (10 - 1) == 0:
    #  print("L: {}; {}".format(L, Theta['log_psi'].exp()))
      
    # Check for convergence
    if i > 0 and abs(ls[-1] - L) < TOL:
      print(L)
      print(F.log_softmax(Theta['is_delta'], 0).exp())
      print(Theta['log_psi'].exp())
      break
    
    ls.append(L)
    #wandb.log({'ll': L, 'Q': -q})
  
  return ls
  #columns = ["before", "after"]
  #xs = [i for i in range(NC)]
  #ys = [psi_init, Theta['log_psi'].exp().detach().numpy()]
  #wandb.log({"em_b_a" : wandb.plot.line_series(
  #  xs=xs,ys=ys,keys=columns, title = "cell size", xname="cluster #")})

## train function for toy data
def torch_mle(Y, S, Theta):
  
  opt = optim.Adam(Theta.values())
  loss = []
  for epoch in range(N_ITER * N_ITER_OPT):
    opt.zero_grad()
    nll = -ll(Y, S, Theta) #nll
    nll.backward()
    opt.step()

    #if epoch % (100 - 1) == 0:
      #print("L: {}; {}; {}".format(nlls, F.log_softmax(Theta['is_delta'], 0).exp(), F.log_softmax(Theta['is_pi'], 0).exp()))
      
      #print("nll: {}; {}".format(nll, Theta['log_psi'].exp()))
    
    if epoch > 0 and abs(loss[-1] - nll) < TOL:
      print(nll)
      print(F.log_softmax(Theta['is_delta'], 0).exp())
      #print(Theta['log_psi'].exp())
      break
  
    loss.append(nll)
    #wandb.log({'nll': nll})
  
  #columns = ["before", "after"]
  #xs = [i for i in range(NC)]
  #ys = [psi_init, Theta['log_psi'].exp().detach().numpy()]
  #wandb.log({"mle_b_a" : wandb.plot.line_series(
  #  xs=xs,ys=ys,keys=columns, title = "cell size", xname="cluster #")})
  return loss
  
## train function for toy data
def torch_vi(Y, S, Theta):
    
    r_net = BasicForwardNet(P, NC)
    v_net = BasicForwardNet(P, NC ** 2)
    d_net = BasicForwardNet(P, 2)
        
    params = list(Theta.values()) + list(r_net.parameters()) + list(v_net.parameters()) + list(d_net.parameters())
    opt = optim.AdamW(params)

    YS = torch.hstack((Y,S.reshape(-1,1))).float()
    YS1 = (YS - YS.mean(0)) / YS.std(0)

    loss = []
    for epoch in range(N_ITER * N_ITER_OPT):
    
        opt.zero_grad()
        r, log_r = r_net(YS1)
        v, log_v = v_net(YS1)
        d, log_d = d_net(YS1)
  
        ## row sums to 1 (from neural net)
        log_q0 = log_d[:,0].reshape(-1,1) + log_r ## like r in em version
        log_q1 = log_d[:,1].reshape(-1,1) + log_v ## like v in em version
    
        log_rzd0, log_vgd1 = compute_joint_probs(Theta, Y, S)

        entro = (d * log_d).sum() + (r * log_r).sum() + (v * log_v).sum()
        recon = (log_q0.exp() * log_rzd0).sum() + (log_q1.exp() * log_vgd1).sum()
        nelbo = (entro - recon).sum()
        nelbo.backward()
        opt.step()
        
        #if epoch % (100 - 1) == 0:
        #    print("nelbo: {}; {}; {}".format(nelbo, F.log_softmax(Theta['is_delta'], 0).exp(), F.log_softmax(Theta['is_pi'], 0).exp()))
  
        if epoch > 0 and abs(loss[-1] - nelbo) < TOL:
            print(nelbo)
            print(F.log_softmax(Theta['is_delta'], 0).exp())
            break
           
        loss.append(nelbo)
    
    return nelbo

In [4]:
def make_plot(init_mu, init_psi, Theta, model_type, run):
  
  for j in range(init_mu.shape[0]):
    plt.plot(init_mu[j], label="before")
    plt.plot(Theta['log_mu'][j].exp().detach().numpy(), label="after")
    plt.legend()
    plt.xlabel("proteins")
    plt.title("{} cluster id {}".format(model_type, j))
    #plt.savefig("/Users/jettlee/Desktop/res/{}1_r{}_c{}.png".format(model_type, run, j))
    plt.show()

    plt.plot(init_psi, label="before")
    plt.plot(Theta['log_psi'].exp().detach().numpy(), label="after")
    plt.legend()
    plt.xlabel("cluster id")
    plt.title("{} cellsizes".format(model_type))
    #plt.savefig("/Users/jettlee/Desktop/res/{}2_r{}.png".format(model_type, run))
    plt.show()
        
    mu_psi = torch.hstack((Theta['log_mu'], Theta['log_psi'].reshape(-1,1))).detach().numpy()
    df = pd.DataFrame((mu_psi - mu_psi.mean(0)) / mu_psi.std(0), columns = np.hstack((adata.var_names, 'size')))    
    #df = pd.DataFrame(mu_psi, columns = np.hstack((adata.var_names, 'size')))    
    #sns_plot = sns.heatmap(df, xticklabels=True)    
    ax = sns.heatmap(df, xticklabels=True)
    ax.figure.tight_layout()
    plt.xlabel("proteins")
    plt.ylabel("cluster ID")
    plt.title("{} run {}".format(model_type, run))
    #plt.savefig("/Users/jettlee/Desktop/res/{}3_r{}.png".format(model_type, run))
    plt.show()

In [5]:
from google.colab import drive
drive.mount('/content/gdrive')

import os
os.chdir('/content/gdrive/MyDrive/Colab Notebooks/')

PATH = '/Users/jettlee/Desktop/DAMM/'
#PATH = '/home/campbell/yulee/DAMM/'

import scanpy as sc
#adata = sc.read_h5ad("{}data/basel_zuri_subsample.h5ad".format(PATH))
adata = sc.read_h5ad("basel_zuri_subsample.h5ad")

adata = adata[:,['EGFR', 'ECadherin', 'ER', 'GATA3','Histone_H3_1', 
 'Ki67', 'SMA', 'Vimentin', 'cleaved_Parp', 'Her2',
 'p53', 'panCytokeratin', 'CD19', 'PR', 'Myc', 'Fibronectin', 'CK14',
 'Slug', 'CD20', 'vWF', 'Histone_H3_2', 'CK5', 'CD44', 'CD45', 'CD68',
 'CD3', 'CAIX', 'CK8/18', 'CK7', '80ArArArAr80Di', 
 'phospho Histone', 'phospho S6', 'phospho mTOR']]

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [7]:
from scipy.stats.mstats import winsorize

NC = 8 # number of clusters?
NF = adata.shape[1] #YY.shape[1]
NO = adata.shape[0] #YY.shape[0]
P = NF + 1 #NF + one more dim for cellsizes

N_ITER = 1000
N_ITER_OPT = 100
TOL = 1e-2
N_INIT = 10

In [None]:
import seaborn as sns

for ii in range(5):
    
    print("run {}".format(ii))
    
    cell_sel = np.random.choice(adata.shape[0], size=1000)
    adata = adata[cell_sel,:]
    
    YY = adata.X
    YY = np.array(np.arcsinh(YY / 5.))
    for i in range(NF):
        YY[:,i] = winsorize(YY[:,i], limits=[0, 0.01]).data #fixed this
    SS = adata.obs.Area
    SS = winsorize(SS, limits=[0, 0.01]).data

    kms = KMeans(NC).fit(YY)
    init_labels = kms.labels_
    init_label_class = np.unique(init_labels)

    mu_init = np.array([YY[init_labels == c,:].mean(0) for c in init_label_class])
    sigma_init = np.array([YY[init_labels == c,:].std(0) for c in init_label_class])
    psi_init = np.array([SS[init_labels == c].mean() for c in init_label_class])
    omega_init = np.array([SS[init_labels == c].std() for c in init_label_class])
    pi_init = np.array([np.mean(init_labels == c) for c in init_label_class])
    
    tau_init = np.ones((NC,NC))
    tau_init = tau_init / tau_init.sum()
    
    mu_psi0 = np.hstack((mu_init, psi_init.reshape(-1,1)))

    Theta = {
    'log_mu': np.log(mu_init + 1e-6),
    'log_sigma': np.log(sigma_init + 1e-6), #np.zeros_like(sigma_init),
    'log_psi': np.log(psi_init + 1e-6),
    'log_omega': np.log(omega_init + 1e-6),
    'is_delta': np.log([0.9, 1-0.9]),
    'is_pi': np.log(pi_init),
    'is_tau': np.log(tau_init)
    }
    
    Theta0 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    Theta1 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    Theta2 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    Theta3 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    Theta4 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    Theta5 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}
    #Theta0['log_psi'].requires_grad = False
    #Theta1['log_psi'].requires_grad = False
    Theta0['is_delta'].requires_grad = False
    Theta2['is_delta'].requires_grad = False
    Theta4['is_delta'].requires_grad = False

    Y = torch.tensor(YY)
    S = torch.tensor(SS)
    
    if ii == 0:
        #em0 = torch_em(Y, S, Theta0)
        em1 = torch_em(Y, S, Theta1)
        make_plot(mu_init, psi_init, Theta1, "em", ii+1)

        #mle0 = torch_mle(Y, S, Theta2)
        mle1 = torch_mle(Y, S, Theta3)
        make_plot(mu_init, psi_init, Theta3, "mle", ii+1)
    
    #vi0 = torch_vi(Y, S, Theta4)
    vi1 = torch_vi(Y, S, Theta5)
    make_plot(mu_init, psi_init, Theta5, "vi", ii+1)
