In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
 
    


In [3]:
import sys
sys.path.insert(0, '/tf/pollock')

In [4]:
import logging
import os
import random
from collections import Counter
from importlib import reload
import time

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

import pollock
from pollock import PollockDataset, PollockModel, load_from_directory
# import pollock.models.analysis as pollock_analysis

In [5]:
import tensorflow as tf
from tensorflow.keras import layers

tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [6]:
from tensorflow.python.client import device_lib

device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 9940628303412410795,
 name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 3818076523845842837
 physical_device_desc: "device: XLA_CPU device",
 name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 7657496466880006617
 physical_device_desc: "device: XLA_GPU device",
 name: "/device:XLA_GPU:1"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 6821635573982979800
 physical_device_desc: "device: XLA_GPU device",
 name: "/device:XLA_GPU:2"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 3901834504750236522
 physical_device_desc: "device: XLA_GPU device",
 name: "/device:XLA_GPU:3"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 13431627675899494463
 physical_device_desc: "device: XLA_GPU device",
 

In [None]:
DATA_DIR = '/home/estorrs/data/single_cell_classification'
MODEL_DIR = '/home/estorrs/pollock/models'

In [None]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

In [None]:
run_name = 'br'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_counts_matrix.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_metadata.tsv')

training_image_dir = os.path.join(MODEL_DIR, 'scratch', run_name)
model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

n_per_cell_type = 5000
epochs = 5
batch_size = 128

In [None]:
expression_df = pd.read_hdf(expression_fp.replace('.tsv', '.h5'), 'df')
expression_df

In [None]:
label_df = pd.read_csv(
    label_fp,
    sep=sep
    )
label_df

In [None]:
label_df = label_df.set_index('cell_id')
label_df

In [None]:
label_df = label_df.loc[expression_df.index]
label_df

In [None]:
adata = anndata.AnnData(X=expression_df.values, obs=label_df)
adata.obs.index = expression_df.index
adata.var.index = expression_df.columns
adata

In [None]:
counts = Counter(adata.obs[cell_type_key])
counts.most_common()

In [None]:
## get rid of unknowns
adata = adata[adata.obs[cell_type_key]!='Unknown']
adata

In [None]:
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

# mito_genes = adata.var_names.str.startswith('MT-')
# # for each cell compute fraction of counts in mito genes vs. all genes
# # the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
# adata.obs['percent_mito'] = np.sum(
#     adata[:, mito_genes].X, axis=1) / np.sum(adata.X, axis=1)
# # add the total counts per cell as observations-annotation to adata
# adata.obs['n_counts'] = adata.X.sum(axis=1)

# sc.pl.scatter(adata, x='n_counts', y='percent_mito')
# sc.pl.scatter(adata, x='n_counts', y='n_genes')

In [None]:
# adata = adata[adata.obs.n_genes < 6000, :]
# adata = adata[adata.obs.percent_mito < 0.05, :]

In [None]:
# sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata
adata

In [None]:
# sc.pp.highly_variable_genes(adata, min_mean=.0125, max_mean=10., min_disp=0.2)
sc.pp.highly_variable_genes(adata, min_mean=None, max_mean=None, min_disp=0.2)

sc.pl.highly_variable_genes(adata)


In [None]:
np.count_nonzero(adata.var.highly_variable)

In [None]:
adata = adata[:, adata.var.highly_variable]
# sc.pp.scale(adata, max_value=2.)
sc.pp.scale(adata, max_value=None)

adata

In [None]:
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
sc.tl.umap(adata)

In [None]:
import random
## filter smartly
n = 1000

cell_type_to_idxs = {}
for cell_id, cell_type in zip(adata.obs.index, adata.obs[cell_type_key]):
    if cell_type not in cell_type_to_idxs:
        cell_type_to_idxs[cell_type] = [cell_id]
    else:
        cell_type_to_idxs[cell_type].append(cell_id)
        
def temp(ls):
    if len(ls) > n:
        return random.sample(ls, n)
    return random.sample(ls, int(len(ls) * .8))

cell_type_to_idxs = {k:temp(ls)
                     for k, ls in cell_type_to_idxs.items()}

train_idxs = np.asarray([x for ls in cell_type_to_idxs.values() for x in ls])
train_idxs = np.arange(adata.shape[0])[np.isin(np.asarray(adata.obs.index), train_idxs)]
val_idxs = np.delete(np.arange(adata.shape[0]), train_idxs)

train_idxs.shape, val_idxs.shape

In [None]:
train_adata = adata[train_idxs, :].copy()
val_adata = adata[val_idxs, :].copy()

In [None]:
Counter(val_adata.obs[cell_type_key]).most_common()

In [None]:
Counter(train_adata.obs[cell_type_key]).most_common()

In [None]:
## binarize for first pass

In [None]:
train_adata.X

In [None]:
# import seaborn as sns
# sns.distplot(train_adata.X[:, train_adata.var.index=='KRT18'].flatten(), kde=False, bins=100)

In [None]:
# sns.distplot(train_adata.raw.X[:, train_adata.raw.var.index=='EPCAM'].flatten(), kde=False, bins=100)

In [None]:
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

In [None]:
X_train = np.copy(train_adata.X)
X_val = np.copy(val_adata.X)

scaler = MinMaxScaler(feature_range=(0, 1))
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)

X_train.shape, X_val.shape

In [None]:
X_train

In [None]:
TRAIN_BUF = 10000
BATCH_SIZE = 128

train_dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(X_val).batch(BATCH_SIZE)

In [None]:
class BVAE(tf.keras.Model):
  def __init__(self, latent_dim, input_size):
    super(BVAE, self).__init__()
    self.latent_dim = latent_dim
    self.input_size = input_size
    self.inference_net = tf.keras.Sequential(
      [
          tf.keras.layers.InputLayer(input_shape=(input_size,)),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(latent_dim + latent_dim),
      ]
    )

    self.generative_net = tf.keras.Sequential(
        [
          tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(input_size),
        ]
    )

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)

  def encode(self, x):
    mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid=False):
    logits = self.generative_net(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs

    return logits

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)

# @tf.function
# def compute_loss(model, x):
#   mean, logvar = model.encode(x)
#   z = model.reparameterize(mean, logvar)
#   x_logit = model.decode(z)

# #   cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
# #   logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])

#   logpx_z = tf.metrics.msle(x_logit, x)
#   logpz = log_normal_pdf(z, 0., 0.)
#   logqz_x = log_normal_pdf(z, mean, logvar)

#   return -tf.reduce_mean(logpx_z + ((logpz - logqz_x) * .0005))
# #   return -tf.reduce_mean(logpx_z)



#         def vae_loss(y_true, y_pred):
#             return K.mean(recon_loss(y_true, y_pred) + self.alpha * kl_loss(y_true, y_pred))

#         def kl_loss(y_true, y_pred):
#             return 0.5 * K.sum(K.exp(self.log_var) + K.square(self.mu) - 1. - self.log_var, axis=1)

#         def recon_loss(y_true, y_pred):
#             return 0.5 * K.sum(K.square((y_true - y_pred)), axis=1)

@tf.function
def compute_loss(model, x, alpha=0.00005):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)

  kl_loss = .5 * tf.reduce_sum(tf.exp(logvar) + tf.square(mean) - 1. - logvar, axis=1)
  reconstruction_loss = .5 * tf.reduce_sum(tf.square((x - x_logit)), axis=1)

  overall_loss = tf.reduce_mean(reconstruction_loss + alpha * kl_loss)
  return overall_loss

@tf.function
def compute_apply_gradients(model, x, optimizer, alpha=.00005):
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x, alpha=alpha)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [None]:
epochs = 100
latent_dim = 100
alpha = 0.1
# num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
# random_vector_for_generation = tf.random.normal(
#     shape=[num_examples_to_generate, latent_dim])
model = BVAE(latent_dim, X_train.shape[1])

In [None]:
# generate_and_save_images(model, 0, random_vector_for_generation)

for epoch in range(1, epochs + 1):
  start_time = time.time()
  for train_x in train_dataset:
    compute_apply_gradients(model, train_x, optimizer, alpha=alpha)
  end_time = time.time()

  if epoch % 1 == 0:
    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:

      
      loss(compute_loss(model, test_x, alpha=alpha))
    print(f'epoch: {epoch}, val loss: {loss.result()}')


In [None]:
mean, logvar = model.encode(train_adata.X)
train_embeddings = model.reparameterize(mean, logvar).numpy()

mean, logvar = model.encode(val_adata.X)
val_embeddings = model.reparameterize(mean, logvar).numpy()

In [None]:
from sklearn.preprocessing import OrdinalEncoder
from sklearn.ensemble import RandomForestClassifier
encoder = OrdinalEncoder()
y_train = encoder.fit_transform(np.asarray(train_adata.obs[cell_type_key]).reshape(-1, 1)).flatten()
y_val = encoder.transform(np.asarray(val_adata.obs[cell_type_key]).reshape(-1, 1)).flatten()

In [None]:
clf = RandomForestClassifier()

In [None]:
%%time
clf.fit(train_embeddings, y_train)

In [None]:
clf.score(train_embeddings, y_train)

In [None]:
clf.score(val_embeddings, y_val)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
preds = clf.predict(val_embeddings)

In [None]:
c_mat = confusion_matrix(y_val, preds)
c_mat = c_mat / np.sum(c_mat, axis=1).reshape(-1, 1)
c_mat = (c_mat * 100).astype(np.int32)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
plt.figure(figsize=(10, 9))
sns.heatmap(c_mat, xticklabels=encoder.categories_[0], yticklabels=encoder.categories_[0],
           cmap='Blues',annot=True, fmt='d')
plt.xlabel('predicted')
plt.xlabel('true')
plt.tight_layout()
# plt.savefig('br_confusion_matrix.png')

In [None]:
mean, logvar = model.encode(adata.X)
cell_embeddings = model.reparameterize(mean, logvar).numpy()
cell_embeddings

In [None]:
adata.obsm['cell_embeddings'] = cell_embeddings
adata

In [None]:
predicted_cell_types = [encoder.categories_[0][int(i)] for i in clf.predict(cell_embeddings)]
adata.obs['predicted_cell_type'] = predicted_cell_types

In [None]:
import umap

In [None]:
adata.obsm['embedding_umap'] = umap.UMAP().fit_transform(adata.obsm['cell_embeddings'])
adata

In [None]:
sc.pl.umap(adata[val_idxs], color=['cell_type', 'predicted_cell_type', 'sample_id'], frameon=False, ncols=1)

In [None]:
adata.obs['embedding_umap1'] = adata.obsm['embedding_umap'][:, 0]
adata.obs['embedding_umap2'] = adata.obsm['embedding_umap'][:, 1]

In [None]:
sc.pl.scatter(adata[val_idxs], x='embedding_umap1', y='embedding_umap2', color='cell_type',
             frameon=False)

In [None]:
sc.pl.scatter(adata[val_idxs], x='embedding_umap1', y='embedding_umap2', color='predicted_cell_type',
             frameon=False)

In [None]:
sc.pl.scatter(adata[val_idxs], x='embedding_umap1', y='embedding_umap2', color='sample_id',
             frameon=False)