In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import random
from collections import Counter
from importlib import reload
from pathlib import Path
import json
import time

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import umap
import scipy

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [None]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

# just looking at known markers and trying to disintangle malignant cells

#### tumor datasets

In [None]:
run_name = 'HTAN_breast_v5'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/BR/raw/houxiang_brca/merged.h5ad')

malignant_cell_type = 'BR_Malignant'

In [None]:
run_name = 'ccrcc_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/CCRCC/yige/adata.h5ad')

malignant_cell_type = 'Malignant proximal tubule'

In [None]:
run_name = 'pdac_v2'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/PDAC/pdac.h5ad')

malignant_cell_type = 'Ductal'

In [None]:
run_name = 'myeloma_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/myeloma/processed.h5ad')

malignant_cell_type = 'Plasma'

In [None]:
run_name = 'hnscc_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/HNSC/processed.h5ad')

malignant_cell_type = 'HNSC_Malignant/Epithelial'

In [None]:
run_name = 'cesc_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/CESC/cesc.h5ad')

malignant_cell_type = 'Malignant_Epithelial'

In [None]:
run_name = 'gbm_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/gbm/gbm.h5ad')

malignant_cell_type = 'Tumor'

In [None]:
# run_name = 'sc_cesc'

# expression_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
#                             'Assigned_CESC_9_processed_cluster_review_final_gene_expression_format.tsv')
# label_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
#                             'Assigned_CESC_9_processed_cluster_review_final_cell_metadata_format.tsv')

# model_save_dir = os.path.join(MODEL_DIR, run_name)

# sample_column = 'Genes'
# sep='\t'
# cell_type_key = 'cell_type'

# expression_df = pd.read_csv(expression_fp, sep=sep)
# expression_df

In [None]:
# expression_df = expression_df.set_index('Genes')
# expression_df = expression_df.transpose()
# expression_df

In [None]:
# run_name = 'sc_hnsc'

# expression_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
#                             'Assigned_WUHN_15_processed_cluster_review_gene_expression_format_2.tsv')
# label_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
#                             'Assigned_WUHN_15_processed_cluster_review_cell_metadata_format_2.tsv')

# model_save_dir = os.path.join(MODEL_DIR, run_name)

# sample_column = 'Genes'
# sep='\t'
# cell_type_key = 'cell_type'

In [None]:
# import h5py

In [None]:
# f = h5py.File(expression_fp)

In [None]:
# f.keys()

In [None]:
# df = pd.read_hdf(expression_fp.replace('.tsv', '.h5'), 'df')
# df

In [None]:
# df

In [None]:
# label_df = pd.read_csv(
#     label_fp,
#     sep=sep
#     )
# label_df = label_df.set_index('cell_id')
# label_df = label_df.loc[expression_df.index]
# label_df

In [None]:
# adata = anndata.AnnData(X=expression_df.values, obs=label_df)
# adata.obs.index = expression_df.index
# adata.var.index = expression_df.columns
# adata

In [None]:
# adata.write_h5ad('/data/single_cell_classification/tumor/CESC/cesc.h5ad')

In [None]:
# adata.obs['cell_type'] = [x if 'Malignant' not in x else 'Malignant_Epithelial' for x in adata.obs['cell_type']]

In [None]:
set(adata.obs['cell_type'])

In [None]:
metadata = {}

## set up figure directory structure
run_root = f'/data/single_cell_classification/outputs/{run_name}'
figure_root = run_root + '/figures'
Path(figure_root).mkdir(parents=True, exist_ok=True)

###### save a processed adata object with all the genes for later

In [None]:
## save a version of processed adata to use later that has all the genes
adata_full = adata.copy()
# sc.pp.filter_cells(adata_full, min_genes=200)
# sc.pp.filter_genes(adata_full, min_cells=3)

# mito_genes = adata_full.var_names.str.startswith('MT-')
# # for each cell compute fraction of counts in mito genes vs. all genes
# # the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
# adata_full.obs['percent_mito'] = np.sum(
#     adata_full[:, mito_genes].X, axis=1) / np.sum(adata_full.X, axis=1)
# # add the total counts per cell as observations-annotation to adata
# adata_full.obs['n_counts'] = adata_full.X.sum(axis=1)

# sc.pl.scatter(adata_full, x='n_counts', y='percent_mito')
# sc.pl.scatter(adata_full, x='n_counts', y='n_genes')

In [None]:
np.min(adata_full.X), np.max(adata_full.X)

In [None]:
# adata = adata[adata_full.obs.percent_mito < 0.2, :]
# sc.pp.normalize_total(adata_full, target_sum=1e6)
sc.pp.log1p(adata_full)
adata_full.raw = adata_full
sc.pp.scale(adata_full, max_value=None)
adata_full

In [None]:
np.min(adata_full.raw.X), np.max(adata_full.raw.X)

#### stem dataset

In [None]:
stem_adata = sc.read_h5ad('/data/stemness/ERP016000/merged.h5ad')

In [None]:
stem_adata

In [None]:
## go ahead and filter to make sure genes are in sych across datasets
genes = set.intersection(set(adata.var.index), set(stem_adata.var.index))
len(genes)

In [None]:
adata = adata[:, sorted(genes)]
stem_adata = stem_adata[:, sorted(genes)]
adata, stem_adata

In [None]:
## double check the normalization space for the stem data
## should be already log transformed for ERP016000
np.min(stem_adata.X), np.max(stem_adata.X)

In [None]:
## only needs to be run if you need to filter stem data

# sc.pp.filter_cells(stem_adata, min_genes=200)
# sc.pp.filter_genes(stem_adata, min_cells=3)

# mito_genes = stem_adata.var_names.str.startswith('MT-')
# # for each cell compute fraction of counts in mito genes vs. all genes
# # the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
# stem_adata.obs['percent_mito'] = np.sum(
#     stem_adata[:, mito_genes].X, axis=1) / np.sum(stem_adata.X, axis=1)
# # add the total counts per cell as observations-annotation to adata
# stem_adata.obs['n_counts'] = stem_adata.X.sum(axis=1)

# sc.pl.scatter(stem_adata, x='n_counts', y='percent_mito')
# sc.pl.scatter(stem_adata, x='n_counts', y='n_genes')

In [None]:
## only run if you need to filter tumor cells

# sc.pp.filter_cells(adata, min_genes=200)
# sc.pp.filter_genes(adata, min_cells=3)

# mito_genes = adata.var_names.str.startswith('MT-')
# # for each cell compute fraction of counts in mito genes vs. all genes
# # the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
# adata.obs['percent_mito'] = np.sum(
#     adata[:, mito_genes].X, axis=1) / np.sum(adata.X, axis=1)
# # add the total counts per cell as observations-annotation to adata
# adata.obs['n_counts'] = adata.X.sum(axis=1)

# sc.pl.scatter(adata, x='n_counts', y='percent_mito')
# sc.pl.scatter(adata, x='n_counts', y='n_genes')

In [None]:
# adata = adata[adata.obs.percent_mito < 0.2, :]
# adata

In [None]:
# sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.log1p(adata)
adata.raw = adata
adata

In [None]:
## check normalization space
np.min(adata.raw.X), np.max(adata.raw.X)

In [None]:
## not going to filter out genes for now
# sc.pp.highly_variable_genes(adata, min_mean=0.0, max_mean=10., min_disp=0.25)
# sc.pl.highly_variable_genes(adata)

In [None]:
# np.count_nonzero(adata.var.highly_variable)

In [None]:
# adata = adata[:, adata.var.highly_variable]
sc.pp.scale(adata, max_value=None)
adata

In [None]:
stem_adata.raw = stem_adata
sc.pp.scale(stem_adata, max_value=None)
stem_adata

In [None]:
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
sc.tl.umap(adata)

In [None]:
sc.tl.pca(stem_adata, svd_solver='arpack')
sc.pp.neighbors(stem_adata, n_neighbors=10, n_pcs=30)
sc.tl.umap(stem_adata)

In [None]:
## add gene count
adata.obs['gene_count'] = np.count_nonzero(adata.raw.X, axis=1).flatten()
stem_adata.obs['gene_count'] = np.count_nonzero(stem_adata.raw.X, axis=1).flatten()

In [None]:
"""
SCA1: ATXN1
CD29: ITGB1
OCT4: POU5F1
"""
all_markers = ['CD24', 'ITGB1', 'EPCAM', 'CD44', 'ATXN1']
## wnt signaling
all_markers += ['AXIN2', 'PTN', 'WIF1']
# more traditional stem markers
all_markers += ['SOX2', 'POU5F1', 'GATA6', 'NANOG']

print([m for m in all_markers if m not in genes])
markers = [m for m in all_markers if m in genes]
all_markers = [m for m in all_markers if m in adata_full.var.index]

In [None]:
figdir = os.path.join(figure_root, 'vanilla_umap')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
colors = ['day', 'gene_count']
colors += markers
sc.pl.umap(stem_adata, color=colors, ncols=1, color_map='Reds', save='_stem.pdf')

In [None]:
set(adata.obs.columns)

In [None]:
set(adata.obs['cell_type'])

In [None]:
set(adata.obs['sample_id'])

In [None]:
# adata.obs['sample_id'] = list(adata.obs['orig.ident'])
# adata.obs['sample_id_overall'] = [x for x in adata.obs['sample']]

In [None]:
## add macro sample id
adata.obs['sample_id_overall'] = ['-'.join(x.split('-')[:-1]) for x in adata.obs['sample_id']]
# adata.obs['sample_id_overall'] = [x for x in adata.obs['sample_id']]

In [None]:
colors = ['sample_id', 'cell_type', 'sample_id_overall', 'gene_count']
colors += markers
sc.pl.umap(adata, color=colors, ncols=1, color_map='Reds', save='_tumor_all_cells.pdf')

In [None]:
cell_type_key = 'cell_type'
tumor_adata = adata[adata.obs[cell_type_key]==malignant_cell_type].copy()
tumor_adata

In [None]:
colors = ['sample_id_overall', 'sample_id', 'cell_type', 'gene_count']
colors += markers
sc.pl.umap(tumor_adata, color=colors, ncols=1, color_map='Reds', save='_tumor_only.pdf')

In [None]:
# sc.pl.dotplot(tumor_adata, markers, groupby='sample_id')

In [None]:
sample_id = 'all'
# sample_id = 'TWCE-HT062B1-S1PAA1A1Z1B1'
# sample_id = 'TWCE-HT065B1-S1H7A2A1Z1B1'
colors = ['sample_id', 'cell_type', 'gene_count']
colors += markers
p_adata = tumor_adata[tumor_adata.obs['sample_id']==sample_id] if sample_id != 'all' else tumor_adata
sc.pl.umap(p_adata, color=colors, ncols=1, color_map='Reds',
          save=f'_tumor_{sample_id}.pdf')

In [None]:
if sample_id != 'all':
    tumor_adata = tumor_adata[tumor_adata.obs['sample_id']==sample_id].copy()
tumor_adata

In [None]:
metadata['tumor_sample_ids'] = sorted(set(tumor_adata.obs['sample_id']))
metadata['tumor_num_samples'] = len(metadata['tumor_sample_ids'])
metadata['tumor_cell_counts'] = Counter(tumor_adata.obs['cell_type'])
metadata['tumor_sample_cell_counts'] = Counter(tumor_adata.obs['sample_id'])

In [None]:
# colors = [x for x in markers]
colors = ['nanog_status', 'pou5f1_status', 'sox2_status']
p_adata = tumor_adata.copy()

if 'sparse' in str(type(tumor_adata.raw.X)):
    p_adata.obs['nanog_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='NANOG'].X > .1).toarray().flatten()
    p_adata.obs['pou5f1_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='POU5F1'].X > .1).toarray().flatten()
    p_adata.obs['sox2_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='SOX2'].X > .1).toarray().flatten()
else:
    p_adata.obs['nanog_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='NANOG'].X > .1).flatten()
    p_adata.obs['pou5f1_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='POU5F1'].X > .1).flatten()
    p_adata.obs['sox2_status'] = (tumor_adata.raw[:, tumor_adata.var.index=='SOX2'].X > .1).flatten()


sc.pl.umap(p_adata,
           color=colors, ncols=1, color_map='Reds', use_raw=True,
          save=f'_tumor_{sample_id}_marker_status.pdf')

In [None]:
tumor_adata.obs['nanog_status'] = p_adata.obs['nanog_status']
tumor_adata.obs['pou5f1_status'] = p_adata.obs['pou5f1_status']
tumor_adata.obs['sox2_status'] = p_adata.obs['sox2_status']

metadata['tumor_marker_counts'] = {'nanog': Counter(tumor_adata.obs['nanog_status']), 
         'pou5f1': Counter(tumor_adata.obs['pou5f1_status']), 
         'sox2': Counter(tumor_adata.obs['sox2_status'])}
metadata['tumor_marker_counts']

In [None]:
import pollock
nanog_cells = pollock.balancedish_training_generator(tumor_adata, 'nanog_status', 100)[0].obs.index
pou5f1_cells = pollock.balancedish_training_generator(tumor_adata, 'pou5f1_status', 100)[0].obs.index
sox2_cells = pollock.balancedish_training_generator(tumor_adata, 'sox2_status', 100)[0].obs.index

nanog_cells, pou5f1_cells, sox2_cells

In [None]:
metadata['tumor_marker_training_cells'] = {
    'nanog': list(nanog_cells),
    'pou5f1': list(pou5f1_cells),
    'sox2': list(sox2_cells),
}

In [None]:
idxs = np.asarray(sorted(set(np.concatenate((nanog_cells, pou5f1_cells, sox2_cells)))))
sc.pl.umap(tumor_adata[idxs], color=markers)

In [None]:
n = 200

remaining_idxs = np.asarray(list(tumor_adata.obs.index))
np.random.shuffle(remaining_idxs)

remaining_idxs = np.asarray([x for x in remaining_idxs[:n] if x not in idxs])

train_idxs = np.concatenate((idxs, remaining_idxs))

val_idxs = np.asarray([i for i in tumor_adata.obs.index if i not in set(train_idxs)])

tumor_train_adata = tumor_adata.copy()[train_idxs]
tumor_val_adata = tumor_adata.copy()[val_idxs]


tumor_train_adata, tumor_val_adata

In [None]:
len(set(tumor_train_adata.obs.index)), len(set(tumor_val_adata.obs.index))

In [None]:
stem_train_adata, stem_val_adata = pollock.balancedish_training_generator(stem_adata, 'day', 100)

stem_train_adata, stem_val_adata

In [None]:
stem_train_adata.obs['dataset'], stem_val_adata.obs['dataset'] = 'stem', 'stem'
tumor_train_adata.obs['dataset'], tumor_val_adata.obs['dataset'] = 'tumor', 'tumor'



combined_train_adata = tumor_train_adata.concatenate(stem_train_adata)
combined_val_adata = tumor_val_adata.concatenate(stem_val_adata)
combined_train_adata, combined_val_adata

In [None]:
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

X_train = np.copy(combined_train_adata.X)
X_train_tumor = np.copy(combined_train_adata[combined_train_adata.obs['dataset']=='tumor'].X)
X_train_stem = np.copy(combined_train_adata[combined_train_adata.obs['dataset']=='stem'].X)
X_val = np.copy(combined_val_adata.X)
X_val_tumor = np.copy(combined_val_adata[combined_val_adata.obs['dataset']=='tumor'].X)
X_val_stem = np.copy(combined_val_adata[combined_val_adata.obs['dataset']=='stem'].X)

scaler = MinMaxScaler(feature_range=(0, 1))
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_train_tumor = scaler.transform(X_train_tumor)
X_train_stem = scaler.transform(X_train_stem)
X_val_tumor = scaler.transform(X_val_tumor)
X_val_stem = scaler.transform(X_val_stem)

X_train.shape, X_val.shape

In [None]:
TRAIN_BUF = 10000
BATCH_SIZE = 64

train_dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
val_dataset = tf.data.Dataset.from_tensor_slices(X_val).batch(BATCH_SIZE)
train_dataset_tumor = tf.data.Dataset.from_tensor_slices(X_train_tumor).batch(BATCH_SIZE)
train_dataset_stem = tf.data.Dataset.from_tensor_slices(X_train_stem).batch(BATCH_SIZE)
val_dataset_tumor = tf.data.Dataset.from_tensor_slices(X_val_tumor[:1000]).batch(BATCH_SIZE)
val_dataset_stem = tf.data.Dataset.from_tensor_slices(X_val_stem[:1000]).batch(BATCH_SIZE)

In [None]:
class BVAE(tf.keras.Model):
  def __init__(self, latent_dim, input_size):
    super(BVAE, self).__init__()
    self.latent_dim = latent_dim
    self.input_size = input_size
    self.inference_net = tf.keras.Sequential(
      [
          tf.keras.layers.InputLayer(input_shape=(input_size,)),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(latent_dim + latent_dim),
      ]
    )

    self.generative_net = tf.keras.Sequential(
        [
          tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(800, activation='relu'),
          tf.keras.layers.Dropout(.2),
          tf.keras.layers.Dense(input_size),
        ]
    )

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)

  def encode(self, x):
    mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid=False):
    logits = self.generative_net(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs

    return logits

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)

@tf.function
def compute_loss(model, x, alpha=0.00005, boost_idxs=None, boost_value=2.):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)

  kl_loss = .5 * tf.reduce_sum(tf.exp(logvar) + tf.square(mean) - 1. - logvar, axis=1)
    
  if boost_idxs is not None:
#     diff = (x - x_logit) + (boost_value * (tf.gather(x, boost_idxs) - tf.gather(x_logit, boost_idxs)))
    normal_reconstruction =  tf.reduce_sum(tf.square((x - x_logit)), axis=1)
    boost_reconstruction = boost_value * tf.reduce_sum(tf.square(((tf.gather(x, boost_idxs, axis=1) - tf.gather(x_logit, boost_idxs, axis=1)))), axis=1)
    reconstruction_loss = .5 * (normal_reconstruction + boost_reconstruction)
#     diff = tf.square((x - x_logit)) + tf.square((boost_value * (tf.gather(x, boost_idxs, axis=1) - tf.gather(x_logit, boost_idxs, axis=1))))
#     reconstruction_loss = .5 * tf.reduce_sum(tf.square(diff), axis=1)
  else:
    reconstruction_loss = .5 * tf.reduce_sum(tf.square((x - x_logit)), axis=1)

  overall_loss = tf.reduce_mean(reconstruction_loss + alpha * kl_loss)
  return overall_loss

@tf.function
def compute_apply_gradients(model, x, optimizer, alpha=.00005, boost_idxs=None, boost_value=None):
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x, alpha=alpha, boost_idxs=boost_idxs, boost_value=boost_value)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

In [None]:
epochs = 250
latent_dim = 100
alpha = 5.

to_boost = ['SOX2', 'NANOG', 'POU5F1']
boost_idxs = np.asarray([i for i, gene in enumerate(combined_train_adata.var.index)
                            if gene in to_boost])
boost_value = 1.

model = BVAE(latent_dim, X_train.shape[1])

In [None]:
run_name
metadata['training'] = {
    'epochs': epochs,
    'latent_dim': latent_dim,
    'alpha': alpha,
    'boost_value': boost_value,
    'boost_genes': to_boost
}

In [None]:
# generate_and_save_images(model, 0, random_vector_for_generation)

for epoch in range(1, epochs + 1):
  start_time = time.time()
  for train_x in train_dataset:
    compute_apply_gradients(model, train_x, optimizer, alpha=alpha,
                           boost_idxs=boost_idxs, boost_value=boost_value)
  end_time = time.time()

  if epoch % 1 == 0:
    train_tumor_loss = tf.keras.metrics.Mean()
    for x in train_dataset_tumor:
        train_tumor_loss(compute_loss(model, x, alpha=alpha,
                                     boost_idxs=boost_idxs, boost_value=boost_value))
        
    train_stem_loss = tf.keras.metrics.Mean()
    for x in train_dataset_stem:
        train_stem_loss(compute_loss(model, x, alpha=alpha,
                                    boost_idxs=boost_idxs, boost_value=boost_value))
        
    val_tumor_loss = tf.keras.metrics.Mean()
    for x in val_dataset_tumor:
        val_tumor_loss(compute_loss(model, x, alpha=alpha,
                                   boost_idxs=boost_idxs, boost_value=boost_value))
        
    nonboost_val_tumor_loss = tf.keras.metrics.Mean()
    for x in val_dataset_tumor:
        nonboost_val_tumor_loss(compute_loss(model, x, alpha=alpha,
                                   boost_idxs=None, boost_value=None))
        
    val_stem_loss = tf.keras.metrics.Mean()
    for x in val_dataset_stem:
        val_stem_loss(compute_loss(model, x, alpha=alpha,
                                  boost_idxs=boost_idxs, boost_value=boost_value))
#     loss = tf.keras.metrics.Mean()
#     for test_x in test_dataset:
#       loss(compute_loss(model, test_x, alpha=alpha))
#       break

#     print(f'epoch: {epoch}, val loss: {compute_loss(model, next(iter(test_dataset)), alpha)}')
    print(f'epoch: {epoch}, train tumor loss: {train_tumor_loss.result()}, \
train stem loss: {train_stem_loss.result()}, val tumor loss: {val_tumor_loss.result()}, \
val stem loss: {val_stem_loss.result()}, non-boosted val tumor loss: {nonboost_val_tumor_loss.result()}')

In [None]:
combined_adata = combined_train_adata.concatenate(combined_val_adata)
mean, logvar = model.encode(combined_adata.X)
cell_embeddings = model.reparameterize(mean, logvar).numpy()
cell_embeddings.shape

In [None]:
combined_adata.obsm['cell_embeddings'] = cell_embeddings
combined_adata.obsm['cell_embedding_umap'] = umap.UMAP().fit_transform(combined_adata.obsm['cell_embeddings'])
combined_adata.obs['cell_embedding_umap1'] = combined_adata.obsm['cell_embedding_umap'][:, 0]
combined_adata.obs['cell_embedding_umap2'] = combined_adata.obsm['cell_embedding_umap'][:, 1]

combined_adata

In [None]:
figdir = os.path.join(figure_root, 'raw_bvae_cell_embeddings')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
variables = ['dataset', 'day', 'gene_count']
variables += markers
for var in variables:
#     if var in combined_adata.var.index or var in combined_adata.obs.columns:
    sc.pl.scatter(combined_adata, x='cell_embedding_umap1', y='cell_embedding_umap2', color=var,
                 frameon=False, color_map='Reds', save=f'_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'day', 'gene_count', 'sample_id_overall', 'sample_id']
variables += markers
for var in variables:
#     if var in combined_adata.var.index or var in combined_adata.obs.columns:
    sc.pl.scatter(combined_adata[combined_adata.obs['dataset']=='tumor'],
                  x='cell_embedding_umap1', y='cell_embedding_umap2', color=var,
                 frameon=False, color_map='Reds',
                 save=f'_tumor_only_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'day', 'gene_count']
variables += markers
for var in variables:
#     if var in combined_adata.var.index or var in combined_adata.obs.columns:
    sc.pl.scatter(combined_adata[combined_adata.obs['dataset']=='stem'],
                  x='cell_embedding_umap1', y='cell_embedding_umap2', color=var,
                 frameon=False, color_map='Reds',
                 save=f'_stem_only_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'sample_id_overall', 'gene_count', 'sample_id']
variables += markers
for sample in sorted(set(tumor_adata.obs['sample_id_overall'])):
    for var in variables:
    #     if var in combined_adata.var.index or var in combined_adata.obs.columns:
        sc.pl.scatter(combined_adata[combined_adata.obs['sample_id_overall']==sample],
                      x='cell_embedding_umap1', y='cell_embedding_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_tumor_only_umap_{sample}_{var}.pdf')
        plt.show()

###### batch correct samples

In [None]:
sample_id_to_avg = {sample_id:np.mean(
                    combined_adata[combined_adata.obs['sample_id']==sample_id].obsm['cell_embeddings'],
                    axis=0)
                    for sample_id in sorted(set(tumor_adata.obs['sample_id']))}

overall_mean = np.mean(np.asarray(list(sample_id_to_avg.values())), axis=0)
sample_id_to_delta = {sample_id:overall_mean - avg
                     for sample_id, avg in sample_id_to_avg.items()}

In [None]:
def batch_correction_operation(latent, sample_id):
    return latent + sample_id_to_delta[sample_id]
#     if dataset == 'stem': return latent + stem_delta
#     if dataset == 'tumor': return latent + tumor_delta


corrected_embeddings = np.asarray([batch_correction_operation(latent, sid) if dataset=='tumor' else latent
                           for latent, dataset, sid in zip(combined_adata.obsm['cell_embeddings'],
                                                      combined_adata.obs['dataset'],
                                                      combined_adata.obs['sample_id'])])

combined_adata.obsm['corrected_embeddings'] = corrected_embeddings
combined_adata.obsm['corrected_embeddings_umap'] = umap.UMAP(
    ).fit_transform(combined_adata.obsm['corrected_embeddings'])
combined_adata.obs['corrected_embeddings_umap1'] = combined_adata.obsm['corrected_embeddings_umap'][:, 0]
combined_adata.obs['corrected_embeddings_umap2'] = combined_adata.obsm['corrected_embeddings_umap'][:, 1]

combined_adata

In [None]:
figdir = os.path.join(figure_root, 'batch_corrected_bvae_cell_embeddings')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
variables = ['dataset', 'day', 'gene_count']
variables += markers
for var in variables:
#     if var in combined_adata.var.index or var in combined_adata.obs.columns:
    sc.pl.scatter(combined_adata, x='corrected_embeddings_umap1', y='corrected_embeddings_umap2', color=var,
                 frameon=False, color_map='Reds', save=f'_umap_{var}.pdf')
    plt.show()

In [None]:
# combined_adata.uns.pop('sample_id_colors')

In [None]:
variables = ['dataset', 'day', 'gene_count', 'sample_id_overall', 'sample_id']
variables += markers
for var in variables:
#     if var in combined_adata.var.index or var in combined_adata.obs.columns:
    sc.pl.scatter(combined_adata[combined_adata.obs['dataset']=='tumor'],
                  x='corrected_embeddings_umap1', y='corrected_embeddings_umap2', color=var,
                 frameon=False, color_map='Reds',
                 save=f'_tumor_only_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'gene_count', 'sample_id_overall', 'sample_id']
variables += markers
for sample in sorted(set(tumor_adata.obs['sample_id_overall'])):
    for var in variables:
        sc.pl.scatter(combined_adata[combined_adata.obs['sample_id_overall']==sample],
                      x='corrected_embeddings_umap1', y='corrected_embeddings_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_tumor_only_umap_{sample}_{var}.pdf')
        plt.show()

In [None]:
# from sklearn.decomposition import PCA
# combined_adata.obsm['cell_embedding_pca'] = PCA().fit_transform(combined_adata.obsm['cell_embeddings'])
# combined_adata.obsm['cell_embedding_pca'][combined_adata.obsm['cell_embedding_pca'] > 50] = 50
# combined_adata.obsm['cell_embedding_pca'][combined_adata.obsm['cell_embedding_pca'] < -50] = -50
# combined_adata.obs['cell_embedding_pca1'] = combined_adata.obsm['cell_embedding_pca'][:, 0]
# combined_adata.obs['cell_embedding_pca2'] = combined_adata.obsm['cell_embedding_pca'][:, 3]

# combined_adata

In [None]:
# variables = ['dataset', 'day']
# variables += markers
# for var in variables:
# #     if var in combined_adata.var.index or var in combined_adata.obs.columns:
#     sc.pl.scatter(combined_adata, x='cell_embedding_pca1', y='cell_embedding_pca2', color=var,
#                  frameon=False, color_map='Reds', )

#     plt.show()

###### transofrm batch corrected embeddings

In [None]:
avg_tumor = np.mean(combined_adata[combined_adata.obs['dataset']=='tumor'].obsm['cell_embeddings'], axis=0)
# avg_tumor = np.mean(combined_adata[combined_adata.obs['dataset']=='tumor'].obsm['corrected_embeddings'], axis=0)
avg_stem = np.mean(combined_adata[combined_adata.obs['dataset']=='stem'].obsm['cell_embeddings'], axis=0)
# avg_stem = np.mean(combined_adata[combined_adata.obs['dataset']=='stem'].obsm['corrected_embeddings'], axis=0)

mean = np.mean(np.asarray([avg_stem, avg_tumor]), axis=0)

stem_delta = mean - avg_stem
tumor_delta = mean - avg_tumor

delta = avg_stem - avg_tumor

In [None]:
def operation(latent, dataset):
    if dataset == 'stem': return latent + stem_delta
    if dataset == 'tumor': return latent + tumor_delta
#     if dataset == 'normal': return latent + normal_delta
    
#     if dataset == 'tumor': return latent + delta
#     return latent

transformed_embeddings = np.asarray([operation(latent, dataset)
#                            for latent, dataset in zip(combined_adata.obsm['cell_embeddings'],
                            for latent, dataset in zip(combined_adata.obsm['corrected_embeddings'],
                                                      combined_adata.obs['dataset'])])

combined_adata.obsm['transformed_embeddings'] = transformed_embeddings
combined_adata.obsm['transformed_embeddings_umap'] = umap.UMAP(
    ).fit_transform(combined_adata.obsm['transformed_embeddings'])
combined_adata.obs['transformed_embeddings_umap1'] = combined_adata.obsm['transformed_embeddings_umap'][:, 0]
combined_adata.obs['transformed_embeddings_umap2'] = combined_adata.obsm['transformed_embeddings_umap'][:, 1]

combined_adata

In [None]:
figdir = os.path.join(figure_root, 'batch_corrected_and_transformed_bvae_cell_embeddings')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
variables = ['dataset', 'day', 'gene_count']
variables += markers
for var in variables:
    sc.pl.scatter(combined_adata, x='transformed_embeddings_umap1', y='transformed_embeddings_umap2', color=var,
                 frameon=False, color_map='Reds',
                 save=f'_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'gene_count', 'sample_id_overall', 'sample_id']
variables += markers
for var in variables:
    sc.pl.scatter(combined_adata[combined_adata.obs['dataset']=='tumor'],
                  x='transformed_embeddings_umap1', y='transformed_embeddings_umap2', color=var,
                 frameon=False, color_map='Reds',
                 save=f'_tumor_only_umap_{var}.pdf')
    plt.show()

In [None]:
variables = ['dataset', 'gene_count', 'sample_id_overall', 'sample_id']
variables += markers
for sample in sorted(set(tumor_adata.obs['sample_id_overall'])):
    for var in variables:
        sc.pl.scatter(combined_adata[combined_adata.obs['sample_id_overall']==sample],
                      x='transformed_embeddings_umap1', y='transformed_embeddings_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_umap_{sample}_{var}.pdf')
        plt.show()

In [None]:
json.dump(metadata, open(os.path.join(run_root, 'metadata.json'), 'w'))

### DEG

In [None]:
combined_subset = combined_adata[combined_adata.obs['dataset']=='tumor'].copy()
combined_subset.obs.index = ['-'.join(x.split('-')[:-2]) for x in combined_subset.obs.index]
# combined_subset.obs.index = ['-'.join(x.split('-')[:2]) for x in combined_subset.obs.index]
# combined_subset.obs = combined_subset.obs.set_index('new_index')
combined_subset

In [None]:
len(set(['-'.join(x.split('-'[:2])) for x in combined_subset.obs.index]))

In [None]:
combined_adata.obs.index

In [None]:
combined_subset.obs.index

In [None]:
# tumor_full = adata_full[adata_full.obs['cell_type']=='BR_Malignant'].copy()
tumor_full = adata_full.copy()

tumor_full = tumor_full[combined_subset.obs.index]
tumor_full

In [None]:
tumor_full.obsm['cell_embeddings'] = np.copy(combined_subset.obsm['cell_embeddings'])
tumor_full.obsm['cell_embedding_umap'] = np.copy(combined_subset.obsm['cell_embedding_umap'])
tumor_full.obs['cell_embedding_umap1'] = list(tumor_full.obsm['cell_embedding_umap'][:, 0].flatten())
tumor_full.obs['cell_embedding_umap2'] = list(tumor_full.obsm['cell_embedding_umap'][:, 1].flatten())
tumor_full

tumor_full.obsm['cell_embeddings_bc'] = np.copy(combined_subset.obsm['corrected_embeddings'])
tumor_full.obsm['cell_embedding_bc_umap'] = np.copy(combined_subset.obsm['corrected_embeddings_umap'])
tumor_full.obs['cell_embedding_bc_umap1'] = list(tumor_full.obsm['cell_embedding_bc_umap'][:, 0].flatten())
tumor_full.obs['cell_embedding_bc_umap2'] = list(tumor_full.obsm['cell_embedding_bc_umap'][:, 1].flatten())
tumor_full

tumor_full.obsm['cell_embeddings_t'] = np.copy(combined_subset.obsm['transformed_embeddings'])
tumor_full.obsm['cell_embedding_t_umap'] = np.copy(combined_subset.obsm['transformed_embeddings_umap'])
tumor_full.obs['cell_embedding_t_umap1'] = list(tumor_full.obsm['cell_embedding_t_umap'][:, 0].flatten())
tumor_full.obs['cell_embedding_t_umap2'] = list(tumor_full.obsm['cell_embedding_t_umap'][:, 1].flatten())
tumor_full

In [None]:
tumor_full.obs['sample_id_overall'] = list(combined_subset.obs['sample_id_overall'])
tumor_full.obs['sample_id'] = list(combined_subset.obs['sample_id'])
tumor_full.obs['gene_count'] = np.count_nonzero(tumor_full.raw.X, axis=1).flatten()

In [None]:
cytotrace_results = pd.read_csv('/data/single_cell_classification/tumor/BR/cytotrace_results.txt',
                               sep='\t')
cytotrace_results.columns = ['cytotrace']
cytotrace_results.index = [x.replace('.', '-') for x in cytotrace_results.index]

cytotrace_results

In [None]:
cell_to_cytotrace = {k:v for k, v in zip(cytotrace_results.index, cytotrace_results['cytotrace'])}

In [None]:
tumor_full.obs['cytotrace'] = [cell_to_cytotrace.get(x, np.nan) for x in tumor_full.obs.index]
tumor_full

In [None]:
figdir = os.path.join(figure_root, 'analysis', 'full_gene_set_plots')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
"""
SCA1: ATXN1
CD29: ITGB1
OCT4: POU5F1
"""
marker_map = {
    'general': ['sample_id_overall', 'sample_id', 'gene_count', 'cytotrace'],
    'csc': ['CD24', 'ITGB1', 'EPCAM', 'CD44', 'ATXN1'],
    'wnt_signaling': ['AXIN2', 'PTN', 'WIF1'],
    'hedgehog_pathway': ['SHH', 'SMO', 'PTCH1'],
    'notch_signaling': ['NOTCH1', 'NOTCH2', 'JAG2', 'DLL1'],
    'quiecient stem cell': ['FGD5', 'HOXB5', 'MKI67'],
    'canonical_stem': ['SOX2', 'POU5F1', 'GATA6', 'NANOG']
}

all_markers = [v for vs in marker_map.values() for v in vs]

print([m for m in all_markers if m not in tumor_full.var.index])

marker_map = {k:[v for v in vs if v in tumor_full.var.index]
              for k, vs in marker_map.items()}
marker_map['general'] = ['sample_id_overall', 'sample_id', 'gene_count', 'cytotrace']

In [None]:
# tumor_full.uns.keys()
# tumor_full.uns.pop('')

In [None]:
tumor_full.obs['cytotrace']

In [None]:
for marker_type, markers in marker_map.items():
    for var in markers:
        sc.pl.scatter(tumor_full,
                      x='cell_embedding_umap1', y='cell_embedding_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_umap_{marker_type}_{var}.pdf')
        plt.show()

In [None]:
for marker_type, markers in marker_map.items():
    for var in markers:
        sc.pl.scatter(tumor_full,
                      x='cell_embedding_t_umap1', y='cell_embedding_t_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_umap_t_{marker_type}_{var}.pdf')
        plt.show()

In [None]:
import seaborn as sns
sns.scatterplot(x='gene_count', y='cytotrace', data=tumor_full.obs)

In [None]:
for marker_type, markers in marker_map.items():
    for var in markers:
        print(var)
        p = tumor_full.obs.copy()
        if var in tumor_full.var.index:
            p[var] = tumor_full.raw[:, var].X.flatten()
        sns.scatterplot(x=var, y='cytotrace', data=p)
        plt.savefig(os.path.join(figure_root, 'analysis', f'_cytotrace_scatter_{var}.pdf'))
        plt.show()

In [None]:
for marker_type, markers in marker_map.items():
    for var in markers:
        print(var)
        p = tumor_full.obs.copy()
        if var in tumor_full.var.index:
            p[var] = tumor_full.raw[:, var].X.flatten()
        sns.scatterplot(x=var, y='gene_count', data=p)
        plt.savefig(os.path.join(figure_root, 'analysis', f'_gene_count_scatter_{var}.pdf'))
        plt.show()

In [None]:
import seaborn as sns

In [None]:
duplicate = tumor_full.copy()
duplicate

In [None]:
# !conda install -y -c conda-forge leidenalg

In [None]:
sc.pp.pca(duplicate, n_comps=30)
sc.pp.neighbors(duplicate, n_neighbors=15)
# sc.pp.neighbors(duplicate, )
sc.tl.umap(duplicate)
sc.tl.leiden(duplicate)

In [None]:
sc.pl.umap(duplicate, color=marker_map['general'])

In [None]:
sc.pl.scatter(duplicate,
                      x='cell_embedding_umap1', y='cell_embedding_umap2', color='leiden',
                     frameon=False)

In [None]:
duplicate.obsm['X_umap_original'] = np.copy(duplicate.obsm['X_umap'])
duplicate.obs['leiden_original'] = list(duplicate.obs['leiden'])

In [None]:
# sc.pp.pca(duplicate, n_comps=30)
sc.pp.neighbors(duplicate, n_neighbors=15, use_rep='cell_embeddings_t')
# sc.pp.neighbors(duplicate, )
sc.tl.umap(duplicate, )

In [None]:
sc.tl.leiden(duplicate, resolution=1.)

In [None]:
figdir = os.path.join(figure_root, 'analysis', 'deg')
Path(figdir).mkdir(parents=True, exist_ok=True)
sc.settings.figdir = figdir

In [None]:
sc.pl.umap(duplicate, color=marker_map['general'], color_map='Reds')

In [None]:
sc.pl.umap(duplicate, color=marker_map['canonical_stem'], color_map='Reds')

In [None]:
sc.pl.umap(duplicate, color='leiden', color_map='Reds', save='_scanpy_leiden.pdf')

In [None]:
sc.pl.scatter(duplicate,
                      x='cell_embedding_t_umap1', y='cell_embedding_t_umap2', color='leiden',
                     frameon=False, save=f'_umap_leiden.pdf')

In [None]:
name_map = {
    't1': ['5'],
    't2': ['3'],
    't3': ['1', '0', '4'],
    't4': ['2'],
    't5': ['6']
}



r_name_map = {v:k for k, vs in name_map.items() for v in vs}
duplicate.obs['cluster'] = [r_name_map[x] for x in duplicate.obs['leiden']]

In [None]:
duplicate.uns.pop('cluster_colors')
duplicate.uns.keys()

In [None]:
# sc.pl.scatter(duplicate,
#                       x='cell_embedding_umap1', y='cell_embedding_umap2', color='cluster',
#                      frameon=False, save=f'_umap_cluster_name.pdf')

In [None]:
sc.pl.scatter(duplicate,
                      x='cell_embedding_t_umap1', y='cell_embedding_t_umap2', color='cluster',
                     frameon=False, save=f'_umap_cluster_name.pdf')

In [None]:
for marker_type, markers in marker_map.items():
    for var in markers:
        sc.pl.scatter(tumor_full,
                      x='cell_embedding_t_umap1', y='cell_embedding_t_umap2', color=var,
                     frameon=False, color_map='Reds',
                     save=f'_umap_{marker_type}_{var}.pdf')
        plt.show()

In [None]:
sc.tl.rank_genes_groups(duplicate, groupby='cluster', method='logreg')

In [None]:
sc.pl.rank_genes_groups_dotplot(duplicate, n_genes=10, save='_deg_all.pdf',
                               groups=sorted(set(duplicate.obs['cluster'])), )

In [None]:
# sc.tl.rank_genes_groups(duplicate, groupby='cluster', method='logreg', groups=['t1', 't2'])

In [None]:
# sc.pl.rank_genes_groups_dotplot(duplicate, n_genes=15, save='_deg_t12.pdf')

In [None]:
# use marker genes as dict to group them
sc.pl.dotplot(duplicate, {k:v for k, v in marker_map.items() if k != 'general'}, groupby='cluster',dot_max=.2,
             save='_deg_marker_groups.pdf')

In [None]:
for marker_group, markers in marker_map.items():
    if marker_group != 'general':
        sc.pl.tracksplot(duplicate, markers, groupby='cluster',
                        save=f'_trackplot_{marker_group}.pdf')

In [None]:
import seaborn as sns

In [None]:
sns.countplot(x='cluster', data=duplicate.obs, hue='sample_id_overall', )
# plt.legend(loc='upper left')
plt.legend(bbox_to_anchor=(1.0, 1.02))
plt.tight_layout()
plt.savefig(os.path.join(figdir, 'countplot_cluster_by_sample.pdf'))

In [None]:
sns.countplot(x='sample_id_overall', data=duplicate.obs, hue='cluster', )
plt.xticks(rotation=90)
plt.legend(bbox_to_anchor=(1.0, 1.02))
plt.tight_layout()
plt.savefig(os.path.join(figdir, 'countplot_sample_by_cluster.pdf'))

In [None]:
duplicate.write_h5ad(os.path.join(run_root, 'final.h5ad'))

In [None]:
just_stem = duplicate[((duplicate.obs['cluster']=='t1')|(duplicate.obs['cluster']=='t2'))].copy()
just_stem

In [None]:
just_stem.uns.pop('dendrogram_cluster')

In [None]:
sc.pp.neighbors(just_stem, n_neighbors=5, use_rep='cell_embeddings')
sc.tl.umap(just_stem, min_dist=0)

In [None]:
sc.pl.umap(just_stem, color=['sample_id', 'NANOG', 'SOX2', 'cluster'], color_map='Reds', ncols=2)

In [None]:
sc.pl.umap(just_stem, color=['sample_id', 'NANOG', 'SOX2', 'cluster'], color_map='Reds', ncols=2)

In [None]:
sc.tl.rank_genes_groups(just_stem, groupby='cluster')

In [None]:
sc.pl.rank_genes_groups_dotplot(just_stem, n_genes=10)

In [None]:
## save adata
