## MELD TX23
Because we observe quite strong batch effects between TX23 and TX23, we will run the replicates seperately and afterwards compare the MELD likelihoods as if they were calculated on the same graph (see [this github issue](https://github.com/KrishnaswamyLab/MELD/issues/56)).

In [None]:
import pandas as pd
import numpy as np
import graphtools as gt
import matplotlib.pyplot as plt
import phate
import scprep
import meld
import sklearn
import scipy

In [None]:
# Load the data
pca_data = pd.read_csv("../../output/MELD/pca_data_tx23.tsv", sep="\t", index_col="cell_barcode")
mdata = pd.read_csv("../../output/MELD/mdata_tx23.tsv", sep="\t", index_col="cell_barcode", low_memory=False)

In [None]:
# load in results from the parameter serach
results = pd.read_csv("../../output/MELD/parameter_search_meld_tx23.csv", sep="\t")
results_wide = results.groupby(['beta', 'knn']).mean().sort_values(by='MSE').reset_index()
ax = scprep.plot.scatter(results_wide['beta'], results_wide['knn'], 
                         s=50, c=results_wide['MSE'], vmax=0.006, cmap='inferno_r')

# Highlight the top performing combination with a large red dot
top_result = results_wide.sort_values('MSE').iloc[0]
ax.scatter(top_result['beta'], top_result['knn'], c='r', s=100, linewidth=1, edgecolor='k')

### Running MELD

In [None]:
G = gt.Graph(pca_data, knn=int(top_result['knn']), use_pygsp=True) # build a graph using all the data
meld_op = meld.MELD(beta=top_result['beta'])
sample_densities = meld_op.fit_transform(G, sample_labels=mdata['orig.ident']) # estimate the densities for both the ctrl and notch -> corresponds to simply the orig.ident

In [None]:
# transform densities to likelihoods.
sample_likelihoods = sklearn.preprocessing.normalize(sample_densities, norm='l1') # per row, likelihood should sum to 1
sample_likelihoods = pd.DataFrame(sample_likelihoods, columns=np.unique(mdata["orig.ident"]))
sample_likelihoods.index = mdata["Barcode_unique"]
mdata["N_likelihood"] = sample_likelihoods.loc[:,"TX23_N"]


## Run VFC for the progenitor population

In [None]:
# First annotate which cells are considered progenitors.
# in our case: ISCs, EBs and EEPs
eep_barcodes = mdata.loc[mdata.loc[:,"high_res_annotation"]=="EEP", "Barcode_unique"].to_numpy()
progenitor_mask = (mdata["celltype_manual"].isin(["ISC", "EB"]) | mdata["Barcode_unique"].isin(eep_barcodes)).to_numpy()
progenitor_pca =  pca_data.loc[progenitor_mask]

In [None]:
# calculate VFC. 
# this takes quite a while
cluster="progenitors"
vfc_op_per_cluster = {}
curr_G = gt.Graph(progenitor_pca, use_pygsp=True)
curr_G.compute_fourier_basis()
curr_sample_labels = mdata['orig.ident'].loc[progenitor_mask]
curr_sample_labels = pd.Series([0 if i == "TX23" else 1 for i in curr_sample_labels])
# get the perturbation likelihood which was calculated above
curr_likelihood = mdata['N_likelihood'].loc[progenitor_mask]
curr_vfc = meld.VertexFrequencyCluster(n_clusters = 2)
curr_vfc.fit_transform(curr_G, curr_sample_labels, curr_likelihood)
vfc_op_per_cluster[cluster] = curr_vfc

In [None]:
cluster="progenitors"
curr_vfc = vfc_op_per_cluster[cluster]
subclustering_results = {}
clusters_by_n = {}
for n in ([2,3]):
    clusters_by_n[n] = curr_vfc.predict(n) # this actually produces VFC prediction results. 
subclustering_results[cluster] = clusters_by_n

In [None]:
# export all information as a table which we can read into R
cluster = "progenitors"
df = pd.DataFrame({
    "Barcode": mdata['Barcode_unique'].loc[progenitor_mask],
    "N_likelihood" : mdata['N_likelihood'].loc[progenitor_mask],
    "VFC_2": subclustering_results[cluster][2],
    "VFC_3": subclustering_results[cluster][3],
    "orig.ident" : mdata['orig.ident'].loc[progenitor_mask]
})
df.to_csv(f"../../output/MELD/TX23_info_{cluster}.tsv",sep="\t", index=False)

## Do the same as above for the other celltypes

In [None]:
## Run VFC
np.random.seed(0)
vfc_op_per_cluster = {}

# only analyze cell types with a certain number of cells in BOTH conditions
clusters = ['EB', 'EE', 'ISC',  'aEC', 'daEC', 'mEC', 'pEC']

for cluster in clusters:
    print(cluster)
    curr_G = gt.Graph(pca_data.loc[mdata["celltype_manual"] == cluster], use_pygsp=True)
    curr_G.compute_fourier_basis()
    curr_sample_labels = mdata['orig.ident'].loc[mdata["celltype_manual"] == cluster]
    curr_sample_labels = pd.Series([0 if i == "TX23" else 1 for i in curr_sample_labels])
    curr_likelihood =mdata['N_likelihood'].loc[mdata["celltype_manual"] == cluster]
    curr_vfc = meld.VertexFrequencyCluster(n_clusters = 3)
    curr_vfc.fit_transform(curr_G, curr_sample_labels, curr_likelihood)
    vfc_op_per_cluster[cluster] = curr_vfc

In [None]:
subclustering_results = {}
for cluster in clusters:
    print(cluster)
    curr_vfc = vfc_op_per_cluster[cluster]
    clusters_by_n = {}
    for n in [2,3]:
        clusters_by_n[n] = curr_vfc.predict(n)
    subclustering_results[cluster] = clusters_by_n

In [None]:
# export all information as a table which we can read into R
for cluster in clusters: 
    df = pd.DataFrame({
        "Barcode": mdata['Barcode_unique'].loc[mdata['celltype_manual'] == cluster],
        "N_likelihood" : mdata['N_likelihood'].loc[mdata['celltype_manual'] == cluster],
        "VFC_2": subclustering_results[cluster][2],
        "VFC_3": subclustering_results[cluster][3],
        "orig.ident" : mdata['orig.ident'].loc[mdata['celltype_manual'] == cluster]
    })
    df.to_csv(f"../../output/MELD/TX23_info_{cluster}.tsv",sep="\t", index=False)