In [None]:
from importlib import reload
from load_cluster_data import load_cluster_data
from pca_kmeans_init import pca_kmeans_init
from betabinomo_mix_singlecells import *
import betabinomo_mix_singlecells
reload(betabinomo_mix_singlecells)
import torch
import sklearn.manifold 
import plotnine as p9
import time
# indicate plot should be small 4 by 4
import plotnine as p9
from plotnine import ggplot, geom_point, aes, stat_smooth, facet_wrap, geom_violin, theme, element_blank, geom_text
import plotnine
from tqdm import tqdm
plotnine.options.figure_size = (4, 4)
import seaborn as sns
sns.set_theme(style="whitegrid")

### Settings and Load data

In [None]:
input_file = '/gpfs/commons/scratch/kisaev/ss_tabulamuris_test/Leaflet/clustered_junctions_noanno.txt_anno_free_50_500000_10_5_0.1_single_cell.h5'

# this folder contains input data for each tissue cell type sample
input_files_folder = '/gpfs/commons/scratch/kisaev/ss_tabulamuris_test/Leaflet/Marrow/'

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

MAKE_PCA_TSNE = True

float_type = { 
    "device" : device, 
    "dtype" : torch.float, # save memory
}

hypers = {
    "eta" : 1., 
    "alpha_prior" : 1., # karin had 0.65 
    "pi_prior" : 1.
}

In [None]:
final_data, coo_counts_sparse, coo_cluster_sparse, cell_ids_conversion, junction_ids_conversion = load_cluster_data(
    input_folder = input_files_folder) 

In [None]:
N = coo_cluster_sparse.shape[0]
J = coo_cluster_sparse.shape[1]

cell_index_tensor, junc_index_tensor, my_data = betabinomo_mix_singlecells.make_torch_data(final_data, **float_type)

In [None]:
num_trials = 5 # should also be an argument that gets fed in
num_iters = 35 # should also be an argument that gets fed in
K = 22

# loop over the number of trials (for now just testing using one trial but in general need to evaluate how performance is affected by number of trials)
reload(betabinomo_mix_singlecells)

start_time = time.time()
results = [ betabinomo_mix_singlecells.calculate_CAVI(K, my_data, float_type, hypers, init_labels = None, num_iterations = num_iters) 
           for t in range(num_trials) ]

# write the above line use fstring
print(f"This took {time.time() - start_time} seconds")

##### -> why some iterations converge after 20 iterations and some go on for 100 iterations?

### Check how consistently cells get co-assigned together across iterations 

In [None]:
# extract PHI_f from every trial in num_trials
all_iters_PHI_f = [ result[3] for result in results ]
i = 0

# Create an empty list to store DataFrames from each iteration
dfs_list = []

for PHI_var in all_iters_PHI_f:

    probability_tensor = PHI_var

    # Create an array with cell IDs (e.g., cell_0, cell_1, ..., cell_(N-1))
    cell_ids = np.arange(probability_tensor.shape[0])
    cell_ids = [cell_id for cell_id in cell_ids]

    # Get the cluster IDs for each cell based on the maximum probability
    cluster_ids = np.argmax(probability_tensor, axis=1)

    # Create a DataFrame with the cell_id, cluster_id, and probability columns
    df = pd.DataFrame({"cell_id": cell_ids, "cluster_id": cluster_ids})

    # Add column with iteration number
    df["iteration"] = i
    i += 1
    # Append the DataFrame to the list
    dfs_list.append(df)

# Concatenate all the DataFrames into a single DataFrame
concatenated_df = pd.concat(dfs_list, ignore_index=True)

In [None]:
# initiate list to save results for each iteration

all_iters_results = [None] * num_trials

for trial in range(num_trials):
    # Find the unique clusters for each cell_id
    cell_by_cell_matrix = np.zeros((N, N))

    clusters = concatenated_df.loc[concatenated_df["iteration"] == trial, ["cell_id", "cluster_id"]]
    unique_clusters = clusters.set_index('cell_id')['cluster_id'].to_dict()

    # Fill the cell_by_cell_matrix using numpy indexing
    for cell_id, cluster_id in unique_clusters.items():
        cell_by_cell_matrix[cell_id, cell_id] = 1
        same_cluster_cells = clusters[clusters["cluster_id"] ==  cluster_id].cell_id.values
        cell_by_cell_matrix[cell_id, same_cluster_cells] = 1

    all_iters_results[trial] = cell_by_cell_matrix

In [None]:
# get all pairs of num_trials 
import itertools
all_pairs = list(itertools.combinations(range(num_trials), 2))
# Create an empty list to store DataFrames from each iteration
dfs_list = []

for pair in tqdm(all_pairs):
    ## assess similarity between iterations
    x = (all_iters_results[pair[0]] - all_iters_results[pair[1]])
    unique, counts = np.unique(x, return_counts=True)
    # turn unique, counts into dataframe 
    df = pd.DataFrame({'unique': unique, 'counts': counts})
    # get percentage for counts 
    df['percentage'] = df['counts']/df['counts'].sum()
    df["pair"] = str(pair)
    dfs_list.append(df)

# Concatenate all the DataFrames into a single DataFrame
concatenated_iters_comp = pd.concat(dfs_list, ignore_index=True)

In [None]:
concatenated_iters_comp = concatenated_iters_comp[["unique", "pair", "percentage"]]

# turn into wide format for plotting heatmap
concatenated_iters_comp_wide = concatenated_iters_comp.pivot(index='pair', columns='unique', values='percentage')

In [None]:
concatenated_iters_comp.sort_values(by=['percentage'], inplace=True, ascending=False)

print(f"The minimum percentage of matching cell pairs across all trials is {concatenated_iters_comp[concatenated_iters_comp['unique'] == 0]['percentage'].min().round(2)}")

In [None]:
#concatenated_iters_comp

# use seaborn to plot the results, x-axis can be the pair and y-axis can be the value in the unique column 
g = sns.clustermap(concatenated_iters_comp_wide, cmap="crest", linewidth=.1, figsize=(4, 6), yticklabels=1)
g.ax_col_dendrogram.remove()

### Evaluate the learned posterions

In [None]:
best = np.argmax([ g[-1][-1] for g in results ]) # final ELBO
print(f"The trial with the highest ELBO was {best}")

# print rows that contains best in the pair and only where unique == 0
best_res_comp = (concatenated_iters_comp[concatenated_iters_comp.pair.str.contains(str(best))])
print(best_res_comp[best_res_comp['unique'] == 0])

print(f"The pair with the highest percentage overlap is {concatenated_iters_comp.iloc[0].pair}")
print(f"The highest percentage overlap across pairs was {concatenated_iters_comp.iloc[0].percentage.round(3)}")

ALPHA_f, PI_f, GAMMA_f, PHI_f, elbos_all = results[best]
elbos_all = np.array(elbos_all)
plt.plot(elbos_all[1:]); plt.show()

In [None]:
juncs_probs = ALPHA_f / (ALPHA_f+PI_f)    
plt.hist(juncs_probs.cpu().numpy().flatten(), 20); plt.show()

In [None]:
theta_f_plot = pd.DataFrame(PHI_f.cpu().numpy())
theta_f_plot['cell_id'] = cell_ids_conversion["cell_type"].to_numpy()
theta_f_plot_summ = theta_f_plot.groupby('cell_id').mean()
print(theta_f_plot_summ)

In [None]:
GAMMA_f

In [None]:
np.random.dirichlet(GAMMA_f, 10)

In [None]:
GAMMA_f / GAMMA_f.sum()

In [None]:
# How much each cell state is used 
#latent proportions describe the general prevalence of each cluster in the datasetN
# cluster proportions via theta ~ dirichlet(GAMMA_f)
# each cell gets an assignment to a cluster via z_c | theta ~ categorical(theta)

theta = GAMMA_f / GAMMA_f.sum()
theta = theta.cpu().numpy()
theta_sorted = np.sort(theta)
theta_sorted

In [None]:
PHI_f #<- this is the matrix of probabilities of each cell belonging to each cluster

In [None]:
plt.bar(np.arange(K)+1,theta_sorted[::-1]); plt.show()

In [None]:
to_keep = theta > 0.01

x = PHI_f.cpu().numpy()
x = x[:,to_keep]
#x -= x.mean(1,keepdims=True)
#x /= x.std(1,keepdims=True)
_ = plt.hist(x.flatten(),100)

In [None]:
ct = pd.crosstab(cell_ids_conversion["cell_type"], x.argmax(axis=1) )
print(ct)

ct_np = ct.to_numpy()
print(ct_np)

ct_np = ct_np / ct_np.sum(1, keepdims=True) # normalize cell-type counts
print(ct_np)

ct_np = ct_np / ct_np.sum(0, keepdims=True)
print(ct_np)

ct.iloc[:,:] = ct_np

ax = plt.figure(figsize=[10,8])
sns.clustermap(ct, dendrogram_ratio=0.15, vmin = None, figsize=(12,6), annot = False)

## **The dataset used to generate the heatmap above likely has a bug**

In [None]:
juncs_probs_df = pd.DataFrame(juncs_probs, columns = range(K))
# add "cell_state" to each column name 
juncs_probs_df.columns = ["cell_state_" + str(col) for col in juncs_probs_df.columns]
juncs_probs_df["junction_id_index"] = junction_ids_conversion.junction_id_index.values
# convert to juncs_probs to pandas dataframe and calculate mean and std across cell states/topics
juncs_probs_df["junction_id"] = junction_ids_conversion.junction_id.values

def plot_juncObsUsage(junc_index):

    # print junction ID using junction_ids_conversion
    print(junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index])
    junc_id = junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index].junction_id.values[0]

    # get data for just junc_index 
    junc_dat=final_data[final_data.junction_id_index==junc_index]
    print(junc_dat.cell_type.value_counts())

    # make violin plot for junc_dat junction usage ratio coloured by cell_type and rotate plot 90 degrees
    plot = ggplot(junc_dat, aes(x='cell_type', y='juncratio', fill="cell_type")) + geom_violin() + geom_point() + plotnine.labels.ggtitle(junc_id) + plotnine.coords.coord_flip() 

    # add number of cells in each cell_type to plot 
    print(plot)

def plot_juncProbs(junc_index):
    
    # print junction ID using junction_ids_conversion
    print(junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index])
    junc_id = junction_ids_conversion[junction_ids_conversion["junction_id_index"] == junc_index].junction_id.values[0]
    
    # get data for just junc_index 
    junc_dat=juncs_probs_df[juncs_probs_df.junction_id_index==junc_index]
    junc_dat = junc_dat.melt().iloc[0:K]
    junc_dat.value = junc_dat.value.astype(float)
    # make violin plot for junc_dat junction usage ratio coloured by cell_type
    # don't print x-axis tick labels 
    plot = ggplot(junc_dat, aes(x='variable', y='value')) + geom_point() + theme(axis_text_x=element_blank())
    print(plot)

In [None]:
# calculate sd deviation for each junction for cell states 0 to 19 
juncs_probs_df["sd"] = juncs_probs_df.iloc[:,0:K].std(axis=1)

In [None]:
# write function that takes in cell state and returns top 10 junctions with the highest difference with all other K-1 cell states
def top10_juncs(cell_state):
    # for each junction get the difference between cell_state and all other cell states not including cell_state
    # return top 10 junctions with highest difference
    no_ref=juncs_probs_df[juncs_probs_df.columns[~juncs_probs_df.columns.isin([cell_state, "junction_id_index", "junction_id", "sd"])]]
    juncs_probs_df["diff"] = juncs_probs_df[cell_state] - no_ref.mean(axis=1)
    top10 = juncs_probs_df.sort_values(by="diff", ascending=False).head(10)
    return(top10.junction_id_index.values)

    # think of actually using the distributions... use the full beta distribution via KL divergence... (pairwise)
    # are the distributions across cell states for junctions more different than if they were coming from the same cell state

In [None]:
def log_beta(a, b):
    return torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b)

def score(a, b):
    return log_beta(a,b).sum() - log_beta(a.sum(), b.sum())

In [None]:
ALPHA_f.shape

In [None]:
# get likelihood ratio/bayes factor score for ALL junctions 
# let's compare just state X and Y

scores_all_juncs = []
for junc_index in range(juncs_probs.shape[0]):
    a = ALPHA_f[junc_index, [3,7]]
    b = PI_f[junc_index, [3,7]]
    scores_all_juncs.append(score(a, b).item())

# turn scores_all_juncs into dataframe and add junction_id_index as a column
scores_all_juncs_df = pd.DataFrame(scores_all_juncs, columns = ["score"])
scores_all_juncs_df["junction_id_index"] = junction_ids_conversion.junction_id_index.values
scores_all_juncs_df.sort_values(by="score", ascending=False).head(10)
juncs_test=scores_all_juncs_df.sort_values(by="score", ascending=False).head(5).junction_id_index.values
scores_all_juncs_df.hist(column="score", bins=30)

In [None]:
# for each junction in top10juncs_state1, run plot_juncObsUsage and plot_juncProbs
for junc in juncs_test:
    plot_juncObsUsage(junc)
    plot_juncProbs(junc)

In [None]:
cell_ids_conversion

In [None]:
# convert PHI_f to a dataframe and add a column with cell ID and cell type 
PHI_f = pd.DataFrame(PHI_f)

# Add "CellState" to each column 
PHI_f.columns = ["CellState_" + str(i) for i in range(PHI_f.shape[1])]
PHI_f['cell_id'] = cell_ids_conversion.cell_id.values
PHI_f['cell_type'] = cell_ids_conversion.cell_type.values

In [None]:
PHI_f.head()

In [None]:
# save cell assignments to file for downstream analysis

output_dir = '/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/TabulaMurisMarrow'

# BBmixture model cell assignments
output_file = os.path.join(output_dir, 'Leaflet_BBmixture.csv')
PHI_f.to_csv(output_file, index=True, header=True)
print('Saved Leaflet latent cell states to {}'.format(output_file))