<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/VI_R_sweep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3

%%capture
!pip install wandb --upgrade

%pip install scanpy

import wandb

import argparse
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F

#from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader, random_split

from sklearn.metrics import confusion_matrix
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import adjusted_mutual_info_score

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

import random

import math
import pprint
import scanpy as sc
from scipy.stats.mstats import winsorize

def compute_p_y_given_z(Y, Theta, reg=1e-6):
  """ Returns NxC
  p(y_n | z_n = c)
  """
  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma']) + reg

  dist_Y = D.Normal(mu, sigma)
  return dist_Y.log_prob(Y.reshape(Y.shape[0], 1, Y.shape[1])).sum(2) # <- sum because IID over G

def compute_p_s_given_z(S, Theta, reg=1e-6):
  """ Returns NxC
  p(s_n | z_n = c)
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega']) + reg

  dist_S = D.Normal(psi, omega)
  return dist_S.log_prob(S.reshape(-1,1)) 

def compute_p_y_given_gamma(Y, Theta, reg=1e-6):
  """ NxCxC
  p(y_n | gamma_n = [c,c'])
  """
  mu = torch.exp(Theta['log_mu'])
  sigma = torch.exp(Theta['log_sigma']) + reg

  mu2 = mu.reshape(1, mu.shape[0], Y.shape[1])
  mu2 = (mu2 + mu2.permute(1, 0, 2)) / 2.0 # C x C x G matrix 

  sigma2 = sigma.reshape(1, mu.shape[0], Y.shape[1])
  sigma2 = (sigma2 + sigma2.permute(1,0,2)) / 2.0

  dist_Y2 = D.Normal(mu2, sigma2)
  return  dist_Y2.log_prob(Y.reshape(-1, 1, 1, Y.shape[1])).sum(3) # <- sum because IID over G

def compute_p_s_given_gamma(S, Theta, reg=1e-6):
  """ NxCxC
  p(s_n | gamma_n = [c,c'])
  """
  psi = torch.exp(Theta['log_psi'])
  omega = torch.exp(Theta['log_omega']) + reg

  psi2 = psi.reshape(-1,1)
  psi2 = psi2 + psi2.T

  omega2 = omega.reshape(-1,1)
  omega2 = omega2 + omega2.T

  dist_S2 = D.Normal(psi2, omega2)
  return dist_S2.log_prob(S.reshape(-1, 1, 1))

def compute_joint_probs(Theta, Y, S):

  nc = Theta['log_mu'].shape[0]
  log_pi = F.log_softmax(Theta['is_pi'], 0)
  log_tau = F.log_softmax(Theta['is_tau'].reshape(-1), 0).reshape(nc, nc)
  log_delta = F.log_softmax(Theta['is_delta'], 0)
  
  p_y_given_z = compute_p_y_given_z(Y, Theta)
  p_s_given_z = compute_p_s_given_z(S, Theta)

  log_rzd0 = p_s_given_z + p_y_given_z + log_pi + log_delta[0]

  p_y_given_gamma = compute_p_y_given_gamma(Y, Theta)
  p_s_given_gamma = compute_p_s_given_gamma(S, Theta)

  log_vgd1 = p_y_given_gamma + p_s_given_gamma + log_tau + log_delta[1]

  #remove_indices = np.tril_indices(nc, -1) ## remove indices
  #log_rd1g[:, remove_indices[0], remove_indices[1]] = float("NaN")

  #q1 = r.exp() * log_rd0z #; q1[torch.isnan(q1)] = 0.0
  #q2 = v.exp() * log_rd1g #; q2[torch.isnan(q2)] = 0.0

  return log_rzd0, log_vgd1.reshape(Y.shape[0], nc*nc)
  
## for VI version
class BasicForwardNet(nn.Module):
  """Encoder for when data is input without any encoding"""
  def __init__(self, input_dim, output_dim, hidden_dim, hidden_layer):
    super().__init__()
        
    self.input = nn.Linear(input_dim, hidden_dim)
    #self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        
    self.linear1 = nn.ModuleList(
        [nn.Linear(hidden_dim, hidden_dim) for i in range(hidden_layer)]
    )

    self.output = nn.Linear(hidden_dim, output_dim)


  def forward(self, x):
    out = self.input(F.relu(x))
    #out = F.relu(self.linear1(out))

    for net in self.linear1:
        out = net(F.relu(out))
    
    out = self.output(out)
        
    return F.softmax(out, dim=1), F.log_softmax(out, dim=1) ## r/v/d log_r/log_v/log_d

class BasicForwardNet_leaky(nn.Module):
  """Encoder for when data is input without any encoding"""
  def __init__(self, input_dim, output_dim, hidden_dim, hidden_layer):
    super().__init__()
        
    self.input = nn.Linear(input_dim, hidden_dim)
    #self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        
    self.linear1 = nn.ModuleList(
        [nn.Linear(hidden_dim, hidden_dim) for i in range(hidden_layer)]
    )

    self.output = nn.Linear(hidden_dim, output_dim)


  def forward(self, x):
    out = self.input(F.leaky_relu(x))

    for net in self.linear1:
        out = net(F.leaky_relu(out))
    
    out = self.output(out)
        
    return F.softmax(out, dim=1), F.log_softmax(out, dim=1) ## r/v/d log_r/log_v/log_d

In [None]:
def build_dataset(NC):

  adata = sc.read_h5ad("basel_zuri_subsample.h5ad")
  
  adata = adata[:,['EGFR', 'ECadherin', 'ER', 'GATA3','Histone_H3_1', 
                     'Ki67', 'SMA', 'Vimentin', 'cleaved_Parp', 'Her2',
                     'p53', 'panCytokeratin', 'CD19', 'PR', 'Myc', 'Fibronectin', 'CK14',
                     'Slug', 'CD20', 'vWF', 'Histone_H3_2', 'CK5', 'CD44', 'CD45', 'CD68',
                     'CD3', 'CAIX', 'CK8/18', 'CK7', '80ArArArAr80Di', 
                     'phospho Histone', 'phospho S6', 'phospho mTOR']]

  #cell_sel = np.random.choice(adata.shape[0], size=1000) ## select 1000
  #adata = adata[cell_sel,:]

  YY = adata.X
  YY = np.arcsinh(YY / 5.)
  for i in range(YY.shape[1]):
    YY[:,i] = winsorize(YY[:,i], limits=[0, 0.01]).data #fixed this
  
  SS = adata.obs.Area
  SS = winsorize(SS, limits=[0, 0.01]).data
  
  Y = torch.tensor(YY)
  S = torch.tensor(SS)
  YS = torch.hstack((Y,S.reshape(-1,1))).float()
  
  kms = KMeans(NC).fit(Y)
  init_labels = kms.labels_
  init_label_class = np.unique(init_labels)
  
  mu_init = np.array([YY[init_labels == c,:].mean(0) for c in init_label_class])
  sigma_init = np.array([YY[init_labels == c,:].std(0) for c in init_label_class])
  psi_init = np.array([SS[init_labels == c].mean() for c in init_label_class])
  omega_init = np.array([SS[init_labels == c].std() for c in init_label_class])
  pi_init = np.array([np.mean(init_labels == c) for c in init_label_class])
  
  tau_init = np.ones((NC,NC))
  tau_init = tau_init / tau_init.sum()
  
  Theta = {
    'log_mu': np.log(mu_init + 1e-6),
    'log_sigma': np.log(sigma_init + 1e-6), #np.zeros_like(sigma_init),
    'log_psi': np.log(psi_init + 1e-6),
    'log_omega': np.log(omega_init + 1e-6),
    'is_delta': np.log([0.9, 1-0.9]),
    'is_pi': np.log(pi_init),
    'is_tau': np.log(tau_init)
  }

  Theta0 = {k: torch.tensor(v, requires_grad=True) for (k,v) in Theta.items()}

  return YS, (YS - YS.mean(0)) / YS.std(0), Y, S, Theta0 

In [None]:
def build_network(P, NC, hidden_dim, hidden_layer, activation_func):
  if activation_func == "relu":
    r_net = BasicForwardNet(P, NC, hidden_dim, hidden_layer)
    v_net = BasicForwardNet(P, NC ** 2, hidden_dim, hidden_layer)
    d_net = BasicForwardNet(P, 2, hidden_dim, hidden_layer)
  elif activation_func == "leaky_relu":
    r_net = BasicForwardNet_leaky(P, NC, hidden_dim, hidden_layer)
    v_net = BasicForwardNet_leaky(P, NC ** 2, hidden_dim, hidden_layer)
    d_net = BasicForwardNet_leaky(P, 2, hidden_dim, hidden_layer)

  return r_net, v_net, d_net

In [None]:
def build_optimizer(Theta, r_net, v_net, d_net, learning_rate=1e-3):
  
  params = list(Theta.values()) + list(r_net.parameters()) + list(v_net.parameters()) + list(d_net.parameters())
  optimizer = optim.Adam(params, lr=learning_rate)
  
  return optimizer

In [None]:
#def train_epoch(r_net, v_net, d_net, loader, optimizer):
def train_epoch(Theta, YS, Y, S, r_net, v_net, d_net, optimizer):

  #for i, batch_data in enumerate(loader):

  optimizer.zero_grad()
  r, log_r = r_net(YS)
  v, log_v = v_net(YS)
  d, log_d = d_net(YS)

  ## row sums to 1 (from neural net)
  log_q0 = log_d[:,0].reshape(-1,1) + log_r ## like r in em version
  log_q1 = log_d[:,1].reshape(-1,1) + log_v ## like v in em version
  log_rzd0, log_vgd1 = compute_joint_probs(Theta, Y, S)

  entro = (d * log_d).sum() + (r * log_r).sum() + (v * log_v).sum()
  recon = (log_q0.exp() * log_rzd0).sum() + (log_q1.exp() * log_vgd1).sum()
  nelbo = entro - recon
  nelbo.backward()
  optimizer.step()

  return nelbo, entro, recon

In [None]:
def train(config=None):
  # Initialize a new wandb run
  with wandb.init(config=config):

    config = wandb.config
    config.tol = 1e-2
    config.N_ITER = 1000
    config.N_ITER_OPT = 100
    config.NC = 8
    config.NF = 33
    config.P = 34
   
    #loader = build_dataset(config.batch_size)
    YS, YS1, Y, S, Theta0 = build_dataset(config.NC)
    net1, net2, net3 = build_network(config.P, config.NC, config.hidden_dim, config.hidden_layer, config.activation_func)
    optimizer = build_optimizer(Theta0, net1, net2, net3)

    loss = []
    for epoch in range(config.N_ITER * config.N_ITER_OPT):

      nelbo, entro, recon = train_epoch(Theta0, YS1, Y, S, net1, net2, net3, optimizer)
      wandb.log({'epoch': epoch, 'entropy': entro, 'reconstruction_loss': recon, 'nelbo': nelbo})

      if epoch > 0 and abs(loss[-1] - nelbo.sum()) < config.tol:
        print(nelbo.sum())
        print(F.log_softmax(Theta0['is_delta'], 0).exp())
        break
    
      loss.append(nelbo)  

In [None]:
sweep_config = { 
    'method': 'random' 
    }

metric = {
    'name': 'nelbo',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'activation_func': {
        'values': ['relu', 'leaky_relu']
        },
    'hidden_dim': {
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(5),
        'max': math.log(20)
        },
    'hidden_layer': {
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1),
        'max': math.log(10)
        },
    'hidden_layer': {
          'values': [2, 4, 6, 8, 10]
        },
    }

sweep_config['parameters'] = parameters_dict

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import os
os.chdir('/content/gdrive/MyDrive/Colab Notebooks/')

In [None]:
wandb.login(key='4117bb00bef94e0904c16afed79f1888e0839eb9')

sweep_id = wandb.sweep(sweep_config, project='vi_sweep_win.99_5000')

wandb.agent(sweep_id, train, count=50)

[34m[1mwandb[0m: Currently logged in as: [33myujulee[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: cxbrrtb4
Sweep URL: https://wandb.ai/yujulee/vi_sweep_win.99_5000/sweeps/cxbrrtb4


[34m[1mwandb[0m: Agent Starting Run: g5madhh7 with config:
[34m[1mwandb[0m: 	activation_func: leaky_relu
[34m[1mwandb[0m: 	hidden_dim: 6
[34m[1mwandb[0m: 	hidden_layer: 2


tensor(-162891.1142, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([0.3641, 0.6359], dtype=torch.float64, grad_fn=<ExpBackward>)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
entropy,-14992.16602
epoch,5104.0
nelbo,-162891.11419
reconstruction_loss,147898.94818


0,1
entropy,▁▅▅▅▆▆▇▇▇███████████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
nelbo,█▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reconstruction_loss,▁▅▅▆▆▇▇▇▇▇▇▇▇▇██████████████████████████


[34m[1mwandb[0m: Agent Starting Run: mekkuwdc with config:
[34m[1mwandb[0m: 	activation_func: leaky_relu
[34m[1mwandb[0m: 	hidden_dim: 13
[34m[1mwandb[0m: 	hidden_layer: 4


tensor(-164534.2485, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([0.0938, 0.9062], dtype=torch.float64, grad_fn=<ExpBackward>)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
entropy,-12109.39551
epoch,3421.0
nelbo,-164534.24854
reconstruction_loss,152424.85303


0,1
entropy,▁▄▄▅▆▆▇▇▇▇▇█████████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
nelbo,█▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reconstruction_loss,▁▅▆▆▆▇▇▇▇▇▇▇████████████████████████████


[34m[1mwandb[0m: Agent Starting Run: etgupdgq with config:
[34m[1mwandb[0m: 	activation_func: relu
[34m[1mwandb[0m: 	hidden_dim: 15
[34m[1mwandb[0m: 	hidden_layer: 2


tensor(-165980.0305, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([0.0336, 0.9664], dtype=torch.float64, grad_fn=<ExpBackward>)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
entropy,-11711.02148
epoch,5112.0
nelbo,-165980.03047
reconstruction_loss,154269.00898


0,1
entropy,▁▄▆▆▆▇▇▇▇▇▇▇▇███████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
nelbo,█▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reconstruction_loss,▁▄▅▅▆▆▇▇▇▇▇▇▇▇██████████████████████████


[34m[1mwandb[0m: Agent Starting Run: vvkl9dy0 with config:
[34m[1mwandb[0m: 	activation_func: relu
[34m[1mwandb[0m: 	hidden_dim: 9
[34m[1mwandb[0m: 	hidden_layer: 4


tensor(-165479.0080, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([0.0623, 0.9377], dtype=torch.float64, grad_fn=<ExpBackward>)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
entropy,-11675.57031
epoch,4384.0
nelbo,-165479.008
reconstruction_loss,153803.43769


0,1
entropy,▁▃▅▆▆▇▇█████████████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
nelbo,█▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reconstruction_loss,▁▄▅▆▆▇▇▇▇▇▇█████████████████████████████


[34m[1mwandb[0m: Agent Starting Run: 3dqwq7hc with config:
[34m[1mwandb[0m: 	activation_func: leaky_relu
[34m[1mwandb[0m: 	hidden_dim: 5
[34m[1mwandb[0m: 	hidden_layer: 2
