In [None]:
import pickle
import sys
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from scipy.stats import zscore

from matplotlib.patches import Rectangle, Patch


sys.path.insert(0, '../4_preffect')
from _config import configs
from preffect_factory import factory
import anndata as ad 
from _inference import( Inference )

In [None]:
# dataset
dataset_run = "dataset_subtype_5"

def list_subfolders(directory):
    subfolders = [entry.name for entry in os.scandir(directory) if entry.is_dir()]
    return subfolders

# created a sorting algorithm to ensure "epoch_100" is sorted after "epoch_50"
def sort_key(s):
    s = re.sub(r'\s+', '', s)
    # Extract the number from the folder name (assuming every folder name has the format 'test_<number>')
    match = re.search(r'\d+$', s)
    if match:
        return int(match.group())
    return s

folder_search = f'/path/to/{dataset_run}/testing_single/separate_latent'

all_subfolders = sorted(list_subfolders(folder_search), reverse=False)

#all_subfolders.sort(key=sort_key)

#all_subfolders = all_subfolders[13:14]
all_subfolders = all_subfolders[5:6]
# 3 jumbled, 4 jumbled/nodrop, 5 nodrop, 6 Normal
print("All sub-folders:", all_subfolders)

Load a specific run to this page. We are just interested in the first one, 
basic_M_1000_minibatch_200_epochs_1000_lr_0.001_lib_False_likelihood_NB_masking_MCAR_2_lambda_0.0

In [None]:
configs['task'] = 'reinstate'

pr_reinstate = {}
pr_count = 0

for dir_name in all_subfolders:

    dir_name = re.sub(r'\s+', '', dir_name)
    full_path = folder_search + "/" + dir_name
    configs['output_path'] = full_path
    configs['cuda_device_num'] = 4
    pr_reinstate[pr_count] = factory(task='reinstate', configs=configs, trigger_setup=True)
    
    #configs['input_inference_anndata_path'] = configs['input_anndata_path'] + 'test/' 
    #configs['task'] = 'impute_experiment'
    #infy, error_masked, error_unmasked, df_subtype  = factory(task='impute_experiment', configs=configs, 
    #                                                preffect_obj=pr_reinstate, inference_key = 'endogenous',
    #                                                error_type='mse')

    configs['always_save'] = False


    pr_count += 1

    break


In [None]:
pr_data = pr_reinstate[0]
#print(pr_data.train_dataset.gene_names[965])

# so position 965 is ERBB2 and 966 is ESR1
configs_inf = pr_data.configs.copy()
configs_inf['task'] = 'inference'
inference_instance = Inference(pr_data, task='inference', inference_key = configs_inf['inference_key'], configs=configs_inf)
inference_instance.run_inference()
inference_instance.configs_inf['inference_key'] = 'endogenous'
inference_instance.register_inference_run()



In [None]:
print((inference_instance.output['DAs'][0][0].shape))

In [None]:
print(pr_data.train_dataset.M)


In [None]:
# Now cluster_counts works and only displays clustering on \hat{mu}
loop_count = 0
for dir_name in all_subfolders:
    print(dir_name)
    factory(task='cluster_counts', preffect_obj=pr_reinstate[loop_count], inference_key='endogenous', trigger_setup=False, configs=configs)
    loop_count += 1



In [None]:
# cluster_latent seems hardcoded for batch. Where does this happen?
# fixed that, but now there's a problem with Leiden clustering

configs['always_save'] = False
loop_count = 0

for dir_name in all_subfolders:
    
    print(dir_name)
    factory(task='cluster_latent', preffect_obj=pr_reinstate[loop_count], inference_key='endogenous', trigger_setup=False, configs=configs)
    loop_count += 1

Now lets evaluate the Mu and Theta of 

In [None]:
# I'm having a weird rounding issue where it just randomly doesn't work
def truncate_to_one_decimal(arr):
    factor = 10  # 10^1 for one decimal place
    return np.floor(arr * factor) / factor


inf_reinstate = pr_reinstate[0].inference_dict['endogenous']
var = inf_reinstate.parent.train_dataset.anndatas_orig[0].var
#print(var)

In [None]:
# Reading in P and R that generated
# read PAM50 p and r file
pam50_path = "/path/to/9_Exploring_NB_In_PAM50/our_dcis.NB_PAM50.Median.Trim_5.Subtype.csv"

df = pd.read_csv(pam50_path)

df_pivot = df.pivot_table(index='Gene', columns='Subtype', values=['p','r'], aggfunc='first')
df_pivot.reset_index(inplace=True)
df_pivot.columns = [' '.join(col).strip() for col in df_pivot.columns.values]

#print(df_pivot)

category_values_list, category_counts_list, category_omegas_list, category_true_omegas_list = [], [], [], []

print("Gene Subtype GenMu MuHat(mean) MuHat(stdev)")

# Looping over each row
for index, row in df_pivot.iterrows():
    mus, thetas = {}, {}


    p_basal = row['p Basal']
    p_her2 = row['p Her2']
    p_luma = row['p LumA']
    p_lumb = row['p LumB']
    p_normal = row['p Normal']

    thetas['basal'] = row['r Basal']    
    thetas['her2'] = row['r Her2']
    thetas['luma'] = row['r LumA']
    thetas['lumb'] = row['r LumB']
    thetas['normal'] = row['r Normal']

    # convert to p/r to mu (theta = r)
    mus['basal'] = thetas['basal'] * (1 - p_basal) / p_basal
    mus['her2'] = thetas['her2'] * (1 - p_her2) / p_her2
    mus['luma'] = thetas['luma'] * (1 - p_luma) / p_luma
    mus['lumb'] = thetas['lumb'] * (1 - p_lumb) / p_lumb
    mus['normal'] = thetas['normal'] * (1 - p_normal) / p_normal

    gene = row['Gene']

    # the same information is in the AnnData input
    model = 0
    inf_reinstate = pr_reinstate[model].inference_dict['endogenous']
    adata = inf_reinstate.return_counts_as_anndata()
   
    # continuing on, lets pull Mu/Theta for this gene
    hat_mu = adata[0].X
    hat_theta = adata[0].layers["X_hat_theta"]
    true_counts = adata[0].layers["original_counts"]
    
    # lets convert true counts to omega
    library_size = np.sum(true_counts, axis=1)

    # Calculate omega (proportion of library size for each gene)
    true_omega = true_counts / library_size[:, np.newaxis]
        
    omega = adata[0].layers["px_omega"]
    subtype = adata[0].obs['subtype']
    gene_order = inf_reinstate.ds.gene_names

    column_index = gene_order.index(gene)
    gene_muhat = hat_mu[:, column_index]
    gene_mutheta = hat_theta[:, column_index]

    gene_truecounts = true_counts[:, column_index]

    gene_omegas = omega[:, column_index]
    gene_true_omegas = true_omega[:, column_index]

    categories = subtype.cat.categories

    # now I want to isolate the values based on the subtype in obs
        
    category_values = {category: gene_muhat[subtype == category] for category in categories}
    category_counts = {category: gene_truecounts[subtype == category] for category in categories}
    category_true_omega = {category: gene_true_omegas[subtype == category] for category in categories}
    category_omegas = {category: gene_omegas[subtype == category] for category in categories}

    category_values_list.append(category_values)
    category_counts_list.append(category_counts)
    category_omegas_list.append(category_omegas)
    category_true_omegas_list.append(category_true_omega)


    # Print the isolated values for each category
    for category, values in category_values.items():

        avg_mu_subtype = truncate_to_one_decimal(np.mean(values))
        std_mu_subtype = truncate_to_one_decimal(np.std(values))

        print(f"{gene} {category} {np.round(mus[category], 1)} {avg_mu_subtype} {std_mu_subtype}")
        
    # And again for true counts

    for category, values in category_counts.items():

        avg_mu_subtype = truncate_to_one_decimal(np.mean(values))
        std_mu_subtype = truncate_to_one_decimal(np.std(values))

        #print(f"{gene} {category} {np.round(mus[category], 1)} {avg_mu_subtype} {std_mu_subtype}")
    #break




In [None]:
df_list_mu, df_list_count, df_list_omega, df_list_true_omega = [], [], [], []

# Loop over categories to create DataFrames
for category in categories:
    category_values = {gene: category_values_list[i][category] for i, gene in enumerate(df_pivot['Gene'])}
    df = pd.DataFrame(category_values)
    df['category'] = category
    df_list_mu.append(df)

    category_counts = {gene: category_counts_list[i][category] for i, gene in enumerate(df_pivot['Gene'])}
    df = pd.DataFrame(category_counts)
    df['category'] = category
    df_list_count.append(df)

    category_omegas = {gene: category_omegas_list[i][category] for i, gene in enumerate(df_pivot['Gene'])}
    df = pd.DataFrame(category_omegas)
    df['category'] = category
    df_list_omega.append(df)

    category_true_omegas = {gene: category_true_omegas_list[i][category] for i, gene in enumerate(df_pivot['Gene'])}
    df = pd.DataFrame(category_true_omegas)
    df['category'] = category
    df_list_true_omega.append(df)

all_category_values_df = pd.concat(df_list_mu)
all_category_counts_df = pd.concat(df_list_count)
all_category_omegas_df = pd.concat(df_list_omega)
all_category_true_omegas_df = pd.concat(df_list_true_omega)

# remove last column and normalize the data
# originally using Z-Score, but that's for normal; switch to min-max scaling?

category_order = all_category_values_df.pop(all_category_values_df.columns[-1])
#normalized_data = all_category_values_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)
normalized_data = (all_category_values_df - all_category_values_df.min()) / (all_category_values_df.max() - all_category_values_df.min())
normalized_data = normalized_data.transpose()

# the input counts
category_order_counts = all_category_counts_df.pop(all_category_counts_df.columns[-1])
#normalized_count_data = all_category_counts_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)
normalized_count_data = (all_category_counts_df - all_category_counts_df.min()) / (all_category_counts_df.max() - all_category_counts_df.min())
normalized_count_data = normalized_count_data.transpose()

# true omegas
# normalization for omegas: Min/Max scaling (best for continous data between 0,1)
category_order = all_category_true_omegas_df.pop(all_category_true_omegas_df.columns[-1])
#normalized_true_omega_data = all_category_true_omegas_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)
normalized_true_omega_data = (all_category_true_omegas_df - all_category_true_omegas_df.min()) / (all_category_true_omegas_df.max() - all_category_true_omegas_df.min())
normalized_true_omega_data = normalized_true_omega_data.transpose()
all_category_true_omegas_df = all_category_true_omegas_df.transpose()

# hat omegas
category_order = all_category_omegas_df.pop(all_category_omegas_df.columns[-1])
# normalized_omega_data = all_category_omegas_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)
# min-max scaling
normalized_omega_data = (all_category_omegas_df - all_category_omegas_df.min()) / (all_category_omegas_df.max() - all_category_omegas_df.min())

normalized_omega_data = normalized_omega_data.transpose()
all_category_omegas_df = all_category_omegas_df.transpose()

print(normalized_omega_data)



In [None]:

from matplotlib.patches import Rectangle, Patch


category_colors = {
    'basal': 'blue',
    'her2': 'orange',
    'luma': 'green',
    'lumb': 'red',
    'normal': 'purple',
}

# Map the vector categories to colors
category_bar = category_order.map(category_colors)
white_red_cmap = LinearSegmentedColormap.from_list('white_red', ['white', 'darkred'])

# Create the heatmap
plt.figure(figsize=(3, 8))
ax = sns.heatmap(normalized_data, cmap=white_red_cmap, cbar=True, xticklabels=True, yticklabels=True)

# Add the category bar
for idx, color in enumerate(category_bar):
    ax.add_patch(Rectangle((idx, len(normalized_data)), 1, 1, color=color, transform=ax.transData, clip_on=False))
    
legend_patches = [Patch(color=color, label=category) for category, color in category_colors.items()]
ax.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.07, 1))

# Adjust the plot to fit the category bar
plt.subplots_adjust(bottom=0.0)
plt.subplots_adjust(right=4)

# Add labels and title
plt.title(r'$\hat{\mu}$ Heatmap [Normalized]')
#plt.xlabel('Patients')
plt.ylabel('Genes')
plt.xticks([])

# Show the plot
plt.show()

In [None]:
# lets make another heatmap but for counts

# Create the heatmap
plt.figure(figsize=(3, 8))
ax = sns.heatmap(normalized_count_data, cmap=white_red_cmap, cbar=True, xticklabels=True, yticklabels=True)

# Add the category bar
for idx, color in enumerate(category_bar):
    ax.add_patch(Rectangle((idx, len(normalized_count_data)), 1, 1, color=color, transform=ax.transData, clip_on=False))
    
legend_patches = [Patch(color=color, label=category) for category, color in category_colors.items()]
ax.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.07, 1))

# Adjust the plot to fit the category bar
plt.subplots_adjust(bottom=0.0)
plt.subplots_adjust(right=4)

# Add labels and title
plt.title(r'True Counts (Input Data) Heatmap [Normalized]')
#plt.xlabel('Patients')
plt.ylabel('Genes')
plt.xticks([])

# Show the plot
plt.show()

In [None]:
PAM50genes = [
    "EGFR",    "CDH3",
    "PHGDH",    "ACTR3B",
    "FOXC1",    "MIA",
    "MYC",    "FGFR4",
    "MDM2",    "MLPH",
    "KRT14",    "BCL2",
    "SFRP1",    "KRT5",
    "KRT17",    "SLC39A6",
    "ESR1",    "CXXC5",
    "BLVRA",    "FOXA1",
    "GPR160",    "NAT1",
    "MAPT",    "PGR",
    "BAG1",    "TMEM45B",
    "ERBB2",    "GRB7",
    "MMP11",    "CDC20",
    "MKI67",    "CCNE1",
    "CENPF",    "NUF2",
    "EXO1",    "KIF2C",
    "ORC6",    "ANLN",
    "CDC6",    "RRM2",
    "UBE2T",    "NDC80",
    "CEP55",    "MELK",
    "TYMS",    "CCNB1",
    "BIRC5",    "MYBL2",
    "PTTG1",    "UBE2C",
]

# re-arrange the gene order to match PAM50
normalized_true_omega_data_re = pd.DataFrame(normalized_true_omega_data, index=PAM50genes).iloc[::-1]

normalized_omega_data_re = pd.DataFrame(normalized_omega_data, index=PAM50genes).iloc[::-1]

In [None]:
# lets plot the True/PREFFECT omegas side by side
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle, Patch

fig, axes = plt.subplots(1, 2, figsize=(14, 8))  # Adjust figsize as needed

# Plot the first heatmap
ax1 = sns.heatmap(
    normalized_true_omega_data_re,
    cmap=white_red_cmap,
    cbar=False,
    xticklabels=True,
    yticklabels=True,
    ax=axes[0]
)

# Add the category bar to the first heatmap
for idx, color in enumerate(category_bar):
    ax1.add_patch(Rectangle((idx, len(normalized_true_omega_data_re)), 1, 1, color=color, transform=ax1.transData, clip_on=False))

# Add legend to the first heatmap
legend_patches = [Patch(color=color, label=category) for category, color in category_colors.items()]
#ax1.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.07, 1))

# Set title and labels for the first heatmap
ax1.set_title(r'Observed $\Omega$ of Input Data [Min/Max Norm.]')
ax1.set_ylabel('Genes')
ax1.set_xticks([])

# Plot the second heatmap
ax2 = sns.heatmap(
    normalized_omega_data_re,  
    cmap=white_red_cmap,
    cbar=True,
    xticklabels=True,
    yticklabels=True,
    ax=axes[1]
)
for idx, color in enumerate(category_bar):
    ax2.add_patch(Rectangle((idx, len(normalized_omega_data_re)), 1, 1, color=color, transform=ax2.transData, clip_on=False))

# Add legend to the second heatmap
ax2.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.5, 1))

# Set title and labels for the second heatmap
plt.title(r'$\Omega$ of PREFFECT Model [Min/Max Norm.]')# Replace with an appropriate title
ax2.set_ylabel('Genes')
ax2.set_xticks([])

plt.tight_layout()
plt.show()


In [None]:
# lets plot the True/PREFFECT omegas side by side

fig, axes = plt.subplots(1, 2, figsize=(14, 8))  # Adjust figsize as needed

# Plot the first heatmap
ax1 = sns.heatmap(
    normalized_count_data,
    cmap=white_red_cmap,
    cbar=False,
    xticklabels=True,
    yticklabels=True,
    ax=axes[0]
)

# Add the category bar to the first heatmap
for idx, color in enumerate(category_bar):
    ax1.add_patch(Rectangle((idx, len(normalized_count_data)), 1, 1, color=color, transform=ax1.transData, clip_on=False))

# Add legend to the first heatmap
legend_patches = [Patch(color=color, label=category) for category, color in category_colors.items()]
#ax1.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.07, 1))

# Set title and labels for the first heatmap
ax1.set_title(r'Observed Counts of Input Data [Min/Max Norm.]')
ax1.set_ylabel('Genes')
ax1.set_xticks([])

# Plot the second heatmap
ax2 = sns.heatmap(
    normalized_data,  
    cmap=white_red_cmap,
    cbar=True,
    xticklabels=True,
    yticklabels=True,
    ax=axes[1]
)
for idx, color in enumerate(category_bar):
    ax2.add_patch(Rectangle((idx, len(normalized_data)), 1, 1, color=color, transform=ax2.transData, clip_on=False))

# Add legend to the second heatmap
ax2.legend(handles=legend_patches, title='Categories', loc='upper right', bbox_to_anchor=(1.5, 1))

# Set title and labels for the second heatmap
plt.title(r'$\hat{\mu}$ of PREFFECT Model [Min/Max Norm.]')# Replace with an appropriate title
ax2.set_ylabel('Genes')
ax2.set_xticks([])

plt.tight_layout()
plt.show()


In [None]:
import torch

### edge comparison
pr_of_interest = pr_reinstate[0]

#print(pr_of_interest.train_dataset.As_mask)
#print(pr_of_interest.train_dataset.anndatas_orig[0].obs['subtype'])
# yeah I can see the edges that line up to subtype in this dataset

# so what does it look like after being learned
inf = pr_of_interest.inference_dict['endogenous']

print(torch.min(inf.output['DAs'][0][0]).item(), torch.max(inf.output['DAs'][0][0]).item(), inf.output['DAs'][0][0].shape)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Plot the first heatmap
im1 = axes[0].imshow(inf.ds.As_orig[0].cpu().float(), cmap='viridis')
axes[0].set_title('True S-S Edges')
fig.colorbar(im1, ax=axes[0])

# Plot the second heatmap
im2 = axes[1].imshow(inf.output['DAs'][0][0].cpu().detach().float(), cmap='viridis')
axes[1].set_title('Recon DAs')
fig.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()
