In [None]:
from importlib import reload
from load_cluster_data import load_cluster_data
import gc

import numpy as np
import torch
import pandas as pd 
import seaborn as sns
import collections

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

float_type = { 
    "device" : device, 
    "dtype" : torch.float, # save memory
}

hypers = {
    "eta" : 1., 
    "alpha_prior" : 1., # karin had 0.65 
    "pi_prior" : 1.
}

K = 10

import plotnine as p9
import scipy.sparse as sp
import matplotlib.pyplot as plt 
import splicing_PCA_utils
from splicing_PCA_utils import make_Y
from nuclear_norm_PCA import sparse_sum

from pca_kmeans_init import pca_kmeans_init
from betabinomo_LDA_singlecells_kinit import *
import betabinomo_LDA_singlecells_kinit
reload(betabinomo_LDA_singlecells_kinit)

import plotnine as p9
from plotnine import ggplot, geom_point, aes, stat_smooth, facet_wrap, geom_violin, theme
import plotnine


### Load data

In [None]:
import sys
sys.path.append('../../utils')
PATH_TO_LEAFLET_REPO = '/gpfs/commons/home/kisaev/Leaflet/src/beta-binomial-mix/'
sys.path.append(PATH_TO_LEAFLET_REPO)
from load_cluster_data import load_cluster_data


In [None]:
input_files_folder = '/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/TabulaMurisBrain/Brain/train/'

final_data, coo_counts_sparse, coo_cluster_sparse, cell_ids_conversion, junction_ids_conversion = load_cluster_data(
    input_folder = input_files_folder) 

### Run LDA

In [None]:
LDA_input=make_torch_data(final_data, **float_type)

In [None]:
print(final_data.head())
print(len(final_data.cell_id_index.unique())) # number of cells
print(len(final_data.junction_id_index.unique())) # number of junctions --> 125,611???? fixed now i think?

In [None]:
num_trials = 1 # can't currently run more than 1 or overflow GPU memory :( 
num_iters = 20 # should also be an argument that gets fed in
K = 30

# run coordinate ascent VI
print("Number of topics to be learned is: ", K)
ALPHA_f, PI_f, GAMMA_f, PHI_f, elbos_all = calculate_CAVI(K, LDA_input, float_type, hypers = hypers, num_iterations = num_iters)
elbos_all = np.array(elbos_all)
# plot ELBO
plt.plot(elbos_all[2:]); plt.show()

In [None]:
juncs_probs = ALPHA_f / (ALPHA_f+PI_f)

# how variable are juncs_probs across cell states/topics? 
plt.hist(juncs_probs.std(axis=1)); plt.show()

In [None]:
juncs_probs_df = pd.DataFrame(juncs_probs, columns = range(K))
juncs_probs_df["junction_id_index"] = junction_ids_conversion.junction_id_index.values
# convert to juncs_probs to pandas dataframe and calculate mean and std across cell states/topics
juncs_probs_df["junction_id"] = junction_ids_conversion.junction_id.values

In [None]:
def plot_juncObsUsage(junc_index):

    # print junction ID using junction_ids_conversion
    print(junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index])
    junc_id = junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index].junction_id.values[0]

    # get data for just junc_index 
    junc_dat=final_data[final_data.junction_id_index==junc_index]

    # make violin plot for junc_dat junction usage ratio coloured by cell_type and rotate plot 90 degrees
    plot = ggplot(junc_dat, aes(x='cell_type', y='juncratio', fill="cell_type")) + geom_violin() + geom_point() + plotnine.labels.ggtitle(junc_id) + plotnine.coords.coord_flip() 
    print(plot)

def plot_juncProbs(junc_index):
    
    # print junction ID using junction_ids_conversion
    print(junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index])
    junc_id = junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index].junction_id.values[0]
    
    # get data for just junc_index 
    junc_dat=juncs_probs_df[juncs_probs_df.junction_id_index==junc_index]
    junc_dat = junc_dat.melt().iloc[0:K]
    junc_dat.value = junc_dat.value.astype(float)
    # make violin plot for junc_dat junction usage ratio coloured by cell_type
    plot = ggplot(junc_dat, aes(x='variable', y='value')) + geom_point() 
    print(plot)


In [None]:
juncs_probs_df.iloc[0]

In [None]:
# indicate plot should be small 4 by 4
plotnine.options.figure_size = (4, 4)
plot_juncObsUsage(0)

In [None]:
plot_juncProbs(0)

In [None]:
theta_f = GAMMA_f / GAMMA_f.sum(1,keepdim=True)
theta_f_plot = pd.DataFrame(theta_f.cpu())
theta_f_plot['cell_id'] = cell_ids_conversion["cell_type"].to_numpy() # are these correct attachments? 
theta_f_plot.head()

In [None]:
sns.violinplot(data=theta_f_plot, y="cell_id", x=9)

In [None]:
sns.violinplot(data=theta_f_plot, y="cell_id", x=0)

In [None]:
# use theta_f values to do PCA 
theta_f.shape

In [None]:
# feed theta_f into PCA
pca = PCA(n_components=10)  # Set the number of components to 2 for visualization
pca_result = pca.fit_transform(theta_f)

In [None]:
plot_pcas(pca_result, cell_types, 0, 1)

In [None]:
plot_pcas(pca_result, cell_types, 0, 2)

In [None]:
plot_pcas(pca_result, cell_types, 0, 3)

In [None]:
plot_pcas(pca_result, cell_types, 0, 4)

In [None]:
plot_pcas(pca_result, cell_types, 0, 5)

In [None]:
pca_result

In [None]:
# make heatmap using pca_result
sns.heatmap(pca_result)

In [None]:
theta_f.shape

In [None]:
color_mapping = {}
unique_cell_types = set(cell_types)
num_colors = len(unique_cell_types)
color_palette = sns.color_palette('hsv', num_colors)  # Choose a color palette
color_palette

In [None]:
for i, cell_type in enumerate(unique_cell_types):
    color_mapping[cell_type] = color_palette[i]

In [None]:
theta_f.shape

In [None]:
sns.heatmap(theta_f)