In [ ]:
# cd to root of folder

In [2]:
cd ..

/Users/euxhenh/CMU/Research/Projects/Truffle


In [None]:
import pickle

import anndata
import numpy as np
import pandas as pd
import seaborn as sns

from src.truffle import Truffle

## Load Data

In [None]:
name = "Psoriasis_gse171012_pca20"
# name = "COVID_gse212041_pca50"
# name = "Crohn_gse112366_pca20"
name_key = name.split('_')[0]
score = None  # 'acuity' for COVID, 'PASI_scores' for psoriaris
adata = anndata.read_h5ad(f"data/{name}.h5ad")
print(adata)

## Run Truffle

In [None]:
truffle = Truffle(max_path_len='auto')

truffle.prepare(
    adata,
    adj='umap_distance',
    subject_id='subj',
    time_point='visit',
    time_point_order='visit_order',
)

In [None]:
# `Termination condition` should be optimal. If not, then no solution
# was found and more relaxed parameters should be picked.
truffle.fit(edge_capacity=1, node_capacity=None)

## Store results

In [None]:
# Store
state_dict = truffle.state_dict()
with open(f"data/dumps/{name}_truffle.pkl", "wb") as f:
    pickle.dump(state_dict, f)

In [None]:
# Load
# with open(f"data/dumps/{name}_truffle.pkl", "rb") as f:
#     state_dict = pickle.load(f)
# truffle = Truffle.from_state_dict(state_dict)

## Analysis

In [None]:
out = truffle.get_state_diagram(
    adata,
    cluster='leiden',
    scores=score,
    prune_q=0.5,  # top fraction of edge weight to keep for a simplified diagram
)

In [None]:
sns.heatmap(out['state_diagram'], annot=True, square=True)

In [None]:
top_paths = truffle.get_top_paths(
    out['state_diagram'],
    out['initial_states'],
    out['final_states'],
)

In [None]:
# top 3 trajectories of length 2
stem_paths = [p[0] for p in top_paths.most_common() if len(p[0]) >= 3][:3]
# top 3 trajectories of length 3
stem_paths.extend([p[0] for p in top_paths.most_common() if len(p[0]) >= 4][:3])
# top 3 trajectories of length 4
stem_paths.extend([p[0] for p in top_paths.most_common() if len(p[0]) >= 5][:3])
print(stem_paths)

### Prepare tsv files for STEM

In [None]:
def prepare_for_STEM(adata, path, method: str = 'Truffle'):
    centers = []
    for cluster_id in path:
        centers.append(adata.uns['leiden_']['cluster_centers_'][str(cluster_id)])
    average_exp = np.asarray(centers).T
    df = pd.DataFrame(average_exp, index=adata.var_names, columns=path)
    df.to_csv(f"data/STEM/{name_key}/{method}_STEM_{''.join([str(p) for p in path])}.csv", sep='\t')

In [None]:
for p in stem_paths:
    prepare_for_STEM(adata, p)