# Patches with Simulation Data

## Import Packages

In [None]:
from ladder.data import get_data
from ladder.scripts import InterpretableWorkflow
import umap, torch, pyro, os
import torch.optim as opt

import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import anndata as ad
import scanpy as sc

## Load and Prepare Data

In [None]:
adata = sc.read_h5ad("../../data/sim/01-pro/t100,s80,b0.h5ad")
adata.X = adata.layers["counts"] # model input should be raw counts (stated in docs)
adata

## Run Patches in Interpretable Workflow - Condition Only

In [None]:
# Initialize workflow object
workflow = InterpretableWorkflow(adata, verbose=True, random_seed=42)

# Define the condition classes & batch key to prepare the data
factors = ["group_id"]
workflow.prep_model(factors, batch_key="sample_id", model_type='Patches', model_args={'ld_normalize' : True})

workflow.run_model(max_epochs=100, convergence_threshold=1e-5, convergence_window=1000) # Lower the convergence threshold if you need a more accurate model, will increase training time
workflow.save_model("../../data/sim/02-patches/t100,s80,b0-con")

In [None]:
workflow.plot_loss()

In [None]:
workflow.write_embeddings()
workflow.anndata.obsm

In [None]:
workflow.evaluate_reconstruction()

In [None]:
workflow.get_conditional_loadings()
workflow.get_common_loadings()
workflow.anndata.var

In [None]:
for gene in (workflow.anndata.var["Condition2_score_Patches"]).sort_values(ascending=False)[:200].index:
    print(gene, workflow.anndata.var.loc[gene, ["Condition2_score_Patches"]].values[0])

In [None]:
workflow.anndata.var.loc[:, ["Condition1_score_Patches", "Condition2_score_Patches", "common_score_Patches"]].to_csv(
    "../../data/sim/02-patches/t100,s80,b0-con_loadings.csv"
)

## Run Patches in Interpretable Workflow - Condition + Cluster

In [None]:
# Initialize workflow object
workflow = InterpretableWorkflow(adata, verbose=True, random_seed=42)

# Define the condition classes & batch key to prepare the data
factors = ["group_id", "cluster_id"]
workflow.prep_model(factors, batch_key="sample_id", model_type='Patches', model_args={'ld_normalize' : True})

workflow.run_model(max_epochs=100, convergence_threshold=1e-5, convergence_window=1000) # Lower the convergence threshold if you need a more accurate model, will increase training time
workflow.save_model("../../data/sim/02-patches/t100,s80,b0-con-clu")

In [None]:
workflow.plot_loss()

In [None]:
workflow.write_embeddings()
workflow.anndata.obsm

In [None]:
workflow.evaluate_reconstruction()

In [None]:
workflow.get_conditional_loadings()
workflow.get_common_loadings()
workflow.anndata.var

In [None]:
for gene in (workflow.anndata.var["Condition2_score_Patches"]).sort_values(ascending=False)[:200].index:
    print(gene, workflow.anndata.var.loc[gene, ["Condition2_score_Patches"]].values[0])

In [None]:
workflow.anndata.var.loc[:, [
    "Condition1_score_Patches", 
    "Condition2_score_Patches", 
    "common_score_Patches", 
    "Group1_score_Patches", 
    "Group2_score_Patches", 
    "Group3_score_Patches"
    ]].to_csv(
    "../../data/sim/02-patches/t100,s80,b0-con-clu_loadings.csv"
)