In [None]:
from time import time
import math, os

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from model import Autoencoder
import numpy as np
import pandas as pd
import collections
from sklearn import metrics
import scanpy as sc
import matplotlib.pyplot as plt


In [None]:
adata = sc.read_h5ad('data/gene_sorted_filtered_matrix.h5ad').T
barcodes = pd.read_csv('data/barcodes_filtered.tsv', header=None, sep='\t')
genes = pd.read_csv('data/genes.tsv', header=None, sep='\t')
ground_truth_labels = pd.read_csv('data/ground_truth_labels.tsv', sep='\t')

In [None]:
# following scdeepcluster here, encoding labels to ints and attaching to the anndata object

adata.obs_names = barcodes[0].values
adata.var_names = genes[0].values

ground_truth_labels = ground_truth_labels.set_index("NAME")  
y = pd.Categorical(adata.obs_names.map(ground_truth_labels["New_cellType"])).codes

adata.obs['Group'] = y

In [None]:
# standard filtering from scanpy workflow, this is also what our baseline used, we could consider tweaking this though

sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
print(adata.shape)
print(adata.n_vars)

In [None]:
# ZINB loss uses size factors and raw X values
# after saving those, normalize counts

adata.obs['n_counts'] = adata.X.sum(axis=1).A1  
adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)

adata.X = adata.X.toarray()

adata.raw = adata.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [None]:
# setting model args, defaults from scdeepcluster
# hyperparameter tuning on these?

input_dim = adata.n_vars
encoder_layers = [256, 64]
z_dim = 32
decoder_layers = [64, 256]

device = 'cuda'


In [None]:
# build model

model = Autoencoder(input_dim=adata.n_vars, z_dim=z_dim, encoder_layers=encoder_layers, decoder_layers=decoder_layers, device='cpu')
print(str(model))

In [None]:
# load pretrained weights if they exist, otherwise do pretraining step


if os.path.isfile('AE_weights.pth.tar'):
    print("Loading pretrained model weights")
    checkpoint = torch.load('AE_weights.pth.tar')
    model.load_state_dict(checkpoint['ae_state_dict'])
else:
    model.pretrain(X=adata.X, X_raw=adata.raw.X, size_factor=adata.obs.size_factors)

In [None]:
####

# Everything after this is more exploratory on how the model runs with knowing or not knowing n_clusters, how centroids are initialized, etc
# so its kind of messy

In [None]:
# before using the autoencoder clustering layer to make predictions, i wanted to see what leiden alg would do on the latent space



# convert latent representation into Anndata object, do knn and leiden (resolution tweaked to 14 clusters), plot
# on input: do pca, knn, leiden
# the idea here was to compare linear and non linear embedding, but still use leiden to actually cluster

# baseline:     pca->knn->leiden
# ae:        latent->knn->leiden

pretrain_latent = model.encodeBatch(torch.tensor(adata.X, dtype=torch.float64)).cpu().numpy()
adata_latent = sc.AnnData(pretrain_latent)

In [None]:
sc.pp.neighbors(adata_latent)
sc.tl.leiden(adata_latent, flavor="igraph", n_iterations=2, resolution=0.3)
sc.tl.umap(adata_latent)

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, flavor="igraph", n_iterations=2, resolution=0.8)
sc.tl.umap(adata)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(18, 6))

# leiden clustering on latent representation
sc.pl.umap(adata_latent, color="leiden", ax=ax[0], title="Latent Cluster Predictions", show=False)

# leiden clustering on PCs (same as baseline)
sc.pl.umap(adata, color="leiden", ax=ax[1], title="PCA-based Cluster Predictions", show=False)

# 3. Ground truth labels
sc.pl.umap(adata, color="Group", ax=ax[2], title="Ground Truth Clusters", show=False)

plt.tight_layout()
plt.show()

In [None]:
# this is taking the results of leiden clustering on the latent space and getting n_clusters and cluster_centers
# n_clusters and cluster_centers are used in the clustering phase of the autoencoder below

y_pred_init = np.asarray(adata_latent.obs['leiden'],dtype=int)
features = pd.DataFrame(adata_latent.X,index=np.arange(0,adata_latent.n_obs))
Group = pd.Series(y_pred_init,index=np.arange(0,adata_latent.n_obs),name="Group")
Mergefeature = pd.concat([features,Group],axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
n_clusters = cluster_centers.shape[0]
print('Estimated number of clusters: ', n_clusters)

In [None]:
# run clustering phase using n_clusters, cluster_centers, and y_pred_init from above cell
# in scdeepcluster, this is how the model is run when ground truth labels are no provided (therefore n_clusters has to be estimated using above cell)
# this copies line 145 on run_scdeepcluster.py

y_pred, _, _, _ = model.fit(X=adata.X, X_raw=adata.raw.X, size_factor=adata.obs.size_factors, n_clusters=n_clusters, init_centroid=cluster_centers, 
            y_pred_init=y_pred_init, y=y, num_epochs=300)



# run clustering phase using n_clusters which is known from provided ground truth labels
# cluster_centers and y_pred_init are found with kmeans - model.py line 136
# this copies line 130 on run_scdeepcluster.py
# 
# y_pred, _, _, _ = model.fit(X=adata.X, X_raw=adata.raw.X, size_factor=adata.obs.size_factors, n_clusters=n_clusters, init_centroid=cluster_centers, 
#            y_pred_init=y_pred_init, y=y, num_epochs=300)            
            

In [None]:
# final metrics

ami = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
print('Evaluating cells: AMI= %.4f, ARI= %.4f' % (ami, ari))

In [None]:
# plot clustering phase results

adata.obs['y_pred'] = y_pred.astype(str)
adata.obs['Group'] = y.astype(str)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# Clustering phase predictions
sc.pl.umap(adata, color="y_pred", ax=ax[0], title="Cluster Predictions", show=False)

# Ground truth labels
sc.pl.umap(adata, color="Group", ax=ax[1], title="Ground Truth Clusters", show=False)

plt.tight_layout()
plt.show()