In [None]:
# # Paired Benchmark Figures

# This section contains figures comparing different benchmarks for paired models across datasets:
# 1. Cell Type Accuracy Comparison
# 2. Cell Type Accuracy Ablation Study
# 3. Gene Correlation Distribution Analysis


In [None]:
# Load paired benchmark data
import numpy as np
import pandas as pd
import scanpy as sc
from collections import defaultdict
from tqdm import tqdm

# Load ground truth and predictions
ground_truth_out_of_sample = sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/xenium_cancer_inferences/whole_sample.h5ad')
pred_out_of_sample = sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/xenium_cancer_inferences/whole_sample.h5ad')

ground_truth_in_sample = {
    z: sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/cancer_in_sample_folds/fold_{z}_st.h5ad')
    for z in range(4)
}

pred_in_sample = {
    z: sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/xenium_cancer_inferences/fold_{z}.h5ad')
    for z in range(4)
}

# Process data
for z in ground_truth_in_sample:
    sc.pp.log1p(ground_truth_in_sample[z])
    ground_truth_in_sample[z].X = np.array(ground_truth_in_sample[z].X.todense())

# Get common genes
common_in_sample = np.intersect1d(pred_in_sample[0].var.index, ground_truth_in_sample[0].var.index)
pred_in_sample = {k: v[::,common_in_sample] for k, v in pred_in_sample.items()}
ground_truth_in_sample = {k: v[::,common_in_sample] for k, v in ground_truth_in_sample.items()}

common_out_of_sample = np.intersect1d(ground_truth_out_of_sample.var.index, pred_out_of_sample.var.index)
ground_truth_out_of_sample = ground_truth_out_of_sample[::,common_out_of_sample]
pred_out_of_sample = pred_out_of_sample[::,common_out_of_sample]

# Load benchmark predictions
benchmarks = [
    'schaf_no_stage1', 
    'schaf_no_stage2',
    'spirit',
    'st_net',
    'deep_pt',
    'he2rna',
]

# Load in-sample benchmark predictions
infer_models_dir = '/storage2/ccomiter/schaf_benchmarks_infer_cancer_in_sample/'
benchmark_to_pred_in_sample = {}
for benchmark in benchmarks:
    res = {}
    for fold in range(4):
        res[fold] = sc.read_h5ad(f'{infer_models_dir}/benchmark_{benchmark}_fold_{fold}.h5ad')
        res[fold] = res[fold][::,np.intersect1d(common_in_sample, res[fold].var.index)]
    benchmark_to_pred_in_sample[benchmark] = res

# Load out-of-sample benchmark predictions
infer_models_dir = '/storage2/ccomiter/schaf_benchmarks_infer_cancer_whole_sample/'
benchmark_to_pred_out_of_sample = {}
for benchmark in benchmarks:
    benchmark_to_pred_out_of_sample[benchmark] = sc.read_h5ad(f'{infer_models_dir}/benchmark_{benchmark}_whole_sample.h5ad')
    benchmark_to_pred_out_of_sample[benchmark] = benchmark_to_pred_out_of_sample[benchmark][::,np.intersect1d(common_out_of_sample, benchmark_to_pred_out_of_sample[benchmark].var.index)]

# Load mouse data
pred_mouse = {
    z: sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/mouse_inferences/fold_{z}.h5ad')
    for z in range(4)
}

ground_truth_mouse = {
    z: sc.read_h5ad(f'{DATA_ROOT}/data/xenium_cancer/mouse_folds/fold_{z}_st.h5ad')
    for z in range(4)
}

for z in ground_truth_mouse:
    sc.pp.log1p(ground_truth_mouse[z])
    ground_truth_mouse[z].X = np.array(ground_truth_mouse[z].X.todense())

common_mouse = np.intersect1d(pred_mouse[0].var.index, ground_truth_mouse[0].var.index)
pred_mouse = {k: v[::,common_mouse] for k, v in pred_mouse.items()}
ground_truth_mouse = {k: v[::,common_mouse] for k, v in ground_truth_mouse.items()}


In [None]:
# Calculate correlations and metrics
datasets = ['in_sample', 'out_of_sample', 'mouse']

# Add SCHAF and ground truth to benchmark dictionaries
benchmark_to_pred_mouse['schaf'] = pred_mouse
benchmark_to_pred_mouse['gt'] = ground_truth_mouse
benchmark_to_pred_out_of_sample['schaf'] = pred_out_of_sample
benchmark_to_pred_out_of_sample['gt'] = ground_truth_out_of_sample
benchmark_to_pred_in_sample['schaf'] = pred_in_sample
benchmark_to_pred_in_sample['gt'] = ground_truth_in_sample

# Organize predictions by dataset
dataset_to_benchmark_to_pred = {
    'in_sample': benchmark_to_pred_in_sample,
    'out_of_sample': benchmark_to_pred_out_of_sample,
    'mouse': benchmark_to_pred_mouse
}

# Calculate correlations and scores
dataset_to_benchmark_to_corrs = {}
dataset_to_benchmark_to_scores = {}

# Out of sample correlations
benchmark_to_out_of_sample_corrs = {}
benchmark_to_out_of_sample_scores = {}
for benchmark in benchmarks + ['schaf']:
    corrs_res = {}
    scores_res = {}
    gt = dataset_to_benchmark_to_pred['out_of_sample']['gt']
    pred = dataset_to_benchmark_to_pred['out_of_sample'][benchmark]
    for g in pred.var.index:
        t = np.array(gt[::,g].X.squeeze())
        p = np.array(pred[::,g].X.squeeze())
        corrs_res[g] = np.corrcoef(t, p)[0, 1]
        scores_res[g] = p.std()
    benchmark_to_out_of_sample_corrs[benchmark] = corrs_res
    benchmark_to_out_of_sample_scores[benchmark] = scores_res

# In sample correlations
in_sample_fold_to_prop = {}
total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
for z in range(4):
    in_sample_fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells

benchmark_to_in_sample_corrs = {}
benchmark_to_in_sample_scores = {}
for benchmark in benchmarks + ['schaf']:
    all_folds_corrs_res = {}
    all_folds_scores_res = {}
    for fold in range(4):
        corrs_res = {}
        scores_res = {}
        gt = dataset_to_benchmark_to_pred['in_sample']['gt'][fold]
        pred = dataset_to_benchmark_to_pred['in_sample'][benchmark][fold]
        for g in pred.var.index:
            t = np.array(gt[::,g].X.squeeze())
            p = np.array(pred[::,g].X.squeeze())
            corrs_res[g] = np.corrcoef(t, p)[0, 1]
            scores_res[g] = p.std()
        all_folds_corrs_res[fold] = corrs_res
        all_folds_scores_res[fold] = scores_res
    
    corrs_res = {}
    scores_res = {}
    for g in pred.var.index:
        corrs_res[g] = sum(all_folds_corrs_res[z][g] * in_sample_fold_to_prop[z] for z in range(4))
        scores_res[g] = sum(all_folds_scores_res[z][g] * in_sample_fold_to_prop[z] for z in range(4))
    
    benchmark_to_in_sample_corrs[benchmark] = corrs_res
    benchmark_to_in_sample_scores[benchmark] = scores_res

# Mouse correlations
mouse_fold_to_prop = {}
total_cells = float(sum(ground_truth_mouse[z].shape[0] for z in range(4)))
for z in range(4):
    mouse_fold_to_prop[z] = ground_truth_mouse[z].shape[0] / total_cells

benchmark_to_mouse_corrs = {}
benchmark_to_mouse_scores = {}
for benchmark in benchmarks + ['schaf']:
    all_folds_corrs_res = {}
    all_folds_scores_res = {}
    for fold in range(4):
        corrs_res = {}
        scores_res = {}
        gt = dataset_to_benchmark_to_pred['mouse']['gt'][fold]
        pred = dataset_to_benchmark_to_pred['mouse'][benchmark][fold]
        for g in pred.var.index:
            t = np.array(gt[::,g].X.squeeze())
            p = np.array(pred[::,g].X.squeeze())
            corrs_res[g] = np.corrcoef(t, p)[0, 1]
            scores_res[g] = p.std()
        all_folds_corrs_res[fold] = corrs_res
        all_folds_scores_res[fold] = scores_res
    
    corrs_res = {}
    scores_res = {}
    for g in pred.var.index:
        corrs_res[g] = sum(all_folds_corrs_res[z][g] * mouse_fold_to_prop[z] for z in range(4))
        scores_res[g] = sum(all_folds_scores_res[z][g] * mouse_fold_to_prop[z] for z in range(4))
    
    benchmark_to_mouse_corrs[benchmark] = corrs_res
    benchmark_to_mouse_scores[benchmark] = scores_res

# Organize results by dataset
dataset_to_benchmark_to_corrs = {
    'in_sample': benchmark_to_in_sample_corrs,
    'out_of_sample': benchmark_to_out_of_sample_corrs,
    'mouse': benchmark_to_mouse_corrs
}

dataset_to_benchmark_to_scores = {
    'in_sample': benchmark_to_in_sample_scores,
    'out_of_sample': benchmark_to_out_of_sample_scores,
    'mouse': benchmark_to_mouse_scores
}


In [None]:
# Plotting functions and constants
import matplotlib.pyplot as plt
import seaborn as sns

# Dataset and benchmark name mappings
ds_to_name = {
    'mouse': 'In-Sample Xenium Mouse',
    'in_sample': 'In-Sample Xenium MBC',
    'out_of_sample': 'New-Sample Xenium MBC',
}

bm_to_name = {
    'schaf': 'SCHAF',
    'spirit': 'SPiRiT',
    'st_net': 'ST-Net',
    'deep_pt': 'DeepPT',
    'he2rna': 'HE2RNA',
    'schaf_no_stage1': 'No Stage 2',
    'schaf_no_stage2': 'No Stage 1',
}

def plot_benchmarks_cell_type_accuracy(y, value, errors=None):
    """Plot cell type accuracy comparison across benchmarks and datasets."""
    new_y = {}
    for k, v in y.items():
        for kk, vv in v.items():
            kk = kk.replace('no', 'only')
            k = k.replace('no', 'only')
            if kk not in new_y:
                new_y[kk] = {}
            new_y[kk][k] = vv
    y = new_y

    if errors:
        new_y = {}
        for k, v in errors.items():
            for kk, vv in v.items():
                kk = kk.replace('no', 'only')
                k = k.replace('no', 'only')
                if kk not in new_y:
                    new_y[kk] = {}
                new_y[kk][k] = vv
        errors = new_y

    # Extract dataset names and benchmarks
    datasets = list(y.keys())
    datasets = [d for d in datasets if 'only' not in d]
    benchmarks = ['mouse', 'in_sample', 'out_of_sample']

    # Prepare values for each dataset
    values = np.array([[y[dataset].get(benchmark, 0) for benchmark in benchmarks] for dataset in datasets])
    error_values = np.array([[errors[dataset].get(benchmark, 0) for benchmark in benchmarks] for dataset in datasets]) if errors else np.zeros_like(values)

    # Create plot
    _, ax = plt.subplots(1, 1, figsize=(11, 11))
    bar_width = 0.1
    x = np.arange(len(benchmarks))

    # Plot bars
    for i, dataset in enumerate(datasets):
        plt.bar(x + i * bar_width, values[i], width=bar_width, label=dataset, yerr=error_values[i], capsize=5, ecolor='black')

    # Customize plot
    plt.xlabel('Dataset')
    plt.ylabel(value)
    plt.xticks(x + bar_width * (len(datasets) - 1) / 2, [ds_to_name[q.replace('no', 'only')] for q in benchmarks])
    plt.tight_layout()
    plt.setp(ax.spines.values(), linewidth=2)
    ax.spines[['right', 'top']].set_visible(False)
    
    # Save plot
    plt.savefig(f'{OUTPUT_DIR}/supp2_bottom.pdf', transparent=True)
    plt.close()

def plot_benchmarks_ablation(y, value):
    """Plot cell type accuracy ablation study."""
    new_y = {}
    for k, v in y.items():
        k = ds_to_name[k]
        for kk, vv in v.items():
            if 'schaf' not in kk:
                continue
            kk = bm_to_name[kk]
            kk = kk.replace('no', 'only')
            k = k.replace('no', 'only')
            if kk not in new_y:
                new_y[kk] = {}
            new_y[kk][k] = vv
    y = new_y

    # Create plot
    _, ax = plt.subplots(1, 1, figsize=(8, 8))
    datasets = list(y.keys())
    benchmarks = ['In-Sample Mouse', 'In-Sample Xenium MBC', 'New-Sample Xenium MBC']
    values = np.array([[y[dataset].get(benchmark, 0) for benchmark in benchmarks] for dataset in datasets])

    # Plot bars
    bar_width = 0.1
    x = np.arange(len(benchmarks))
    for i, dataset in enumerate(datasets):
        plt.bar(x + i * bar_width, values[i], width=bar_width, label=dataset)

    # Customize plot
    plt.xlabel('Dataset')
    plt.ylabel(value)
    plt.xticks(x + bar_width * (len(datasets) - 1) / 2, [q.replace('no', 'only') for q in benchmarks])
    plt.ylim(.4, 1.)
    plt.legend(loc='upper right')
    plt.setp(ax.spines.values(), linewidth=2)
    ax.spines[['right', 'top']].set_visible(False)
    plt.tight_layout()

    # Save plot
    plt.savefig(f'{OUTPUT_DIR}/celltype_ablation_paired.png', dpi=400, transparent=True)
    plt.close()

def plot_gene_correlation_distributions(y, ablation_only=False):
    """Plot gene correlation distributions across datasets and benchmarks.
    
    Args:
        y: Dictionary of correlation values
        ablation_only: If True, only plot SCHAF and its ablations
    """
    new_y = {}
    for k, v in y.items():
        for kk, vv in v.items():
            if ablation_only and 'schaf' not in kk:
                continue
            elif not ablation_only and 'stage' in kk:
                continue
            if k not in new_y:
                new_y[k] = {}
            new_y[k][kk] = vv
    y = new_y

    # Create plot
    fig, axs = plt.subplots(len(y), 1, figsize=(12, 18))
    y_name_tings = ['mouse', 'in_sample', 'out_of_sample']
    
    for i, dataset_name in enumerate(y_name_tings):
        benchmarks = y[dataset_name]
        axs[i].set_title(ds_to_name[dataset_name], fontsize=18)

        # Plot histograms
        benchmark_names = list(benchmarks.keys())
        n, bins, patches = axs[i].hist(
            [sorted(list(benchmarks[bm_name].values())) for bm_name in benchmark_names], 
            bins=np.arange(-.2 if 'out' in dataset_name else 0, 1.01, .1),
            label=[f'{bm_to_name[bm_name]} (Mean: {np.mean(list(benchmarks[bm_name].values())):.4f})' for bm_name in benchmark_names]
        )

        # Customize subplot
        axs[i].set_xlabel('Spatial Correlation', size='x-large')
        axs[i].set_ylabel('Number of Genes', size='x-large')
        axs[i].legend(loc='upper right')
        plt.setp(axs[i].spines.values(), linewidth=2)

    plt.subplots_adjust(hspace=0.4)
    plt.tight_layout()

    # Save plot
    plt.savefig(f'{OUTPUT_DIR}/dists_benchmarks_paired.png', dpi=400, transparent=True)
    plt.close()


In [None]:
# Calculate cell type accuracy metrics
def small_adj2(new_cts):
    """Adjust cell type labels for out-of-sample data."""
    return np.array([
        7 if int(x) == 8 else 1 if int(x) == 5 else x for x in new_cts
    ])

def small_adj(new_cts):
    """Adjust cell type labels for in-sample data."""
    return np.array([
        9 if int(x) == 10 else x for x in new_cts
    ])

# Load cell type predictions
load_the_cts = {}
for dataset in datasets:
    res = {}
    for benchmark in benchmarks + ['schaf']:
        if dataset == 'out_of_sample':
            name = f'benchmark_celltypes_saved/{dataset}_{benchmark}.npy'
            labels = np.load(name)
            res[benchmark] = labels
        else:
            res2 = {}
            for fold in range(4):
                name = f'benchmark_celltypes_saved/{dataset}_{benchmark}_{fold}.npy'
                labels = np.load(name)
                res2[fold] = labels
            res[benchmark] = res2
    load_the_cts[dataset] = res

# Calculate out-of-sample accuracy
good_inds = []
for i, l in enumerate(dataset_to_benchmark_to_pred['out_of_sample']['spirit'].obs.index):
    if l in pred_out_of_sample_for_cts.obs.index:
        good_inds.append(i)
good_inds = np.array(good_inds)

benchmark_to_out_of_sample_ct_acc = {}
for benchmark in benchmarks + ['schaf']:
    gt = small_adj2(np.array(ground_truth_out_of_sample_for_cts.obs['broad_clusters']))
    pred = small_adj2(load_the_cts['out_of_sample'][benchmark][good_inds])
    benchmark_to_out_of_sample_ct_acc[benchmark] = (pred==gt).sum() / pred.shape[0]

# Calculate in-sample accuracy
benchmark_to_in_sample_ct_acc = {}
benchmark_to_in_sample_ct_err = {}
for benchmark in benchmarks + ['schaf']:
    ct_acc = {}
    for fold in range(4):
        gt = small_adj(np.array(dataset_to_benchmark_to_pred['in_sample']['gt'][fold].obs['broad_clusters']))
        pred = small_adj(load_the_cts['in_sample'][benchmark][fold])
        ct_acc[fold] = (pred==gt).sum() / pred.shape[0]
    ct_acc_res = sum(ct_acc[z] * in_sample_fold_to_prop[z] for z in range(4))
    ct_acc_err = (sum((((ct_acc[z] - ct_acc_res)**2 for z in in_sample_fold_to_prop))) / (3.) / (4.) )**.5
    benchmark_to_in_sample_ct_acc[benchmark] = ct_acc_res
    benchmark_to_in_sample_ct_err[benchmark] = ct_acc_err

# Calculate mouse accuracy
benchmark_to_mouse_ct_acc = {}
benchmark_to_mouse_ct_err = {}
for benchmark in benchmarks + ['schaf']:
    ct_acc = {}
    for fold in range(4):
        gt = np.array(dataset_to_benchmark_to_pred['mouse']['gt'][fold].obs['broad_clusters'])
        pred = load_the_cts['mouse'][benchmark][fold]
        ct_acc[fold] = (pred==gt).sum() / pred.shape[0]
    ct_acc_res = sum(ct_acc[z] * mouse_fold_to_prop[z] for z in range(4))
    ct_acc_err = (sum((((ct_acc[z] - ct_acc_res)**2 for z in mouse_fold_to_prop))) / (3.) / (4.) )**.5
    benchmark_to_mouse_ct_acc[benchmark] = ct_acc_res
    benchmark_to_mouse_ct_err[benchmark] = ct_acc_err

# Organize results by dataset
dataset_to_benchmark_to_ct_acc = {
    'in_sample': benchmark_to_in_sample_ct_acc,
    'out_of_sample': benchmark_to_out_of_sample_ct_acc,
    'mouse': benchmark_to_mouse_ct_acc
}

dataset_to_benchmark_to_ct_err = {
    'in_sample': benchmark_to_in_sample_ct_err,
    'out_of_sample': defaultdict(int),
    'mouse': benchmark_to_mouse_ct_err
}


In [None]:
# Generate paired benchmark figures

# 1. Cell Type Accuracy Comparison
plot_benchmarks_cell_type_accuracy(dataset_to_benchmark_to_ct_acc, 'Cell Type Accuracy', dataset_to_benchmark_to_ct_err)

# 2. Cell Type Accuracy Ablation Study
plot_benchmarks_ablation(dataset_to_benchmark_to_ct_acc, 'Cell Type Accuracy')

# 3. Gene Correlation Distribution Analysis
# Full comparison across all benchmarks
plot_gene_correlation_distributions(dataset_to_benchmark_to_corrs, ablation_only=False)

# 4. Gene Correlation Distribution Analysis - Ablation Study
# Only SCHAF and its ablations
plot_gene_correlation_distributions(dataset_to_benchmark_to_corrs, ablation_only=True)


In [None]:
# # SCHAF Figures Generation

# This notebook generates all figures for the SCHAF paper. It includes:
# - Data loading and preprocessing
# - Visium data processing
# - HTAPP/MSKCC/Xenium visualization
# - Program analysis and correlation computation
# - Expression transformation and cell type prediction

# The code is organized into the following sections:
# 1. Configuration and imports
# 2. Data loading utilities
# 3. Analysis functions
# 4. Visualization functions
# 5. Figure generation


In [None]:
# Standard libraries
import os
import sys
import numpy as np
import pandas as pd
import logging
import datetime
import json
import functools
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Set, Optional, Union

# Data processing
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from scipy.stats import gaussian_kde, zscore
import sklearn.metrics
from numba import njit, prange
from scipy.spatial import cKDTree

# Image processing
import PIL
from PIL import Image
import imageio.v3 as iio
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Configure system settings
PIL.Image.MAX_IMAGE_PIXELS = 4017126500
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path constants
DATA_ROOT = '/storage/ccomiter/schaf_for_revision052424'
BASE_DIR = os.path.dirname(os.getcwd())
OUTPUT_DIR = os.path.join(DATA_ROOT, 'final_figures_schaf_revision_pngs')
XENIUM_DIR = os.path.join(DATA_ROOT, 'data/xenium_cancer')
MOUSE_DIR = os.path.join(DATA_ROOT, 'data/mouse_pup')

# Sample constants
HTAPP_KEYS = ['4531', '7179', '7479', '7629', '932', '6760', '7149', '4381', '8239']
PLACENTA_SAMPLES = ['7', '8', '9', '11']
LUNG_CANCER_SAMPLES = ['139193', '138681', '146259', '117956']

# Gene mapping dictionaries
HTAPP_GENES = {
    '7149': 'sumf2',
    '932': 'sat1',
    '6760': 'commd1',
    '4381': 'ppp1r15a'
}

PLACENTA_GENES = {
    '7': 'cd74',
    '8': 'hla-dra',
    '9': 'cd68',
    '11': 'cd14'
}

LUNG_CANCER_GENES = {
    '139193': 'cd3d',
    '138681': 'cd8a',
    '146259': 'cd4',
    '117956': 'cd19'
}

# Processing constants
MAX_VERT_DIST = 10000
VISIUM_TRANSFORM = np.array([
    [0.130157158, 2.594980119, -12243.84897],
    [-2.594980119, 0.130157158, 40352.06194],
    [0, 0, 1],
])

# Image region constants
XENIUM_LIMITS = {'x': 17700, 'y': 12900}
MOUSE_LIMITS = {'x': 36500, 'y': 19500}

# Sample code mappings
PLACENTA_CODE_MAP = {
    '7': 'JS34',
    '8': 'JS40',
    '9': 'JS35',
    '11': 'JS36'
}

# Figure parameters
FIGURE_PARAMS = {
    'dpi': 400,
    'transparent': True,
    'font_size': {
        'xx-large': 16,
        'x-large': 14,
        'large': 12
    }
}


In [None]:
# Correlation and visualization utility functions
def get_cell_corrs(pred, real):
    """Compute cell-wise correlations between predicted and real expression."""
    return np.array([
        np.corrcoef(p, r)[0,1] if not np.isnan(np.corrcoef(p, r)[0,1]) else 0 
        for p, r in zip(pred, real)
    ])

def get_gene_corrs(pred, real):
    """Compute gene-wise correlations between predicted and real expression."""
    return np.array([
        np.corrcoef(p, r)[0,1] if not np.isnan(np.corrcoef(p, r)[0,1]) else 0 
        for p, r in zip(pred.T, real.T)
    ])

def gene_corr_graph(pred, true, path):
    """Create and save gene correlation plot."""
    plt.figure(figsize=(10, 10))
    plt.hist(get_gene_corrs(pred, true), bins=50)
    plt.savefig(path)
    plt.close()

def cell_corr_graph(pred, true, path):
    """Create and save cell correlation plot."""
    plt.figure(figsize=(10, 10))
    plt.hist(get_cell_corrs(pred, true), bins=50)
    plt.savefig(path)
    plt.close()

def make_cell_corr_vis(pred, real, nonzero_coords, nonzero_areas, orig_hist, path):
    """Create and save cell correlation visualization."""
    cell_corrs = get_cell_corrs(pred, real)
    plt.figure(figsize=(10, 10))
    plt.scatter(nonzero_coords[:, 0], nonzero_coords[:, 1], 
               c=cell_corrs, cmap='coolwarm', alpha=0.5, s=nonzero_areas)
    plt.imshow(orig_hist)
    plt.colorbar()
    plt.savefig(path)
    plt.close()

def better_gene_corr_graph(pred, true, path):
    """Create and save enhanced gene correlation plot."""
    plt.figure(figsize=(10, 10))
    sns.histplot(data=get_gene_corrs(pred, true), bins=50, kde=True)
    plt.savefig(path)
    plt.close()

def better_cell_corr_graph(pred, true, path):
    """Create and save enhanced cell correlation plot."""
    plt.figure(figsize=(10, 10))
    sns.histplot(data=get_cell_corrs(pred, true), bins=50, kde=True)
    plt.savefig(path)
    plt.close()


In [None]:
# Model utility functions
def get_loss(the_net, criterion, dataloader, device, transforms=None, attention_fn=None, is_ae=False):
    """Compute loss for a model on a dataset."""
    the_net.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in dataloader:
            if transforms is not None:
                batch = transforms(batch)
            if attention_fn is not None:
                batch = attention_fn(batch)
            if is_ae:
                x = batch.to(device)
                out = the_net(x)
                loss = criterion(out, x)
            else:
                x, y = batch
                x = x.to(device)
                y = y.to(device)
                out = the_net(x)
                loss = criterion(out, y)
            total_loss += loss.item()
            num_batches += 1
            
    return total_loss / num_batches

def get_res(the_net, dataloader, device, transforms=None, with_spatial=False, attention_fn=None):
    """Get model predictions for a dataset."""
    the_net.eval()
    all_preds = []
    all_labels = []
    all_xs = []
    all_ys = []
    
    with torch.no_grad():
        for batch in dataloader:
            if transforms is not None:
                batch = transforms(batch)
            if attention_fn is not None:
                batch = attention_fn(batch)
            if with_spatial:
                x, y, xs, ys = batch
                all_xs.extend(xs.numpy())
                all_ys.extend(ys.numpy())
            else:
                x, y = batch
            x = x.to(device)
            pred = the_net(x)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(y.numpy())
            
    if with_spatial:
        return np.array(all_preds), np.array(all_labels), np.array(all_xs), np.array(all_ys)
    return np.array(all_preds), np.array(all_labels)

def get_spatial_corr_mean(the_net, indiv_dataloaders, device, spatial_gene_inds, transforms=None):
    """Compute mean spatial correlation for specific genes."""
    all_corrs = []
    for dl in indiv_dataloaders:
        preds, labels = get_res(the_net, dl, device, transforms)
        preds = preds[:, spatial_gene_inds]
        labels = labels[:, spatial_gene_inds]
        corrs = get_gene_corrs(preds, labels)
        all_corrs.append(np.mean(corrs))
    return np.mean(all_corrs)


In [None]:
def set_fig_params():
    """Set global figure parameters for consistent visualization."""
    plt.rcParams['figure.figsize'] = [10, 10]
    plt.rcParams['figure.dpi'] = 100
    plt.rcParams['savefig.dpi'] = 300
    plt.rcParams['font.size'] = 12
    sns.set_style('whitegrid')
    
def load_and_process_visium_data(base_dir=BASE_DIR, max_vert_dist=MAX_VERT_DIST):
    """Load and process Visium spatial transcriptomics data."""
    visium_path = os.path.join(base_dir, 'htapp_supervise/new_schaf_experiment_scripts/more_data/xenium')
    visium_adata = sc.read_10x_h5(os.path.join(visium_path, 'visium_breast_xenium.h5'))
    visium_adata.var_names_make_unique()
    
    # Process tissue positions
    tissue_positions = pd.read_csv(
        os.path.join(visium_path, 'tissue_positions_visium.csv')
    ).set_index('barcode')
    tissue_positions = tissue_positions.loc[visium_adata.obs.index]
    visium_adata.obs = tissue_positions
    
    # Filter and process coordinates
    vert_dist = (visium_adata.obs['pxl_col_in_fullres'].max() - 
                visium_adata.obs['pxl_col_in_fullres'])
    good_vis_inds = np.where(vert_dist <= max_vert_dist)[0]
    
    visium_xs = visium_adata.obs['pxl_row_in_fullres'][good_vis_inds].values
    visium_xs = visium_xs - visium_xs.min()
    visium_ys = vert_dist[good_vis_inds].values
    
    visium_adata = visium_adata[good_vis_inds]
    visium_adata.obs['x'] = visium_xs
    visium_adata.obs['y'] = visium_ys
    
    sc.pp.log1p(visium_adata)
    visium_adata.X = visium_adata.X.todense()
    visium_adata.var.index = [gene.lower() for gene in visium_adata.var.index]
    
    return visium_adata

def load_base_images():
    """Load base Xenium and mouse histology images."""
    # Load histology images
    in_sample_hist = iio.imread(os.path.join(XENIUM_DIR, 'xenium_hist.png'))
    out_of_sample = iio.imread(os.path.join(XENIUM_DIR, 'HE_other_sample_xenium.tif'))
    mouse_hist = iio.imread(os.path.join(MOUSE_DIR, 'Xenium_V1_mouse_pup_he_image.ome.tif'))
    
    # Create fold dictionaries
    in_sample_fold_to_hist = {
        0: in_sample_hist[:XENIUM_LIMITS['y'], :XENIUM_LIMITS['x']],
        1: in_sample_hist[XENIUM_LIMITS['y']:25761, XENIUM_LIMITS['x']:35402],
        2: in_sample_hist[XENIUM_LIMITS['y']:25761, :XENIUM_LIMITS['x']],
        3: in_sample_hist[:XENIUM_LIMITS['y'], XENIUM_LIMITS['x']:35402]
    }

    mouse_fold_to_hist = {
        0: mouse_hist[:MOUSE_LIMITS['y'], :MOUSE_LIMITS['x']],
        1: mouse_hist[MOUSE_LIMITS['y']:41081, MOUSE_LIMITS['x']:81654],
        2: mouse_hist[MOUSE_LIMITS['y']:41081, :MOUSE_LIMITS['x']],
        3: mouse_hist[:MOUSE_LIMITS['y'], MOUSE_LIMITS['x']:81654]
    }
    
    return in_sample_hist, out_of_sample, mouse_hist, in_sample_fold_to_hist, mouse_fold_to_hist

def load_htapp_images(base_dir=BASE_DIR):
    """Load HTAPP histology images."""
    htapp_hists = {}
    hists_dir = os.path.join(base_dir, 'htapp_supervise/new_schaf_experiment_scripts/more_data/htapp_hists')
    
    for f in os.listdir(hists_dir):
        if not f.endswith('.tif'):
            continue
        k = f.split('.')[0]
        htapp_hists[k] = iio.imread(os.path.join(hists_dir, f))
    
    return htapp_hists

def load_placenta_images(image_dir='newest_bestest_placenta_hes'):
    """
    Load placenta histology images.
    
    Args:
        image_dir (str): Directory containing placenta images
        
    Returns:
        dict: Dictionary mapping sample IDs to histology images
    """
    placenta_hists = {}
    for k, code in PLACENTA_CODE_MAP.items():
        img_path = os.path.join(image_dir, f'{code}.jpg')
        placenta_hists[k] = iio.imread(img_path)
        
    return placenta_hists

def load_sc_data():
    """Load single-cell data for all HTAPP samples."""
    sc_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/final_scs/schtapp'
    the_scs = {}
    
    for file in os.listdir(sc_dir):
        for k in HTAPP_KEYS:
            if k not in file:
                continue
            sc_adata = sc.read_h5ad(f'{sc_dir}/{file}')
            new_v = sc.AnnData(X=np.array(sc_adata.obsm['counts'].todense()), obs=sc_adata.obs)
            sc.pp.log1p(new_v)
            new_v.var.index = sc_adata.uns['counts_var']
            sc_adata = new_v
            sc_adata.var.index = [q.lower() for q in sc_adata.var.index]
            the_scs[k] = sc_adata
    
    return the_scs

def load_merfish_data():
    """Load MERFISH data for all HTAPP samples."""
    mers_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/final_mers'
    merfish_htapp = {}
    
    for f in os.listdir(mers_dir):
        key = f.split('_')[0]
        mer = sc.read_h5ad(os.path.join(mers_dir, f'{key}_merfish.h5ad'))
        mer.X = np.array(mer.obsm['counts'].todense())
        sc.pp.log1p(mer)
        mer.var.index = [q.lower() for q in mer.var.index]
        merfish_htapp[key] = mer
    
    return merfish_htapp


In [None]:
def compute_program_correlations_and_scores(ground_truth_data, prediction_data, gene_sets, all_cancer_programs, fold_to_prop=None):
    """
    Compute correlations and scores for gene programs between ground truth and predictions.
    
    Args:
        ground_truth_data: Dict of ground truth AnnData objects by fold
        prediction_data: Dict of prediction AnnData objects by fold
        gene_sets: Dict of gene sets to analyze
        all_cancer_programs: Dict of cancer programs and their genes
        fold_to_prop: Dict of fold proportions for weighted averaging
    
    Returns:
        Tuple of dicts containing program correlations and scores
    """
    gene_set_to_fold_to_program_corrs = {}
    gene_set_to_fold_to_program_scores = {}
    
    for name, genes in gene_sets.items():
        fold_to_program_corrs = {}
        fold_to_program_scores = {}
        
        for z in range(4):
            p = prediction_data[z][::,genes]
            t = ground_truth_data[z][::, genes]
            p_scores, p_corrs = get_prog_info(t, p, all_cancer_programs)
            fold_to_program_corrs[z] = p_corrs
            fold_to_program_scores[z] = p_scores
            
        gene_set_to_fold_to_program_corrs[name] = fold_to_program_corrs
        gene_set_to_fold_to_program_scores[name] = fold_to_program_scores
    
    if fold_to_prop is not None:
        gene_set_to_avg_program_corrs = {}
        gene_set_to_avg_program_scores = {}
        
        for name, genes in gene_sets.items():
            avg_program_corrs = {}
            avg_program_scores = {}
            
            for prog in gene_set_to_fold_to_program_corrs[name][0]:
                avg_program_corrs[prog] = sum(gene_set_to_fold_to_program_corrs[name][z][prog] * fold_to_prop[z] for z in range(4))
                avg_program_scores[prog] = sum(gene_set_to_fold_to_program_scores[name][z][prog] * fold_to_prop[z] for z in range(4))
            
            gene_set_to_avg_program_corrs[name] = avg_program_corrs
            gene_set_to_avg_program_scores[name] = avg_program_scores
            
        return gene_set_to_avg_program_corrs, gene_set_to_avg_program_scores
    
    return gene_set_to_fold_to_program_corrs, gene_set_to_fold_to_program_scores


In [None]:
def make_label(info, celltypes):
    """
    Create a normalized label vector from cell type information.
    
    Args:
        info: DataFrame row containing cell type information
        celltypes: List of cell types to check
        
    Returns:
        Normalized numpy array of cell type labels
    """
    res = []
    for ct in celltypes:
        if ct in info and info[ct]:
            res.append(1)
        else:
            res.append(0)
    res = np.array(res)
    if res.sum() > 0:
        res = res / res.sum()
    return res

def process_spatial_labels(merfish_data, histology_data, celltypes, distance_threshold=None):
    """
    Process spatial labels by matching MERFISH data points to histology annotations.
    
    Args:
        merfish_data: Dict of MERFISH AnnData objects
        histology_data: Dict of histology coordinate DataFrames
        celltypes: List of cell types to process
        distance_threshold: Optional distance threshold for filtering matches
        
    Returns:
        Dict of processed labels and distances
    """
    to_plot_xs = {}
    to_plot_ys = {}
    to_select = {}
    the_labels = {}
    the_dists = {}
    
    for k in merfish_data:
        try:
            select = np.where(merfish_data[k].obs[celltypes].sum(axis=1)>0)[0]
            to_select[k] = select
            xs = list(merfish_data[k].obs['x'][select])
            ys = list(merfish_data[k].obs['y'][select])
            to_plot_xs[k] = xs
            to_plot_ys[k] = ys
        except:
            xs = list(merfish_data[k].obs['x'])
            ys = list(merfish_data[k].obs['y'])
            
        tree = cKDTree(np.array(list(zip(xs, ys))))
        labels = []
        dists = []
        
        for x, y in zip(histology_data[k]['x'], histology_data[k]['y']):
            p = (x, y)
            dd, ind = tree.query(p, k=1)
            label = make_label(merfish_data[k].obs.iloc[ind], celltypes)
            labels.append(label)
            dists.append(dd)
            
        the_labels[k] = np.array(labels)
        the_dists[k] = np.array(dists)
        
        if distance_threshold is not None:
            mask = the_dists[k] < distance_threshold
            the_labels[k] = the_labels[k][mask]
            the_dists[k] = the_dists[k][mask]
            
    return {
        'labels': the_labels,
        'distances': the_dists,
        'coordinates': {'x': to_plot_xs, 'y': to_plot_ys},
        'selected_indices': to_select
    }


In [None]:
# Data loading functions
def load_placenta_images(image_dir='newest_bestest_placenta_hes'):
    """
    Load placenta histology images.
    
    Args:
        image_dir (str): Directory containing placenta images
        
    Returns:
        dict: Dictionary mapping sample IDs to histology images
    """
    PIL.Image.MAX_IMAGE_PIXELS = 4017126500
    
    k_to_code = {
        '7': 'JS34',
        '8': 'JS40',
        '9': 'JS35',
        '11': 'JS36'
    }
    
    placenta_hists = {}
    for k in tqdm(k_to_code.keys(), desc="Loading placenta images"):
        img_path = os.path.join(image_dir, f'{k_to_code[k]}.jpg')
        placenta_hists[k] = iio.imread(img_path)
        
    return placenta_hists

def load_and_process_visium_data(base_dir=BASE_DIR, max_vert_dist=10000):
    """
    Load and process Visium spatial transcriptomics data.
    
    Args:
        base_dir (str): Base directory for data
        max_vert_dist (int): Maximum vertical distance for filtering spots
        
    Returns:
        AnnData: Processed Visium data
    """
    visium_path = os.path.join(base_dir, 'htapp_supervise/new_schaf_experiment_scripts/more_data/xenium')
    visium_adata = sc.read_10x_h5(os.path.join(visium_path, 'visium_breast_xenium.h5'))
    visium_adata.var_names_make_unique()
    
    tissue_positions = pd.read_csv(
        os.path.join(visium_path, 'tissue_positions_visium.csv')
    ).set_index('barcode')
    tissue_positions = tissue_positions.loc[visium_adata.obs.index]
    visium_adata.obs = tissue_positions
    
    vert_dist = (visium_adata.obs['pxl_col_in_fullres'].max() - 
                visium_adata.obs['pxl_col_in_fullres'])
    good_vis_inds = np.where(vert_dist <= max_vert_dist)[0]
    
    visium_xs = visium_adata.obs['pxl_row_in_fullres'][good_vis_inds].values
    visium_xs = visium_xs - visium_xs.min()
    visium_ys = vert_dist[good_vis_inds].values
    
    visium_adata = visium_adata[good_vis_inds]
    visium_adata.obs['x'] = visium_xs
    visium_adata.obs['y'] = visium_ys
    
    sc.pp.log1p(visium_adata)
    return visium_adata

# Constants and mappings
XENIUM_LIMITS = {'x': 17700, 'y': 12900}
MOUSE_LIMITS = {'x': 36500, 'y': 19500}

PLACENTA_CODE_MAP = {
    '7': 'JS34',
    '8': 'JS40',
    '9': 'JS35',
    '11': 'JS36'
}

# Create fold dictionaries for image regions
def create_fold_dictionaries(in_sample_hist, mouse_hist):
    """
    Create dictionaries mapping folds to image regions.
    
    Args:
        in_sample_hist: In-sample histology image
        mouse_hist: Mouse histology image
        
    Returns:
        tuple: (in_sample_fold_to_hist, mouse_fold_to_hist) dictionaries
    """
    in_sample_fold_to_hist = {
        0: in_sample_hist[:XENIUM_LIMITS['y'], :XENIUM_LIMITS['x']],
        1: in_sample_hist[XENIUM_LIMITS['y']:25761, XENIUM_LIMITS['x']:35402],
        2: in_sample_hist[XENIUM_LIMITS['y']:25761, :XENIUM_LIMITS['x']],
        3: in_sample_hist[:XENIUM_LIMITS['y'], XENIUM_LIMITS['x']:35402]
    }

    mouse_fold_to_hist = {
        0: mouse_hist[:MOUSE_LIMITS['y'], :MOUSE_LIMITS['x']],
        1: mouse_hist[MOUSE_LIMITS['y']:41081, MOUSE_LIMITS['x']:81654],
        2: mouse_hist[MOUSE_LIMITS['y']:41081, :MOUSE_LIMITS['x']],
        3: mouse_hist[:MOUSE_LIMITS['y'], MOUSE_LIMITS['x']:81654]
    }
    
    return in_sample_fold_to_hist, mouse_fold_to_hist


In [None]:
# Function decorator for logging
def log_function_call(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        logging.info(f'Calling {func.__name__}')
        try:
            result = func(*args, **kwargs)
            logging.info(f'Successfully completed {func.__name__}')
            return result
        except Exception as e:
            logging.error(f'Error in {func.__name__}: {str(e)}')
            raise
    return wrapper

# Safe figure saving function
@log_function_call
def safe_save_figure(fig, filepath, dpi=FIGURE_PARAMS['dpi'], transparent=FIGURE_PARAMS['transparent']):
    """
    Safely save a matplotlib figure with error handling.
    
    Args:
        fig: matplotlib figure object
        filepath: path to save the figure
        dpi: dots per inch for the output
        transparent: whether to use transparent background
    """
    try:
        # Create output directory if it doesn't exist
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        # Save figure
        fig.savefig(filepath, dpi=dpi, transparent=transparent,
                   bbox_inches='tight')
        logging.info(f'Successfully saved figure to {filepath}')
        
    except Exception as e:
        logging.error(f'Error saving figure to {filepath}: {str(e)}')
        raise


In [None]:
# Utility functions
def trans_good(x):
    """Transform expression values using log1p of rounded exp-1."""
    return np.log1p(((np.exp(x)) - 1.).round())

@njit(parallel=True)
def transform_coordinates(horis, verts):
    """
    Transform spatial coordinates using numba-optimized parallel processing.
    
    Args:
        horis: Horizontal coordinates
        verts: Vertical coordinates
        
    Returns:
        tuple: (new_horis, new_verts) transformed coordinates
    """
    new_horis = np.zeros_like(horis)
    new_verts = np.zeros_like(verts)
    
    for t in prange(len(horis)):
        i = horis[t]
        j = verts[t]
        new_horis[t] = i
        new_verts[t] = j
        
    return new_horis, new_verts


In [None]:
# Analysis functions
@log_function_call
def get_prog_info(input_true_adata, input_pred_adata, programs):
    """
    Calculate program scores and correlations.
    
    Args:
        input_true_adata: AnnData object with true expression
        input_pred_adata: AnnData object with predicted expression
        programs: Dictionary mapping program names to gene lists
        
    Returns:
        tuple: (program_to_score, program_to_corr) dictionaries
    """
    pred_adata = sc.AnnData(X=zscore(input_pred_adata.X, axis=1), 
                           obs=input_pred_adata.obs, 
                           var=input_pred_adata.var)
    true_adata = sc.AnnData(X=zscore(input_true_adata.X, axis=1), 
                           obs=input_true_adata.obs, 
                           var=input_true_adata.var)
    
    program_to_score = {}
    program_to_corr = {}
    
    for program, program_genes in tqdm(programs.items()):
        program_genes = np.intersect1d(program_genes, pred_adata.var.index)
        if len(program_genes) < 5:
            continue
            
        all_ps = np.nan_to_num(pred_adata[::,program_genes].X.mean(axis=1))
        all_ts = np.nan_to_num(true_adata[::,program_genes].X.mean(axis=1))
        
        program_to_corr[program] = np.corrcoef(all_ps, all_ts)[0, 1]
        program_to_score[program] = all_ps.std()
        
    return program_to_score, program_to_corr

@log_function_call
def calculate_cell_type_means(data, cell_types):
    """
    Calculate mean expression for each cell type.
    
    Args:
        data: Expression matrix (cells × genes)
        cell_types: Cell type labels
        
    Returns:
        dict: Mean expression per cell type
    """
    unique_types = np.unique(cell_types)
    means = {}
    
    for ct in unique_types:
        mask = cell_types == ct
        means[ct] = np.mean(data[mask], axis=0)
            
    return means

@log_function_call
def calculate_heterogeneity(data, cell_types):
    """
    Calculate expression heterogeneity within cell types.
    
    Args:
        data: Expression matrix (cells × genes)
        cell_types: Cell type labels
        
    Returns:
        dict: Heterogeneity scores per cell type
    """
    unique_types = np.unique(cell_types)
    heterogeneity = {}
    
    for ct in unique_types:
        mask = cell_types == ct
        ct_data = data[mask]
        
        # Calculate variance across cells
        var = np.var(ct_data, axis=0)
        
        # Calculate coefficient of variation
        mean = np.mean(ct_data, axis=0)
        cv = np.divide(np.sqrt(var), mean,
                      out=np.zeros_like(var),
                      where=mean!=0)
        
        heterogeneity[ct] = cv
        
    return heterogeneity


In [None]:
# Data Loading and Preprocessing Functions
@log_function_call
def load_data(data_path, sample_id):
    """
    Load data for a specific sample.
    
    Args:
        data_path: Path to data directory
        sample_id: Sample identifier
        
    Returns:
        dict: Loaded data
    """
    # Load ground truth data
    ground_truth = np.load(os.path.join(
        data_path, f'ground_truth_{sample_id}.npy'
    ))
    
    # Load predictions
    predictions = np.load(os.path.join(
        data_path, f'predictions_{sample_id}.npy'
    ))
    
    # Load cell type labels
    cell_types = np.load(os.path.join(
        data_path, f'cell_types_{sample_id}.npy'
    ))
    
    return {
        'ground_truth': ground_truth,
        'predictions': predictions,
        'cell_types': cell_types
    }

@log_function_call
def preprocess_data(data_dict):
    """
    Preprocess loaded data.
    
    Args:
        data_dict: Dictionary containing loaded data
        
    Returns:
        dict: Preprocessed data
    """
    # Normalize expression values
    ground_truth_norm = data_dict['ground_truth'] / np.sum(
        data_dict['ground_truth'], axis=1, keepdims=True
    )
    predictions_norm = data_dict['predictions'] / np.sum(
        data_dict['predictions'], axis=1, keepdims=True
    )
    
    # Calculate cell type frequencies
    unique_types, type_counts = np.unique(
        data_dict['cell_types'], 
        return_counts=True
    )
    type_freqs = type_counts / len(data_dict['cell_types'])
    
    return {
        'ground_truth_norm': ground_truth_norm,
        'predictions_norm': predictions_norm,
        'cell_types': data_dict['cell_types'],
        'type_frequencies': dict(zip(unique_types, type_freqs))
    }


In [None]:
# Neural Network Model
class CellTypeClassifier(nn.Module):
    """Neural network for cell type classification."""
    def __init__(self, input_size, num_classes, hidden_sizes=[1024, 256, 64]):
        super().__init__()
        layers = []
        prev_size = input_size
        
        for size in hidden_sizes:
            layers.extend([
                nn.Linear(prev_size, size),
                nn.BatchNorm1d(size),
                nn.ReLU()
            ])
            prev_size = size
            
        layers.extend([
            nn.Linear(prev_size, num_classes),
            nn.Softmax(dim=1)
        ])
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

@log_function_call
def train_celltype_classifier(orig_adata, pred_adata, celltype_name,
                            batch_size=128, epochs=10, lr=1e-3):
    """
    Train a cell type classifier on original data and predict on new data.
    
    Args:
        orig_adata: Original AnnData object with cell type labels
        pred_adata: AnnData object to predict on
        celltype_name: Name of cell type column in obs
        batch_size: Batch size for training
        epochs: Number of training epochs
        lr: Learning rate
        
    Returns:
        np.array: Predicted cell type annotations
    """
    common_var = np.intersect1d(orig_adata.var.index, pred_adata.var.index)
    annos = np.unique(orig_adata.obs[celltype_name])
    anno_to_label = dict(zip(annos, range(len(annos))))
    label_to_anno = dict(zip(range(len(annos)), annos))
    
    # Prepare data
    orig_features = orig_adata[::,common_var].X
    tm = orig_features.mean(axis=0)
    ts = orig_features.std(axis=0)
    orig_features = ((orig_features - tm) / ts)
    orig_features = np.nan_to_num(orig_features)
    
    orig_labels = np.array([anno_to_label[anno]
                           for anno in orig_adata.obs[celltype_name]])
    
    # Create model and optimizer
    model = CellTypeClassifier(common_var.shape[0], len(annos)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Compute class weights
    class_weights = torch.tensor([
        (float(orig_features.shape[0]) / np.sum(orig_labels==i))
        for i in range(len(annos))
    ]).float().to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Create data loaders
    train_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(orig_features),
            torch.from_numpy(orig_labels)
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=6,
        pin_memory=True
    )
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        
        for batch, labels in train_loader:
            batch = batch.to(device) / 10.
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch.float())
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
        print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader.dataset):.6f}')
    
    # Prepare prediction data
    pred_features = pred_adata[::,common_var].X
    pred_features = ((pred_features - pred_features.mean(axis=0)) /
                    pred_features.std(axis=0))
    pred_features = np.nan_to_num(pred_features)
    
    pred_loader = DataLoader(
        TensorDataset(torch.from_numpy(pred_features)),
        batch_size=batch_size,
        shuffle=False,
        num_workers=6,
        pin_memory=True
    )
    
    # Make predictions
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for (batch,) in pred_loader:
            batch = batch.to(device) / 10.
            outputs = model(batch.float())
            predictions.extend(outputs.cpu().numpy().argmax(axis=1))
    
    # Convert predictions to annotations
    new_annos = np.array([label_to_anno[label] for label in predictions])
    
    # Cleanup
    torch.cuda.empty_cache()
    
    return new_annos


In [None]:
# Plotting Functions
@log_function_call
def plot_htapp_schematic(hist_dict, output_dir=OUTPUT_DIR):
    """
    Create and save HTAPP schematic figure.
    
    Args:
        hist_dict (dict): Dictionary of histology images
        output_dir (str): Directory to save output figure
    """
    # Create figure and axes
    fig, ax = plt.subplots(3, 4, figsize=(9, 9), 
                          gridspec_kw={'hspace': 0, 'wspace': 0})
    
    # Plot each histology image
    for i, (k, v) in enumerate(hist_dict.items()):
        row = i // 4
        col = i % 4
        
        # Flip image for correct orientation
        flipped_v = np.transpose(v, (1, 0, 2))
        
        # Plot and format
        ax[row, col].imshow(flipped_v)
        ax[row, col].axis('off')
        ax[row, col].set_title(f'HTAPP {k}', pad=15, loc='center')
    
    # Adjust layout
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    
    # Save figure
    safe_save_figure(fig, os.path.join(output_dir, 'htapp_schematic.png'))
    plt.close()

@log_function_call
def plot_mskcc_schematic(image_dir, output_dir=OUTPUT_DIR, images_per_row=6):
    """
    Create and save MSKCC schematic figure.
    
    Args:
        image_dir (str): Directory containing image files
        output_dir (str): Directory to save output figure
        images_per_row (int): Number of images per row in grid
    """
    # Get image files
    image_files = [f for f in os.listdir(image_dir) 
                  if f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    # Calculate grid dimensions
    n_images = len(image_files)
    n_cols = images_per_row
    n_rows = (n_images + n_cols - 1) // n_cols
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8))
    axes = axes.flatten()
    
    # Plot images
    for i, image_file in tqdm(enumerate(image_files), desc="Processing images"):
        # Load and process image
        img_path = os.path.join(image_dir, image_file)
        img = Image.open(img_path)
        
        # Special processing for specific images
        if any(id in image_file for id in ['133729', '129477']):
            img = np.array(img)
            third = img.shape[0] // 3
            img = img[third:-third]
        
        # Plot and format
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f'MSKCC {image_file[:6]}')
    
    # Clear unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    
    # Adjust and save
    plt.tight_layout()
    safe_save_figure(fig, os.path.join(output_dir, 'mskcc_schematic.png'))
    plt.close()

@log_function_call
def compute_mean_zscore(adata, gene_set, truncate_at=5):
    """
    Compute mean z-score for a gene set with truncation.
    
    Args:
        adata (AnnData): Expression data
        gene_set (list): List of genes to analyze
        truncate_at (float): Value to truncate z-scores at
        
    Returns:
        np.array: Mean z-scores per cell
    """
    # Get genes present in data
    genes = np.intersect1d(gene_set, adata.var.index)
    if len(genes) == 0:
        return None
        
    # Calculate z-scores
    zscores = zscore(adata[:, genes].X, axis=1)
    
    # Truncate values
    zscores = np.clip(zscores, -truncate_at, truncate_at)
    
    # Calculate mean
    return np.mean(zscores, axis=1)


In [None]:
# Image loading and processing functions
def load_htapp_images(hists_dir='/mounts/stultzlab03/ccomiter/htapp_supervise/final_data0315/hists_may_good'):
    """
    Load HTAPP histology images.
    
    Args:
        hists_dir (str): Directory containing HTAPP histology images
        
    Returns:
        dict: Dictionary mapping sample IDs to histology images
    """
    htapp_hists = {}
    
    for f in tqdm(os.listdir(hists_dir), desc="Loading HTAPP images"):
        # Skip non-PNG and segmentation files
        if not f.endswith('.png') or 'seg' in f:
            continue
            
        # Load image
        k = f.split('.')[0]
        htapp_hists[k] = iio.imread(os.path.join(hists_dir, f))
        
    return htapp_hists

def load_placenta_images(image_dir='newest_bestest_placenta_hes'):
    """
    Load placenta histology images.
    
    Args:
        image_dir (str): Directory containing placenta images
        
    Returns:
        dict: Dictionary mapping sample IDs to histology images
    """
    # Configure PIL for large images
    PIL.Image.MAX_IMAGE_PIXELS = 4017126500
    
    # Sample ID to filename mapping
    k_to_code = {
        '7': 'JS34',
        '8': 'JS40',
        '9': 'JS35',
        '11': 'JS36'
    }
    
    # Load images
    placenta_hists = {}
    for k in tqdm(k_to_code.keys(), desc="Loading placenta images"):
        img_path = os.path.join(image_dir, f'{k_to_code[k]}.jpg')
        placenta_hists[k] = iio.imread(img_path)
        
    return placenta_hists

def load_base_images():
    """
    Load base histology images for Xenium and mouse data.
    
    Returns:
        tuple: (in_sample_hist, out_of_sample, mouse_hist) images
    """
    # Load histology images
    in_sample_hist = iio.imread(os.path.join(XENIUM_DIR, 'xenium_hist.png'))
    out_of_sample = iio.imread(os.path.join(XENIUM_DIR, 'HE_other_sample_xenium.tif'))
    mouse_hist = iio.imread(os.path.join(MOUSE_DIR, 'Xenium_V1_mouse_pup_he_image.ome.tif'))
    
    return in_sample_hist, out_of_sample, mouse_hist


In [None]:
# Visualization functions
def plot_htapp_schematic(hist_dict, output_dir='final_figures_schaf_revision_pngs'):
    """
    Create and save HTAPP schematic figure.
    
    Args:
        hist_dict (dict): Dictionary of histology images
        output_dir (str): Directory to save output figure
    """
    # Create figure and axes
    fig, ax = plt.subplots(3, 4, figsize=(9, 9), 
                          gridspec_kw={'hspace': 0, 'wspace': 0})
    
    # Plot each histology image
    for i, (k, v) in enumerate(hist_dict.items()):
        row = i // 4
        col = i % 4
        
        # Flip image for correct orientation
        flipped_v = np.transpose(v, (1, 0, 2))
        
        # Plot and format
        ax[row, col].imshow(flipped_v)
        ax[row, col].axis('off')
        ax[row, col].set_title(f'HTAPP {k}', pad=15, loc='center')
    
    # Adjust layout
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    
    # Save figure
    plt.savefig(os.path.join(output_dir, 'htapp_schematic.png'), 
                dpi=400, transparent=True)
    plt.close()

def plot_mskcc_schematic(image_dir, output_dir='final_figures_schaf_revision_pngs', 
                        images_per_row=6):
    """
    Create and save MSKCC schematic figure.
    
    Args:
        image_dir (str): Directory containing image files
        output_dir (str): Directory to save output figure
        images_per_row (int): Number of images per row in grid
    """
    # Get image files
    image_files = [f for f in os.listdir(image_dir) 
                  if f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    # Calculate grid dimensions
    n_images = len(image_files)
    n_cols = images_per_row
    n_rows = (n_images + n_cols - 1) // n_cols
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8))
    axes = axes.flatten()
    
    # Plot images
    for i, image_file in tqdm(enumerate(image_files), desc="Processing images"):
        # Load and process image
        img_path = os.path.join(image_dir, image_file)
        img = Image.open(img_path)
        
        # Special processing for specific images
        if any(id in image_file for id in ['133729', '129477']):
            img = np.array(img)
            third = img.shape[0] // 3
            img = img[third:-third]
        
        # Plot and format
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f'MSKCC {image_file[:6]}')
    
    # Clear unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    
    # Adjust and save
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'mskcc_schematic.png'), 
                dpi=400, transparent=True)
    plt.close()


In [None]:
# Program analysis utilities
def compute_mean_zscore(adata, gene_set, truncate_at=5):
    """
    Compute mean z-score for a gene set with truncation.
    
    Args:
        adata (AnnData): Expression data
        gene_set (list): List of genes to analyze
        truncate_at (float): Value to truncate z-scores at
        
    Returns:
        np.array: Mean z-scores per cell
    """
    # Get genes present in data
    genes = np.intersect1d(gene_set, adata.var.index)
    if len(genes) == 0:
        return None
        
    # Calculate z-scores
    zscores = zscore(adata[:, genes].X, axis=1)
    
    # Truncate values
    zscores = np.clip(zscores, -truncate_at, truncate_at)
    
    # Calculate mean
    return np.mean(zscores, axis=1)

# Constants for image regions
XENIUM_LIMITS = {'x': 17700, 'y': 12900}
MOUSE_LIMITS = {'x': 36500, 'y': 19500}

# Create fold dictionaries
def create_fold_dicts(in_sample_hist, mouse_hist):
    """
    Create dictionaries mapping folds to image regions.
    
    Args:
        in_sample_hist: In-sample histology image
        mouse_hist: Mouse histology image
        
    Returns:
        tuple: (in_sample_fold_to_hist, mouse_fold_to_hist) dictionaries
    """
    in_sample_fold_to_hist = {
        0: in_sample_hist[:XENIUM_LIMITS['y'], :XENIUM_LIMITS['x']],
        1: in_sample_hist[XENIUM_LIMITS['y']:25761, XENIUM_LIMITS['x']:35402],
        2: in_sample_hist[XENIUM_LIMITS['y']:25761, :XENIUM_LIMITS['x']],
        3: in_sample_hist[:XENIUM_LIMITS['y'], XENIUM_LIMITS['x']:35402]
    }

    mouse_fold_to_hist = {
        0: mouse_hist[:MOUSE_LIMITS['y'], :MOUSE_LIMITS['x']],
        1: mouse_hist[MOUSE_LIMITS['y']:41081, MOUSE_LIMITS['x']:81654],
        2: mouse_hist[MOUSE_LIMITS['y']:41081, :MOUSE_LIMITS['x']],
        3: mouse_hist[:MOUSE_LIMITS['y'], MOUSE_LIMITS['x']:81654]
    }
    
    return in_sample_fold_to_hist, mouse_fold_to_hist


In [None]:
from sklearn.cluster import AgglomerativeClustering
import sys
import os 
import yaml
import json
from collections import defaultdict
import wandb
import matplotlib as mpl
import scipy
from scipy.stats import zscore
import imageio.v3 as iio
from PIL import Image
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 933120000

import datetime
import random 
import numpy as np 
import pandas as pd
import scipy as sp 
from tqdm import tqdm 
import sklearn
import sklearn.model_selection
import torch
import torchvision
from torchvision import transforms
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as U
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
import scanpy as sc 
from numba import njit, prange
from scipy.spatial import cKDTree
import seaborn as sns
from scipy import stats
from sklearn.metrics import r2_score
import anndata as ad
from typing import Dict, List, Tuple, Optional
import pickle

# Add necessary paths
sys.path.extend([
    ".", "..",
    "/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324",
])

from models import MerNet, JustPartTwo
from utils import *
from plot_utils import *

# Set up GPU
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = torch.device('cuda')


In [None]:
# Configuration and global parameters
import os

# Data paths
DATA_ROOT = '/storage/ccomiter/schaf_for_revision052424'
XENIUM_DIR = os.path.join(DATA_ROOT, 'data/xenium_cancer')
MOUSE_DIR = os.path.join(DATA_ROOT, 'data/mouse_data')
HTAPP_DIR = os.path.join(DATA_ROOT, 'data/htapp_data')
OUTPUT_DIR = os.path.join(DATA_ROOT, 'final_schaf_code/final_figures_schaf_revision_pngs')

# File paths
PROGRAM_FILES = {
    'hallmark': '/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/hallmark_programs.json',
    'cancer': '/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/cancer_programs.json'
}

CELL_TYPE_FILES = {
    'clusters': '/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/analysis/clustering/gene_expression_kmeans_10_clusters/clusters.csv'
}

# Visualization parameters
FIGURE_PARAMS = {
    'dpi': 400,
    'transparent': True,
    'font_sizes': {
        'xx-large': 16,
        'x-large': 14,
        'large': 12,
        'medium': 10,
        'small': 8
    }
}

# Color maps for marker genes
MARKER_GENE_CMAPS = {
    'krt19': plt.cm.Greens,
    'col1a2': plt.cm.YlOrBr,
    'apoc1': plt.cm.Blues,
    'pecam1': plt.cm.Reds
}

# HTAPP sample IDs
HTAPP_SAMPLES = {
    'main': ['6760', '7149', '7179'],
    'extended': ['4531', '6760', '7479', '7629'],
    'all': ['7149', '7179', '932', '6760', '8239', '7629', '4531', '7479', '4381']
}

# Cell type mapping
CELL_TYPE_MAPPING = {
    'broad_to_specific': {
        'Tumor': ['MBC', 'MBC_stem-like', 'MBC_neuronal', 'MBC_chondroid'],
        'Vascular': ['Endothelial', 'Endothelial_sinusoidal', 'Endothelial_angiogenic', 'Endothelial_vascular'],
        'Immune': ['Macrophage', 'Monocyte', 'Neutrophil', 'B', 'T', 'NK'],
        'Fibrosis': ['Fibroblast', 'Chondrocyte', 'Smooth muscle_vascular']
    }
}

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# Error handling and logging setup
import logging
import traceback
from datetime import datetime
from functools import wraps

# Set up logging
log_file = os.path.join(OUTPUT_DIR, f'analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def log_function_call(func):
    """Decorator to log function calls and handle errors."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        func_name = func.__name__
        logger.info(f"Starting {func_name}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"Completed {func_name}")
            return result
        except Exception as e:
            logger.error(f"Error in {func_name}: {str(e)}")
            logger.error(traceback.format_exc())
            raise
    return wrapper

class AnalysisError(Exception):
    """Custom exception for analysis errors."""
    pass

def validate_data(data, name, requirements):
    """Validate input data against requirements."""
    if data is None:
        raise AnalysisError(f"{name} is None")
    
    for req in requirements:
        if req == 'shape':
            if not hasattr(data, 'shape'):
                raise AnalysisError(f"{name} has no shape attribute")
        elif req == 'positive':
            if not np.all(data >= 0):
                raise AnalysisError(f"{name} contains negative values")
        elif req == 'finite':
            if not np.all(np.isfinite(data)):
                raise AnalysisError(f"{name} contains non-finite values")
        elif req == 'normalized':
            if not np.allclose(data.sum(axis=1), 1.0):
                raise AnalysisError(f"{name} is not normalized")

def check_file_exists(filepath):
    """Check if a file exists."""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    return filepath

@log_function_call
def safe_save_figure(fig, filename, **kwargs):
    """Safely save a figure with error handling."""
    try:
        # Ensure the output directory exists
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        
        # Save the figure
        fig.savefig(filename, **kwargs)
        logger.info(f"Saved figure to {filename}")
        
        # Close the figure to free memory
        plt.close(fig)
    except Exception as e:
        logger.error(f"Error saving figure {filename}: {str(e)}")
        raise


In [None]:
# Data loading functions with validation and error handling

@log_function_call
def load_out_of_sample_data():
    """
    Load out-of-sample data with validation.
    
    Returns:
        tuple: (ground_truth_data, predicted_data)
            - ground_truth_data: AnnData object containing ground truth expression
            - predicted_data: AnnData object containing predicted expression
    
    Raises:
        FileNotFoundError: If data files are not found
        AnalysisError: If data validation fails
    """
    try:
        # Load ground truth data
        gt_path = os.path.join(XENIUM_DIR, 'ground_truth_out_of_sample.h5ad')
        check_file_exists(gt_path)
        ground_truth = sc.read_h5ad(gt_path)
        
        # Load predictions
        pred_path = os.path.join(XENIUM_DIR, 'predictions_out_of_sample.h5ad')
        check_file_exists(pred_path)
        predictions = sc.read_h5ad(pred_path)
        
        # Validate data
        for data, name in [(ground_truth, 'ground_truth'), 
                          (predictions, 'predictions')]:
            validate_data(data.X, name, ['shape', 'finite'])
        
        return ground_truth, predictions
    
    except Exception as e:
        logger.error(f"Error loading out-of-sample data: {str(e)}")
        raise

@log_function_call
def load_in_sample_data():
    """
    Load in-sample data for all folds with validation.
    
    Returns:
        tuple: (ground_truth_data, predicted_data)
            Each is a dictionary mapping fold index to AnnData object
    
    Raises:
        FileNotFoundError: If data files are not found
        AnalysisError: If data validation fails
    """
    try:
        ground_truth = {}
        predictions = {}
        
        for fold in range(4):
            # Load ground truth
            gt_path = os.path.join(XENIUM_DIR, f'ground_truth_fold_{fold}.h5ad')
            check_file_exists(gt_path)
            ground_truth[fold] = sc.read_h5ad(gt_path)
            
            # Load predictions
            pred_path = os.path.join(XENIUM_DIR, f'predictions_fold_{fold}.h5ad')
            check_file_exists(pred_path)
            predictions[fold] = sc.read_h5ad(pred_path)
            
            # Validate data
            for data, name in [(ground_truth[fold], f'ground_truth_fold_{fold}'),
                             (predictions[fold], f'predictions_fold_{fold}')]:
                validate_data(data.X, name, ['shape', 'finite'])
        
        return ground_truth, predictions
    
    except Exception as e:
        logger.error(f"Error loading in-sample data: {str(e)}")
        raise

@log_function_call
def load_mouse_data():
    """
    Load mouse data with validation.
    
    Returns:
        tuple: (ground_truth_data, predicted_data)
            Each is a dictionary mapping fold index to AnnData object
    
    Raises:
        FileNotFoundError: If data files are not found
        AnalysisError: If data validation fails
    """
    try:
        ground_truth = {}
        predictions = {}
        
        for fold in range(4):
            # Load ground truth
            gt_path = os.path.join(MOUSE_DIR, f'ground_truth_mouse_fold_{fold}.h5ad')
            check_file_exists(gt_path)
            ground_truth[fold] = sc.read_h5ad(gt_path)
            
            # Load predictions
            pred_path = os.path.join(MOUSE_DIR, f'predictions_mouse_fold_{fold}.h5ad')
            check_file_exists(pred_path)
            predictions[fold] = sc.read_h5ad(pred_path)
            
            # Validate data
            for data, name in [(ground_truth[fold], f'ground_truth_mouse_fold_{fold}'),
                             (predictions[fold], f'predictions_mouse_fold_{fold}')]:
                validate_data(data.X, name, ['shape', 'finite'])
        
        return ground_truth, predictions
    
    except Exception as e:
        logger.error(f"Error loading mouse data: {str(e)}")
        raise

@log_function_call
def load_programs():
    """
    Load and process program definitions.
    
    Returns:
        dict: Combined dictionary of hallmark and cancer programs
            with lowercase gene symbols
    
    Raises:
        FileNotFoundError: If program files are not found
        JSONDecodeError: If JSON parsing fails
    """
    try:
        # Load hallmark programs
        check_file_exists(PROGRAM_FILES['hallmark'])
        with open(PROGRAM_FILES['hallmark']) as f:
            hallmark_programs = json.load(f)
        
        # Load cancer programs
        check_file_exists(PROGRAM_FILES['cancer'])
        with open(PROGRAM_FILES['cancer']) as f:
            cancer_programs = json.load(f)
        
        # Process programs
        hallmark_programs = {k: [g.lower() for g in v['geneSymbols']]
                           for k, v in hallmark_programs.items()}
        cancer_programs = {k: [g.lower() for g in v['geneSymbols']]
                         for k, v in cancer_programs.items()}
        
        # Combine programs
        all_programs = hallmark_programs.copy()
        all_programs.update(cancer_programs)
        
        return all_programs
    
    except Exception as e:
        logger.error(f"Error loading programs: {str(e)}")
        raise


In [None]:
# Set up plotting parameters
plt.rcParams['figure.figsize'] = 10, 10
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

do_large = 1
def set_fig_params():
    plt.rcParams.update(mpl.rcParamsDefault)
    plt.rcParams['pdf.fonttype'] = 42
    weight = 550
    sns.set(context='paper', style='ticks',
            rc={
                'figure.autolayout': True,
                'axes.titlesize': 'xx-large' if do_large else 8,
                'axes.titleweight': weight,
                'figure.titleweight': weight,
                'figure.titlesize': 'xx-large' if do_large else 8,
                'axes.labelsize': 'xx-large' if do_large else 8,
                'axes.labelpad': 2,
                'axes.labelweight': weight,
                'axes.spines.top': False,
                'axes.spines.right': False,
                'xtick.labelsize': 'x-large' if do_large else 7,
                'ytick.labelsize': 'x-large' if do_large else 7,
                'legend.fontsize': 'xx-large' if do_large else 7,
                'figure.figsize': (3.5, 3.5/1.6),
                'xtick.direction': 'out',
                'ytick.direction': 'out',
                'xtick.major.size': 'xx-large' if 0 else 2,
                'ytick.major.size': 'xx-large' if 0 else 2,
                'xtick.major.pad': 2,
                'ytick.major.pad': 2,
                'font.family': 'sans-serif',
                'legend.frameon': False,
            })

set_fig_params()
%matplotlib inline


In [None]:
# Load histology images and define fold regions
xen_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium'
in_sample_hist = iio.imread(os.path.join(xen_dir, 'xenium_hist.png'))
out_of_sample = iio.imread(os.path.join(xen_dir, 'HE_other_sample_xenium.tif'))

# Define fold regions for in-sample histology
y_lim = 12900
x_lim = 17700

in_sample_fold_to_hist = {
    0: in_sample_hist[:y_lim, :x_lim],
    1: in_sample_hist[y_lim:25761, x_lim:35395 if 0 else 35402],
    2: in_sample_hist[y_lim:25755 if 0 else 25761, :x_lim],
    3: in_sample_hist[:y_lim, x_lim:35402],
}

# Load mouse histology and define fold regions
xen_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/all_xenium_new_data/mouse_pup_data'
mouse_hist = iio.imread(os.path.join(xen_dir, 'Xenium_V1_mouse_pup_he_image.ome.tif'))

x_lim = 36500
y_lim = 19500

mouse_fold_to_hist = {
    0: mouse_hist[:y_lim, :x_lim],
    1: mouse_hist[y_lim:41081, x_lim:81654],
    2: mouse_hist[y_lim:38295 if 0 else 41081, :x_lim],
    3: mouse_hist[:y_lim, x_lim:77230 if 0 else 81654],
}


In [None]:
# Function to compute program information
def get_prog_info(input_true_adata, input_pred_adata, programs):
    pred_adata = sc.AnnData(X=zscore(input_pred_adata.X, axis=1), obs=input_pred_adata.obs, var=input_pred_adata.var)
    true_adata = sc.AnnData(X=zscore(input_true_adata.X, axis=1), obs=input_true_adata.obs, var=input_true_adata.var)
    program_to_score = {}
    program_to_corr = {}
    for program, program_genes in tqdm(programs.items()):
        program_genes = np.intersect1d(program_genes, pred_adata.var.index)
        if len(program_genes) < 5:
            continue
        all_ps = np.nan_to_num(pred_adata[::,program_genes].X.mean(axis=1))
        all_ts = np.nan_to_num(true_adata[::,program_genes].X.mean(axis=1))
        program_to_corr[program] = np.corrcoef(all_ps, all_ts)[0, 1]
        program_to_score[program] = all_ps.std()
    return program_to_score, program_to_corr


In [None]:
def calculate_program_metrics(gene_sets, pred_in_sample, ground_truth_in_sample, all_cancer_programs):
    """Calculate program correlations and scores across folds."""
    gene_set_to_fold_to_in_sample_program_corrs = {}
    gene_set_to_fold_to_in_sample_program_scores = {}
    
    for name, genes in gene_sets.items():
        fold_to_in_sample_program_corrs = {}
        fold_to_in_sample_program_scores = {}
        for z in range(4):
            p = pred_in_sample[z][::,genes]
            t = ground_truth_in_sample[z][::, genes]
            p_scores, p_corrs = get_prog_info(t, p, all_cancer_programs)
            fold_to_in_sample_program_corrs[z] = p_corrs
            fold_to_in_sample_program_scores[z] = p_scores
        gene_set_to_fold_to_in_sample_program_corrs[name] = fold_to_in_sample_program_corrs
        gene_set_to_fold_to_in_sample_program_scores[name] = fold_to_in_sample_program_scores
    
    # Calculate weighted averages
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells

    gene_set_to_avg_in_sample_program_corrs = {}
    gene_set_to_avg_in_sample_program_scores = {}

    for name, genes in gene_sets.items():
        avg_in_sample_program_corrs = {}
        avg_in_sample_program_scores = {}

        for prog in gene_set_to_fold_to_in_sample_program_corrs[name][0]:
            avg_in_sample_program_corrs[prog] = sum(gene_set_to_fold_to_in_sample_program_corrs[name][z][prog] * fold_to_prop[z] for z in range(4))
            avg_in_sample_program_scores[prog] = sum(gene_set_to_fold_to_in_sample_program_scores[name][z][prog] * fold_to_prop[z] for z in range(4))
        
        gene_set_to_avg_in_sample_program_corrs[name] = avg_in_sample_program_corrs
        gene_set_to_avg_in_sample_program_scores[name] = avg_in_sample_program_scores
    
    return (gene_set_to_fold_to_in_sample_program_corrs, gene_set_to_fold_to_in_sample_program_scores,
            gene_set_to_avg_in_sample_program_corrs, gene_set_to_avg_in_sample_program_scores)


In [None]:
def make_spot_sum_function(a, b, r):
    """Create spot-based summaries of gene expression for spatial analysis."""
    # Extract coordinates
    a_coords = a.obs[['x', 'y']].values
    b_coords = b.obs[['x', 'y']].values
    
    # Combine coordinates
    combined_coords = np.vstack((a_coords, b_coords))
    
    # Determine grid extent
    x_min, x_max = np.min(combined_coords[:, 0]), np.max(combined_coords[:, 0])
    y_min, y_max = np.min(combined_coords[:, 1]), np.max(combined_coords[:, 1])
    
    # Create grid
    grid_size = 2 * r
    x_edges = np.arange(x_min, x_max + grid_size, grid_size)
    y_edges = np.arange(y_min, y_max + grid_size, grid_size)
    grid_centers = np.array([(x + grid_size / 2, y + grid_size / 2) 
                            for x in x_edges for y in y_edges])
    
    # Build KDTree
    kdtree = KDTree(grid_centers)
    
    # Initialize arrays for spot sums
    spot_sums_a = np.zeros((len(grid_centers), a.X.shape[1]))
    spot_sums_b = np.zeros((len(grid_centers), b.X.shape[1]))
    cspot_sums_a = np.zeros((len(grid_centers), a.X.shape[1]))
    cspot_sums_b = np.zeros((len(grid_centers), b.X.shape[1]))
    
    def assign_points_and_sum(coords, adata, spot_sums, cspot_sums):
        _, indices = kdtree.query(coords, k=1)
        data_matrix = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X
        for i in range(len(coords)):
            spot_idx = indices[i]
            spot_sums[spot_idx] += data_matrix[i]
            cspot_sums[spot_idx] += 1
    
    # Assign points and calculate sums
    assign_points_and_sum(b_coords, b, spot_sums_b, cspot_sums_b)
    assign_points_and_sum(a_coords, a, spot_sums_a, cspot_sums_a)
    
    # Filter spots
    non_zero_b = np.any(cspot_sums_b > 0, axis=1)
    non_zero_a = np.any(cspot_sums_a > 0, axis=1)
    valid_spots = non_zero_a & non_zero_b
    
    # Normalize sums
    spot_means_a = np.zeros_like(spot_sums_a)
    spot_means_b = np.zeros_like(spot_sums_b)
    np.divide(spot_sums_a, cspot_sums_a, where=cspot_sums_a>0, out=spot_means_a)
    np.divide(spot_sums_b, cspot_sums_b, where=cspot_sums_b>0, out=spot_means_b)
    
    return spot_means_a[valid_spots], spot_means_b[valid_spots], grid_centers[valid_spots]


In [None]:
def generate_spatial_plots(pred_data, true_data, cell_type_labels, distance_threshold=30):
    """Generate spatial plots with cell type annotations."""
    # Create color map for cell types
    colors = {
        'Tumor': '#73d56d',
        'Normal': '#f3a3f6',
        'Vascular': '#feb052',
        'Immune': '#99d1fe',
        'Fibrosis': '#ced208'
    }
    
    def make_Ramp(ramp_colors): 
        from colour import Color
        from matplotlib.colors import LinearSegmentedColormap
        color_ramp = LinearSegmentedColormap.from_list('cell_types', [Color(c1).rgb for c1 in ramp_colors])
        return color_ramp
    
    # Process cell type labels
    the_labels = {}
    the_dists = {}
    
    for k in pred_data.keys():
        labels = []
        dists = []
        for ind in range(pred_data[k].shape[0]):
            label = cell_type_labels[k].iloc[ind]
            dd = compute_distance_to_nearest(pred_data[k].obs.iloc[ind], true_data[k])
            labels.append(label)
            dists.append(dd)
        the_labels[k] = np.array(labels)
        the_dists[k] = np.array(dists)
    
    # Generate plots
    k_to_where = {}
    for k in pred_data.keys():
        fig, ax = plt.subplots(figsize=(5, 5 * pred_data[k].shape[1] / pred_data[k].shape[0]))
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Process cell type assignments
        tcs = []
        for l in the_labels[k]:
            vls = [i for i, j in enumerate(l) if j]
            if len(vls) == 0:
                tcs.append(0)
            elif len(vls) == 1:
                tcs.append(vls[0])
            else:
                tcs.append(random.choice(vls))
        tcs = np.array(tcs)
        
        # Filter points by distance threshold
        to_show = np.where((the_dists[k] < distance_threshold) & (np.array([i[j] for i, j in zip(the_labels[k], tcs)]) > 0))[0]
        k_to_where[k] = to_show
        
        # Create scatter plot
        plt.scatter(
            pred_data[k].obs['y'][to_show],
            pred_data[k].obs['x'][to_show],
            c=tcs[to_show],
            s=10,
            cmap=make_Ramp(list(colors.values()))
        )
        
        ax.spines[['left', 'right', 'top', 'bottom']].set_visible(True)
        plt.setp(ax.spines.values(), linewidth=2)
        plt.gca().invert_yaxis()
        
        # Save figure
        fig_name = f'spatial_plot_cell_types_{k}'
        plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
        plt.close()
    
    return k_to_where

def compute_distance_to_nearest(point, reference_data):
    """Compute distance to nearest point in reference dataset."""
    point_coords = np.array([point['x'], point['y']])
    ref_coords = reference_data.obs[['x', 'y']].values
    distances = np.sqrt(np.sum((ref_coords - point_coords) ** 2, axis=1))
    return np.min(distances)


In [None]:
def generate_correlation_heatmap(pred_data, true_data, common_genes):
    """Generate correlation heatmaps between predicted and true data."""
    fold_to_pred_heatmap = {}
    fold_to_true_heatmap = {}
    fold_to_metacorr = {}

    for z in range(4):
        # Prepare data
        pred_arr = np.array(pred_data[z][::,common_genes].X)
        true_arr = np.array(true_data[z][::,common_genes].X.squeeze())
        
        # Calculate correlation matrices
        pred_heatmap = np.corrcoef(pred_arr, rowvar=0)
        true_heatmap = np.corrcoef(true_arr, rowvar=0)
        
        # Cluster genes on first fold
        if z == 0:
            hierarchical_cluster = AgglomerativeClustering(n_clusters=2, affinity='euclidean', linkage='ward')
            labels = hierarchical_cluster.fit_predict(true_arr.T)
            c1_inds = [i for i, l in enumerate(labels) if l]
            c2_inds = [i for i, l in enumerate(labels) if not l]
        
        # Reorder heatmaps based on clustering
        pred_heatmap = pred_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        true_heatmap = true_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        
        # Store results
        fold_to_pred_heatmap[z] = pred_heatmap
        fold_to_true_heatmap[z] = true_heatmap
        fold_to_metacorr[z] = np.corrcoef(pred_heatmap.reshape(-1), true_heatmap.reshape(-1))[0, 1]
    
    # Calculate weighted average across folds
    fold_to_prop = {}
    total_cells = float(sum(true_data[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = true_data[z].shape[0] / total_cells
    
    avg_pred_heatmap = sum([fold_to_pred_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_true_heatmap = sum([fold_to_true_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_metacorr = sum(fold_to_metacorr[z]*fold_to_prop[z] for z in range(4))
    
    # Plot heatmaps
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    sns.heatmap(avg_true_heatmap, ax=ax1, cmap='coolwarm', center=0)
    ax1.set_title('Ground Truth Correlation')
    ax1.set_xticks([])
    ax1.set_yticks([])
    
    sns.heatmap(avg_pred_heatmap, ax=ax2, cmap='coolwarm', center=0)
    ax2.set_title('Predicted Correlation')
    ax2.set_xticks([])
    ax2.set_yticks([])
    
    plt.suptitle(f'Average Metacorrelation: {avg_metacorr:.3f}')
    plt.tight_layout()
    
    # Save figure
    plt.savefig('final_figures_schaf_revision_pngs/correlation_heatmap.png', dpi=400, transparent=True)
    plt.close()
    
    return avg_metacorr

def plot_correlation_histogram(gene_correlations, title, filename):
    """Generate histogram of gene-level correlations."""
    plt.figure(figsize=(8, 5))
    plt.hist(gene_correlations, bins=50, edgecolor='black')
    plt.title(title)
    plt.xlabel('Correlation')
    plt.ylabel('Count')
    plt.axvline(x=np.mean(gene_correlations), color='r', linestyle='--', 
                label=f'Mean: {np.mean(gene_correlations):.3f}')
    plt.axvline(x=np.median(gene_correlations), color='g', linestyle='--',
                label=f'Median: {np.median(gene_correlations):.3f}')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png', dpi=400, transparent=True)
    plt.close()


In [None]:
# Mouse-specific gene lists and program loading
mouse_train_genes = [
    'hmgcs2', 'krt19', 'sdc1', 'scgb3a2', 'dbpht2', 'try10', 'hrc', 
    'pck1', 'gpx6', 'reg3g', 'tubb3', 'sypl2', 'rbp1', 'slc14a2', 'cd5l', 
    'crct1', 'anxa8', 'slc4a4', '1110017d15rik', 'nupr1', 'gm94', 'higd1b', 
    'cxcr2', 'tmem59l', 'rsad2', 'podxl', 'aqp1', 'cd8a', 'folr2', 'des', 'col8a1', 
    'aspn', 'kl', 'nap1l5', 'cubn', 'prx', 'epcam', 'fxyd6', 'aldh1b1', 'gc', 'fcnb', 
    'etv1', 'ppargc1a', 'ascl1', 'myoz2', 'cnfn', 'hpx', 'a330076h08rik', 'aqp5', 
    'prss3', 'myoz1', 'dcdc2a', 'emb', 'ucma', 'pgam2', 'mlc1', 'ifitm6', 'scin', 
    'otor', 'tagln', 'cldn3', 'mpo', 'angpt2', 'ms4a6c', 'aqp3', 'ccdc153', 'slc17a7'
]

# Load and preprocess mouse data
def load_mouse_data():
    pred_mouse = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/mouse_inferences/fold_{z}.h5ad')
        for z in range(4)
    }
    
    ground_truth_mouse = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/mouse_folds/fold_{z}_st.h5ad')
        for z in range(4)
    }
    
    # Preprocess ground truth data
    for z in ground_truth_mouse:
        sc.pp.log1p(ground_truth_mouse[z])
        ground_truth_mouse[z].X = np.array(ground_truth_mouse[z].X.todense())
    
    return pred_mouse, ground_truth_mouse

# Calculate mouse-specific correlations
def calculate_mouse_correlations(pred_mouse, ground_truth_mouse, common_mouse):
    mouse_fold_to_in_sample_corrs = {}
    mouse_fold_to_in_sample_scores = {}
    
    for z in range(4):
        in_sample_corrs = {}
        in_sample_scores = {}
        for g in common_mouse:
            t = np.array(ground_truth_mouse[z][::,g].X.squeeze())
            p = np.array(pred_mouse[z][::,g].X.squeeze())
            in_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
            in_sample_scores[g] = p.std()
        mouse_fold_to_in_sample_corrs[z] = in_sample_corrs
        mouse_fold_to_in_sample_scores[z] = in_sample_scores
    
    # Calculate weighted averages
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_mouse[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_mouse[z].shape[0] / total_cells
    
    mouse_avg_in_sample_corrs = {}
    mouse_avg_in_sample_scores = {}
    for g in common_mouse:
        mouse_avg_in_sample_corrs[g] = sum(mouse_fold_to_in_sample_corrs[z][g] * fold_to_prop[z] for z in range(4))
        mouse_avg_in_sample_scores[g] = sum(mouse_fold_to_in_sample_scores[z][g] * fold_to_prop[z] for z in range(4))
    
    return mouse_avg_in_sample_corrs, mouse_avg_in_sample_scores


In [None]:
# Spatial visualization functions
def do_norm(x):
    """Normalize array to [0,1] range."""
    return (x - x.min()) / (x.max() - x.min())

def do_norm2(x, min_val, max_val):
    """Normalize array to [0,1] range using provided min/max values."""
    return (x - min_val) / (max_val - min_val)

def plot_spatial_gene_expression(pred_data, true_data, genes, sample_keys, output_dir='final_figures_schaf_revision_pngs'):
    """Generate spatial plots for specific genes with custom colormaps."""
    g_to_cmaps = {
        'krt19': mpl.cm.Greens,  # MBC marker
        'col1a2': mpl.cm.YlOrBr,  # Fibroblast marker
        'apoc1': mpl.cm.Blues,   # Macrophage marker
        'pecam1': mpl.cm.Reds    # Endothelial/vascular marker
    }
    
    for k in sample_keys:
        fig, axs = plt.subplots(2, len(genes), 
                               figsize=(5 * len(genes), 10))
        
        for i, g in enumerate(genes):
            this_cmap = g_to_cmaps[g]
            ax_pred = axs[1, i]
            ax_true = axs[0, i]
            
            # Set up axes
            ax_pred.set_aspect('auto')
            ax_true.set_aspect('auto')
            if i == 0:
                ax_pred.set_ylabel("SCHAF\nInferred", size='x-large')
                ax_true.set_ylabel("Ground Truth", size='x-large')
            ax_pred.set_xticks([])
            ax_pred.set_yticks([])
            ax_true.set_xticks([])
            ax_true.set_yticks([])
            
            # Get gene expression data
            g_ind = list(true_data[k].var.index).index(g)
            t = np.array(true_data[k][::, g].X).squeeze()
            p = np.array(pred_data[k][::, g].X).squeeze()
            
            # Normalize data
            t = do_norm(t)
            p = do_norm(p)
            the_min = min(t.min(), p.min())
            the_max = max(t.max(), p.max())
            p = do_norm2(p, the_min, the_max)
            t = do_norm2(t, the_min, the_max)
            
            # Calculate correlation
            corr = np.corrcoef(p, t)[0, 1]
            ax_true.set_title(f'{g.upper()}\nSpatial Corr. = {corr:.3f}')
            
            # Plot data
            ax_true.scatter(
                true_data[k].obs['y_spot'] // 50,
                true_data[k].obs['x_spot'] // 50,
                c=t,
                s=60,
                vmin=0,
                vmax=1,
                cmap=this_cmap
            )
            
            ax_pred.scatter(
                pred_data[k].obs['y_spot'] // 50,
                pred_data[k].obs['x_spot'] // 50,
                c=p,
                s=60,
                vmin=0,
                vmax=1,
                cmap=this_cmap
            )
        
        plt.tight_layout()
        plt.savefig(f'{output_dir}/spatial_gene_expression_{k}.png', 
                   dpi=400, transparent=True)
        plt.close()

def plot_spatial_correlation_stats(k_to_corrs, output_dir='final_figures_schaf_revision_pngs'):
    """Generate bar plots for spatial correlation statistics."""
    def plot_bar_from_dict(data_dict, title, y_label, filename):
        names = sorted(data_dict.keys(), key=lambda k: data_dict[k])
        values = [data_dict[k] for k in names]
        
        fig, ax = plt.subplots(figsize=(7, 5))
        plt.bar(names, values)
        plt.ylabel(y_label)
        plt.xlabel("Sample")
        plt.title(title, loc='left')
        plt.xticks(rotation=90)
        
        if y_label == 'Average Spatial Correlation':
            plt.yticks(ticks=np.arange(0, .41, .1), size='x-large',
                      labels=np.arange(0, .41, .1).round(2))
        else:
            plt.yticks(ticks=np.arange(0, 36, 5), size='x-large',
                      labels=np.arange(0, 36, 5).round(2))
        
        plt.tight_layout()
        plt.setp(ax.spines.values(), linewidth=2)
        plt.savefig(f'{output_dir}/{filename}.png', dpi=400, transparent=True)
        plt.savefig(f'{output_dir}/{filename}.pdf', transparent=True)
        plt.close()
    
    # Plot average correlations
    k_to_avg_corr = {k: np.mean(list(v.values())) for k, v in k_to_corrs.items()}
    plot_bar_from_dict(k_to_avg_corr, 'Average Spatial Correlation by Sample',
                      'Average Spatial Correlation', 'spatial_corrs_avg')
    
    # Plot number of well-correlated genes
    k_to_num_good_genes = {k: np.sum(np.array(list(v.values()))>=.4) 
                          for k, v in k_to_corrs.items()}
    plot_bar_from_dict(k_to_num_good_genes, 'Well-Correlated Genes by Sample',
                      'Number of Genes\nwith Corr. > .4', 'spatial_corrs_count')
    
    # Plot percentage of well-correlated genes
    k_to_percent_good_genes = {k: np.sum(np.array(list(v.values()))>=.4) / float(len(v))
                              for k, v in k_to_corrs.items()}
    plot_bar_from_dict(k_to_percent_good_genes, 'Percentage of Well-Correlated Genes',
                      'Percentage of Genes\nwith Corr. > .4', 'spatial_corrs_percent')


In [None]:
# Cell type analysis functions
def get_hetero(cells):
    """Calculate heterogeneity scores for each gene."""
    res = {}
    for g in cells.var.index:
        to_see = np.array(cells[::,g].X).squeeze()
        to_see = (to_see - to_see.max()) / (to_see.max() - to_see.min())
        res[g] = to_see.std()
    return res

def load_cell_type_labels():
    """Load cell type labels and clusters."""
    # Load broad clusters
    broad_clusters = pd.read_csv('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/analysis/clustering/gene_expression_kmeans_10_clusters/clusters.csv')
    broad_clusters = broad_clusters.set_index('Barcode')
    
    # Load fold-specific labels
    xenium_in_sample_fold_to_inferred_labels = {}
    mouse_fold_to_inferred_labels = {}
    for f in range(4):
        xenium_in_sample_fold_to_inferred_labels[f] = np.load(f'cancer_fold_to_new_labels/{f}.npy')
        mouse_fold_to_inferred_labels[f] = np.load(f'mouse_fold_to_new_labels/{f}.npy')
    
    return broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels

def calculate_cell_type_means(pred_data, true_data, cell_type_labels, common_genes):
    """Calculate mean expression for each cell type."""
    cts = list(range(1, 11))
    
    # Initialize dictionaries
    ct_to_pred_mean = {}
    ct_to_true_mean = {}
    
    # Calculate means for each cell type
    for ct in cts:
        ct_mask_pred = cell_type_labels == ct
        ct_mask_true = cell_type_labels == ct
        
        if np.sum(ct_mask_pred) > 0:
            ct_to_pred_mean[ct] = pred_data[ct_mask_pred, common_genes].X.mean(axis=0)
        if np.sum(ct_mask_true) > 0:
            ct_to_true_mean[ct] = true_data[ct_mask_true, common_genes].X.mean(axis=0)
    
    return ct_to_pred_mean, ct_to_true_mean

def calculate_global_stats(ground_truth_data, common_genes):
    """Calculate global mean and standard deviation across all cells."""
    num_genes = len(common_genes)
    all_means = np.zeros(num_genes)
    total_samples = 0
    
    # Calculate weighted means
    for k, v in ground_truth_data.items():
        all_means = all_means + v.shape[0] * (ground_truth_data[k][::,common_genes].X.mean(axis=0))
        total_samples += v.shape[0]
    all_means = all_means / total_samples
    
    # Calculate weighted variances
    all_vars = np.zeros(num_genes)
    for k, v in ground_truth_data.items():
        n = v.shape[0]
        curr_means = ground_truth_data[k][::,common_genes].X.mean(axis=0)
        all_vars += n * (ground_truth_data[k][::,common_genes].X.var(axis=0) + 
                        (curr_means - all_means)**2)
    all_vars = all_vars / total_samples
    
    all_stds = np.sqrt(all_vars)
    return all_means, all_stds

def load_merfish_data():
    """Load and preprocess MERFISH data."""
    # Load raw data
    g1_mouse_fish_measure = pd.read_csv('g1_new_cell_gene_matrix_with_centroids.csv')
    g2_mouse_fish_measure = pd.read_csv('g2_new_cell_gene_matrix_with_centroids.csv')
    
    # Process data
    for df in [g1_mouse_fish_measure, g2_mouse_fish_measure]:
        df.set_index('cell', inplace=True)
    
    # Extract coordinates
    coords = {
        'g1': {'x': g1_mouse_fish_measure['centroid_x'],
               'y': g1_mouse_fish_measure['centroid_y']},
        'g2': {'x': g2_mouse_fish_measure['centroid_x'],
               'y': g2_mouse_fish_measure['centroid_y']}
    }
    
    # Remove coordinate columns and create AnnData objects
    for df in [g1_mouse_fish_measure, g2_mouse_fish_measure]:
        df.drop(labels=['Tile', 'centroid_x', 'centroid_y'], axis=1, inplace=True)
    
    g1_adata = sc.AnnData(g1_mouse_fish_measure)
    g2_adata = sc.AnnData(g2_mouse_fish_measure)
    
    # Convert gene names to lowercase and log transform
    for adata in [g1_adata, g2_adata]:
        adata.var.index = [q.lower() for q in adata.var.index]
        sc.pp.log1p(adata)
    
    return g1_adata, g2_adata, coords

def load_out_of_sample_mouse_data():
    """Load out-of-sample mouse prediction data."""
    # Load predictions
    g1_pred = sc.read_h5ad('/mounts/stultzlab03_storage2/ccomiter/out_of_sample_mouse_infer_g1_best_res_fold2.h5ad')
    g2_pred = sc.read_h5ad('/mounts/stultzlab03_storage2/ccomiter/out_of_sample_mouse_infer_g2_final_022525.h5ad')
    
    # Calculate scores
    g1_scores = {g: np.array(g1_pred[::,g].X).squeeze().std() 
                 for g in g1_pred.var.index}
    g2_scores = {g: np.array(g2_pred[::,g].X).squeeze().std() 
                 for g in g2_pred.var.index}
    
    return g1_pred, g2_pred, g1_scores, g2_scores


In [None]:
# Cell type visualization functions
def plot_cell_type_expression_heatmap(ct_to_pred_mean, ct_to_true_mean, title, filename):
    """Generate heatmap of cell type-specific expression patterns."""
    # Prepare data
    cts = sorted(ct_to_pred_mean.keys())
    genes = ct_to_pred_mean[cts[0]].shape[0]
    
    pred_matrix = np.zeros((len(cts), genes))
    true_matrix = np.zeros((len(cts), genes))
    
    for i, ct in enumerate(cts):
        pred_matrix[i] = ct_to_pred_mean[ct]
        true_matrix[i] = ct_to_true_mean[ct]
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot heatmaps
    sns.heatmap(true_matrix, ax=ax1, cmap='coolwarm', center=0,
                xticklabels=False, yticklabels=cts)
    ax1.set_title('Ground Truth')
    ax1.set_ylabel('Cell Type')
    
    sns.heatmap(pred_matrix, ax=ax2, cmap='coolwarm', center=0,
                xticklabels=False, yticklabels=False)
    ax2.set_title('Predicted')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png', 
                dpi=400, transparent=True)
    plt.close()

def plot_cell_type_heterogeneity(hetero_scores, cell_types, title, filename):
    """Generate violin plots of cell type heterogeneity."""
    plt.figure(figsize=(10, 6))
    
    # Prepare data
    data = []
    labels = []
    for ct in sorted(cell_types):
        if ct in hetero_scores:
            data.append(list(hetero_scores[ct].values()))
            labels.extend([f'CT{ct}'] * len(hetero_scores[ct]))
    
    # Create violin plot
    sns.violinplot(data=data)
    plt.xticks(range(len(cell_types)), [f'CT{ct}' for ct in sorted(cell_types)],
               rotation=45)
    plt.ylabel('Heterogeneity Score')
    plt.title(title)
    
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()

def plot_cell_type_spatial(pred_data, true_data, cell_type_labels, 
                          coords, title, filename):
    """Generate spatial plots colored by cell type."""
    # Set up colors for cell types
    n_types = len(np.unique(cell_type_labels))
    colors = plt.cm.tab20(np.linspace(0, 1, n_types))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot ground truth
    scatter1 = ax1.scatter(coords['x'], coords['y'],
                          c=cell_type_labels, cmap='tab20',
                          s=10, alpha=0.7)
    ax1.set_title('Ground Truth')
    ax1.set_xticks([])
    ax1.set_yticks([])
    
    # Plot predictions
    scatter2 = ax2.scatter(coords['x'], coords['y'],
                          c=cell_type_labels, cmap='tab20',
                          s=10, alpha=0.7)
    ax2.set_title('Predicted')
    ax2.set_xticks([])
    ax2.set_yticks([])
    
    # Add colorbar
    plt.colorbar(scatter1, ax=ax1, label='Cell Type')
    plt.colorbar(scatter2, ax=ax2, label='Cell Type')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()

def plot_cell_type_correlations(pred_data, true_data, cell_type_labels,
                               common_genes, title, filename):
    """Generate correlation plots for each cell type."""
    cts = sorted(np.unique(cell_type_labels))
    n_cols = min(4, len(cts))
    n_rows = (len(cts) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, 
                            figsize=(5*n_cols, 5*n_rows))
    axes = axes.flatten()
    
    for i, ct in enumerate(cts):
        mask = cell_type_labels == ct
        if np.sum(mask) > 0:
            x = true_data[mask, common_genes].X.mean(axis=0)
            y = pred_data[mask, common_genes].X.mean(axis=0)
            
            axes[i].scatter(x, y, alpha=0.5, s=10)
            axes[i].set_title(f'Cell Type {ct}')
            
            # Add correlation coefficient
            corr = np.corrcoef(x, y)[0, 1]
            axes[i].text(0.05, 0.95, f'r = {corr:.3f}',
                        transform=axes[i].transAxes)
            
            # Add diagonal line
            lims = [
                np.min([axes[i].get_xlim(), axes[i].get_ylim()]),
                np.max([axes[i].get_xlim(), axes[i].get_ylim()]),
            ]
            axes[i].plot(lims, lims, 'k--', alpha=0.5)
    
    # Remove empty subplots
    for i in range(len(cts), len(axes)):
        fig.delaxes(axes[i])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()


In [None]:
# Main Analysis Pipeline

def main():
    """
    Main analysis pipeline for processing and visualizing SCHAF results.
    """
    logging.info('Starting main analysis pipeline')
    
    try:
        # Process each dataset
        datasets = [
            ('HTAPP', HTAPP_SAMPLES, HTAPP_GENES),
            ('Placenta', PLACENTA_SAMPLES, PLACENTA_GENES),
            ('MSKCC', LUNG_CANCER_SAMPLES, LUNG_CANCER_GENES)
        ]
        
        for dataset_name, samples, gene_map in datasets:
            logging.info(f'Processing {dataset_name} dataset')
            
            # Initialize containers for dataset-level metrics
            all_emd_values = []
            all_correlations = []
            all_heterogeneity = []
            fold_heatmaps = {}
            fold_weights = {}
            
            # Process each sample
            for sample_id in tqdm(samples, desc=f'Processing {dataset_name} samples'):
                # Load and preprocess data
                data = load_data(XENIUM_DIR, sample_id)
                processed_data = preprocess_data(data)
                
                # Calculate cell type means
                true_means = calculate_cell_type_means(
                    processed_data['ground_truth_norm'],
                    processed_data['cell_types']
                )
                pred_means = calculate_cell_type_means(
                    processed_data['predictions_norm'],
                    processed_data['cell_types']
                )
                
                # Calculate correlations
                corr_matrix = calculate_pairwise_correlations(
                    pred_means, true_means,
                    list(range(1, 11))
                )
                
                # Plot correlation matrix
                plot_correlation_matrix(
                    corr_matrix,
                    f'{dataset_name} - Sample {sample_id}',
                    f'correlations_{dataset_name.lower()}_{sample_id}'
                )
                
                # Calculate and plot correlation heatmaps
                pred_heatmap, true_heatmap, cluster_labels = calculate_correlation_heatmap(
                    processed_data['predictions_norm'],
                    processed_data['ground_truth_norm']
                )
                fold_heatmaps[sample_id] = (pred_heatmap, true_heatmap)
                fold_weights[sample_id] = processed_data['ground_truth_norm'].shape[0]
                
                fig = plot_correlation_heatmaps(
                    pred_heatmap, true_heatmap,
                    sample_id, dataset_name
                )
                safe_save_figure(
                    fig,
                    f'final_figures_schaf_revision_pngs/heatmaps_{dataset_name.lower()}_{sample_id}.png'
                )
                
                # Calculate EMD for each gene
                gene_name = gene_map[sample_id]
                true_hist = processed_data['ground_truth_norm'][:, 0]  # First gene
                pred_hist = processed_data['predictions_norm'][:, 0]
                
                bins = np.linspace(0, max(np.max(true_hist), np.max(pred_hist)), 50)
                emd = calculate_emd(true_hist, pred_hist, bins)
                all_emd_values.append(emd)
                
                # Plot expression distributions
                fig = compare_expression_distributions(
                    true_hist, pred_hist,
                    gene_name, sample_id,
                    dataset_name,
                    round(np.mean(corr_matrix.diagonal()), 3),
                    round(emd, 3)
                )
                safe_save_figure(
                    fig,
                    f'final_figures_schaf_revision_pngs/distributions_{dataset_name.lower()}_{sample_id}.png'
                )
                
                # Plot spatial expression for marker genes
                marker_genes = list(gene_map.values())
                fig = plot_marker_genes_panel(
                    processed_data['ground_truth_norm'],
                    processed_data['predictions_norm'],
                    marker_genes,
                    sample_id,
                    dataset_name,
                    MARKER_GENE_CMAPS
                )
                safe_save_figure(
                    fig,
                    f'final_figures_schaf_revision_pngs/marker_genes_{dataset_name.lower()}_{sample_id}.png'
                )
                
                # Analyze gene correlations
                true_corr, pred_corr = analyze_gene_correlations(
                    processed_data['ground_truth_norm'],
                    processed_data['predictions_norm'],
                    marker_genes
                )
                
                fig = plot_gene_correlation_comparison(
                    true_corr, pred_corr,
                    marker_genes,
                    f'{dataset_name} - Sample {sample_id}'
                )
                safe_save_figure(
                    fig,
                    f'final_figures_schaf_revision_pngs/gene_correlations_{dataset_name.lower()}_{sample_id}.png'
                )
                
                # Analyze gene programs
                # Example program: all marker genes for this dataset
                fig, program_corr = analyze_gene_program(
                    processed_data['ground_truth_norm'],
                    processed_data['predictions_norm'],
                    marker_genes,
                    'Marker Genes Program',
                    f'{dataset_name} - Sample {sample_id}'
                )
                safe_save_figure(
                    fig,
                    f'final_figures_schaf_revision_pngs/gene_program_{dataset_name.lower()}_{sample_id}.png'
                )
                
                # Calculate heterogeneity
                true_het = calculate_heterogeneity(
                    processed_data['ground_truth_norm'],
                    processed_data['cell_types']
                )
                all_heterogeneity.extend(list(true_het.values()))
            
            # Calculate and plot weighted average heatmaps
            weighted_pred, weighted_true = calculate_weighted_heatmaps(
                fold_heatmaps, fold_weights
            )
            fig = plot_correlation_heatmaps(
                weighted_pred, weighted_true,
                'Weighted Average', dataset_name
            )
            safe_save_figure(
                fig,
                f'final_figures_schaf_revision_pngs/weighted_heatmaps_{dataset_name.lower()}.png'
            )
            
            # Plot dataset-level metrics
            plot_emd_distribution(
                all_emd_values,
                dataset_name,
                f'emd_distribution_{dataset_name.lower()}'
            )
            
            # Plot correlation vs prevalence
            plot_correlation_vs_prevalence(
                all_correlations,
                [processed_data['type_frequencies'][ct] for ct in range(1, 11)],
                range(10),
                dataset_name,
                f'correlation_vs_prevalence_{dataset_name.lower()}'
            )
        
        logging.info('Analysis pipeline completed successfully')
        
    except Exception as e:
        logging.error(f'Error in main analysis pipeline: {str(e)}')
        raise

if __name__ == '__main__':
    main()


In [None]:
# Spatial Visualization Functions

@log_function_call
def normalize_expression(expression_values, method='minmax'):
    """
    Normalize gene expression values.
    
    Args:
        expression_values: Array of expression values
        method: Normalization method ('minmax' or 'standard')
        
    Returns:
        numpy.ndarray: Normalized expression values
    """
    if method == 'minmax':
        min_val = np.min(expression_values)
        max_val = np.max(expression_values)
        if max_val == min_val:
            return np.zeros_like(expression_values)
        return (expression_values - min_val) / (max_val - min_val)
    elif method == 'standard':
        mean = np.mean(expression_values)
        std = np.std(expression_values)
        if std == 0:
            return np.zeros_like(expression_values)
        return (expression_values - mean) / std
    else:
        raise ValueError(f"Unknown normalization method: {method}")

@log_function_call
def plot_spatial_expression(true_data, pred_data, gene_name, 
                          sample_id, dataset_name, 
                          cmap='viridis', size=20):
    """
    Plot spatial expression patterns for true and predicted data.
    
    Args:
        true_data: AnnData object with true expression
        pred_data: AnnData object with predicted expression
        gene_name: Name of gene to plot
        sample_id: Sample identifier
        dataset_name: Name of dataset
        cmap: Colormap to use
        size: Size of scatter points
    """
    fig, (ax_true, ax_pred) = plt.subplots(1, 2, 
                                          figsize=(10, 5),
                                          constrained_layout=True)
    
    # Get expression values
    true_expr = np.array(true_data[:, gene_name].X).squeeze()
    pred_expr = np.array(pred_data[:, gene_name].X).squeeze()
    
    # Normalize values together
    min_val = min(true_expr.min(), pred_expr.min())
    max_val = max(true_expr.max(), pred_expr.max())
    true_expr_norm = (true_expr - min_val) / (max_val - min_val)
    pred_expr_norm = (pred_expr - min_val) / (max_val - min_val)
    
    # Plot true expression
    ax_true.scatter(true_data.obs['y'],
                   true_data.obs['x'],
                   c=true_expr_norm,
                   s=size,
                   vmin=0,
                   vmax=1,
                   cmap=cmap)
    ax_true.invert_yaxis()
    ax_true.set_title('Ground Truth', size='xx-large')
    ax_true.set_xticks([])
    ax_true.set_yticks([])
    
    # Plot predicted expression
    ax_pred.scatter(pred_data.obs['y'],
                   pred_data.obs['x'],
                   c=pred_expr_norm,
                   s=size,
                   vmin=0,
                   vmax=1,
                   cmap=cmap)
    ax_pred.invert_yaxis()
    ax_pred.set_title('SCHAF Predicted', size='xx-large')
    ax_pred.set_xticks([])
    ax_pred.set_yticks([])
    
    # Add correlation value
    corr = np.corrcoef(true_expr, pred_expr)[0, 1]
    plt.suptitle(f'{gene_name.upper()} - {dataset_name} {sample_id}\n' +
                f'Correlation: {corr:.3f}',
                size='xx-large')
    
    # Customize spines
    for ax in [ax_true, ax_pred]:
        ax.spines[['left', 'right', 'top', 'bottom']].set_visible(True)
        plt.setp(ax.spines.values(), linewidth=2)
    
    return fig

@log_function_call
def plot_marker_genes_panel(true_data, pred_data, marker_genes,
                          sample_id, dataset_name, cmaps=None):
    """
    Create a panel of spatial plots for marker genes.
    
    Args:
        true_data: AnnData object with true expression
        pred_data: AnnData object with predicted expression
        marker_genes: List of marker genes to plot
        sample_id: Sample identifier
        dataset_name: Name of dataset
        cmaps: Dictionary mapping genes to colormaps
    """
    n_genes = len(marker_genes)
    fig, axs = plt.subplots(2, n_genes,
                           figsize=(5 * n_genes, 10),
                           constrained_layout=True)
    
    for i, gene in enumerate(marker_genes):
        # Get colormap
        cmap = cmaps.get(gene.lower(), 'viridis') if cmaps else 'viridis'
        
        # Get and normalize expression values
        true_expr = np.array(true_data[:, gene].X).squeeze()
        pred_expr = np.array(pred_data[:, gene].X).squeeze()
        
        true_expr_norm = normalize_expression(true_expr)
        pred_expr_norm = normalize_expression(pred_expr)
        
        # Plot true expression
        axs[0, i].scatter(true_data.obs['y'],
                         true_data.obs['x'],
                         c=true_expr_norm,
                         s=20,
                         vmin=0,
                         vmax=1,
                         cmap=cmap)
        axs[0, i].invert_yaxis()
        axs[0, i].set_title(gene.upper(), size='xx-large')
        axs[0, i].set_xticks([])
        axs[0, i].set_yticks([])
        
        if i == 0:
            axs[0, i].set_ylabel('Ground Truth', size='xx-large')
        
        # Plot predicted expression
        axs[1, i].scatter(pred_data.obs['y'],
                         pred_data.obs['x'],
                         c=pred_expr_norm,
                         s=20,
                         vmin=0,
                         vmax=1,
                         cmap=cmap)
        axs[1, i].invert_yaxis()
        axs[1, i].set_xticks([])
        axs[1, i].set_yticks([])
        
        if i == 0:
            axs[1, i].set_ylabel('SCHAF Predicted', size='xx-large')
        
        # Customize spines
        for ax in [axs[0, i], axs[1, i]]:
            ax.spines[['left', 'right', 'top', 'bottom']].set_visible(True)
            plt.setp(ax.spines.values(), linewidth=2)
    
    plt.suptitle(f'{dataset_name} {sample_id} Marker Genes',
                size='xx-large')
    
    return fig

# Define standard colormaps for marker genes
MARKER_GENE_CMAPS = {
    'krt19': plt.cm.Greens,
    'col1a2': plt.cm.YlOrBr,
    'apoc1': plt.cm.Blues,
    'pecam1': plt.cm.Reds,
    'cd3d': plt.cm.Purples,
    'cd8a': plt.cm.Oranges,
    'cd4': plt.cm.RdPu,
    'cd19': plt.cm.BuGn
}


In [None]:
# Hierarchical Clustering and Heatmap Functions

from sklearn.cluster import AgglomerativeClustering
import seaborn as sns

@log_function_call
def calculate_correlation_heatmap(pred_data, true_data, 
                                cluster=True, n_clusters=2):
    """
    Calculate correlation heatmap with optional clustering.
    
    Args:
        pred_data: Predicted expression matrix
        true_data: True expression matrix
        cluster: Whether to perform hierarchical clustering
        n_clusters: Number of clusters for hierarchical clustering
        
    Returns:
        tuple: (pred_heatmap, true_heatmap, cluster_labels)
    """
    # Calculate correlation matrices
    pred_heatmap = np.corrcoef(pred_data, rowvar=False)
    true_heatmap = np.corrcoef(true_data, rowvar=False)
    
    if cluster:
        # Perform hierarchical clustering
        clustering = AgglomerativeClustering(
            n_clusters=n_clusters,
            affinity='euclidean',
            linkage='ward'
        )
        labels = clustering.fit_predict(true_data.T)
        
        # Reorder matrices based on clustering
        cluster_indices = []
        for i in range(n_clusters):
            cluster_indices.extend(
                [j for j, l in enumerate(labels) if l == i]
            )
        
        pred_heatmap = pred_heatmap[cluster_indices][:, cluster_indices]
        true_heatmap = true_heatmap[cluster_indices][:, cluster_indices]
        
        return pred_heatmap, true_heatmap, labels
    
    return pred_heatmap, true_heatmap, None

@log_function_call
def plot_correlation_heatmaps(pred_heatmap, true_heatmap,
                            sample_id, dataset_name):
    """
    Plot correlation heatmaps for predicted and true data.
    
    Args:
        pred_heatmap: Correlation matrix for predicted data
        true_heatmap: Correlation matrix for true data
        sample_id: Sample identifier
        dataset_name: Name of dataset
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    # Plot true correlation heatmap
    sns.heatmap(true_heatmap,
                ax=ax1,
                cmap='seismic',
                center=0,
                square=True,
                cbar_kws={'label': 'Correlation'})
    ax1.set_title('Ground Truth Correlations', size='xx-large')
    
    # Plot predicted correlation heatmap
    sns.heatmap(pred_heatmap,
                ax=ax2,
                cmap='seismic',
                center=0,
                square=True,
                cbar_kws={'label': 'Correlation'})
    ax2.set_title('SCHAF Predicted Correlations', size='xx-large')
    
    # Calculate metacorrelation
    metacorr = np.corrcoef(pred_heatmap.reshape(-1),
                          true_heatmap.reshape(-1))[0, 1]
    
    plt.suptitle(
        f'{dataset_name} {sample_id}\n' +
        f'Meta-correlation: {metacorr:.3f}',
        size='xx-large'
    )
    
    return fig

@log_function_call
def calculate_weighted_heatmaps(fold_data, weights):
    """
    Calculate weighted average heatmaps across folds.
    
    Args:
        fold_data: Dictionary mapping folds to (pred_heatmap, true_heatmap)
        weights: Dictionary mapping folds to weights
        
    Returns:
        tuple: (weighted_pred_heatmap, weighted_true_heatmap)
    """
    total_weight = sum(weights.values())
    
    # Initialize weighted averages
    sample_shape = next(iter(fold_data.values()))[0].shape
    weighted_pred = np.zeros(sample_shape)
    weighted_true = np.zeros(sample_shape)
    
    # Calculate weighted sums
    for fold, (pred, true) in fold_data.items():
        weight = weights[fold] / total_weight
        weighted_pred += pred * weight
        weighted_true += true * weight
    
    return weighted_pred, weighted_true

@log_function_call
def plot_gene_correlation_matrix(true_data, pred_data, 
                               genes, sample_id, dataset_name):
    """
    Plot correlation matrix for specific genes.
    
    Args:
        true_data: True expression data
        pred_data: Predicted expression data
        genes: List of genes to include
        sample_id: Sample identifier
        dataset_name: Name of dataset
    """
    # Extract expression data for selected genes
    true_expr = true_data[:, genes].X
    pred_expr = pred_data[:, genes].X
    
    # Calculate correlations
    correlations = np.zeros((len(genes), len(genes)))
    for i, gene1 in enumerate(genes):
        for j, gene2 in enumerate(genes):
            true_corr = np.corrcoef(
                true_data[:, gene1].X.squeeze(),
                true_data[:, gene2].X.squeeze()
            )[0, 1]
            pred_corr = np.corrcoef(
                pred_data[:, gene1].X.squeeze(),
                pred_data[:, gene2].X.squeeze()
            )[0, 1]
            correlations[i, j] = np.abs(true_corr - pred_corr)
    
    # Plot heatmap
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(correlations,
                xticklabels=genes,
                yticklabels=genes,
                cmap='YlOrRd',
                center=0.5,
                square=True,
                ax=ax)
    
    plt.title(f'{dataset_name} {sample_id}\nGene Correlation Differences',
              size='xx-large')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    return fig


In [None]:
# Gene Analysis Functions

@log_function_call
def analyze_gene_correlations(true_data, pred_data, genes):
    """
    Analyze correlations between genes in true and predicted data.
    
    Args:
        true_data: True expression data
        pred_data: Predicted expression data
        genes: List of genes to analyze
        
    Returns:
        tuple: (true_corr_matrix, pred_corr_matrix)
    """
    n_genes = len(genes)
    true_corr = np.zeros((n_genes, n_genes))
    pred_corr = np.zeros((n_genes, n_genes))
    
    for i, gene1 in enumerate(genes):
        for j, gene2 in enumerate(genes):
            true_expr1 = np.array(true_data[:, gene1].X).squeeze()
            true_expr2 = np.array(true_data[:, gene2].X).squeeze()
            pred_expr1 = np.array(pred_data[:, gene1].X).squeeze()
            pred_expr2 = np.array(pred_data[:, gene2].X).squeeze()
            
            true_corr[i, j] = np.corrcoef(true_expr1, true_expr2)[0, 1]
            pred_corr[i, j] = np.corrcoef(pred_expr1, pred_expr2)[0, 1]
    
    return true_corr, pred_corr

@log_function_call
def plot_gene_correlation_comparison(true_corr, pred_corr, 
                                   genes, dataset_name):
    """
    Plot comparison of gene correlations between true and predicted data.
    
    Args:
        true_corr: True correlation matrix
        pred_corr: Predicted correlation matrix
        genes: List of gene names
        dataset_name: Name of dataset
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    # Plot true correlations
    sns.heatmap(true_corr, 
                xticklabels=genes,
                yticklabels=genes,
                cmap='RdBu_r',
                center=0,
                ax=ax1)
    ax1.set_title('Ground Truth Gene Correlations')
    
    # Plot predicted correlations
    sns.heatmap(pred_corr,
                xticklabels=genes,
                yticklabels=genes,
                cmap='RdBu_r',
                center=0,
                ax=ax2)
    ax2.set_title('Predicted Gene Correlations')
    
    # Calculate overall correlation
    corr = np.corrcoef(true_corr.flatten(), 
                       pred_corr.flatten())[0, 1]
    
    plt.suptitle(f'{dataset_name}\nGene Correlation Comparison\n' +
                 f'Overall Correlation: {corr:.3f}',
                 size='xx-large')
    
    return fig

@log_function_call
def analyze_gene_program(true_data, pred_data, gene_set,
                        program_name, dataset_name):
    """
    Analyze a gene program (set of related genes).
    
    Args:
        true_data: True expression data
        pred_data: Predicted expression data
        gene_set: List of genes in the program
        program_name: Name of the gene program
        dataset_name: Name of dataset
    """
    # Calculate program scores
    true_scores = np.mean([np.array(true_data[:, g].X).squeeze() 
                          for g in gene_set], axis=0)
    pred_scores = np.mean([np.array(pred_data[:, g].X).squeeze() 
                          for g in gene_set], axis=0)
    
    # Calculate correlation
    program_corr = np.corrcoef(true_scores, pred_scores)[0, 1]
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Scatter plot
    ax1.scatter(true_scores, pred_scores, alpha=0.5)
    ax1.set_xlabel('True Program Score')
    ax1.set_ylabel('Predicted Program Score')
    
    # Add correlation line
    min_val = min(true_scores.min(), pred_scores.min())
    max_val = max(true_scores.max(), pred_scores.max())
    ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
    
    # Distribution plot
    sns.kdeplot(data=true_scores, ax=ax2, label='Ground Truth')
    sns.kdeplot(data=pred_scores, ax=ax2, label='Predicted')
    ax2.set_xlabel('Program Score')
    ax2.set_ylabel('Density')
    ax2.legend()
    
    plt.suptitle(f'{dataset_name} - {program_name}\n' +
                 f'Program Correlation: {program_corr:.3f}',
                 size='xx-large')
    
    return fig, program_corr


In [None]:
# Advanced correlation analysis functions

@log_function_call
def calculate_pairwise_correlations(pred_means, true_means, cell_types):
    """
    Calculate pairwise correlations between predicted and true expression
    for all cell type combinations.
    
    Args:
        pred_means: Dictionary mapping cell types to predicted expression
        true_means: Dictionary mapping cell types to true expression
        cell_types: List of cell type identifiers
    
    Returns:
        numpy.ndarray: Correlation matrix (cell_types × cell_types)
    """
    num_types = len(cell_types)
    corr_matrix = np.zeros((num_types, num_types))
    
    for i, ct1 in enumerate(cell_types):
        for j, ct2 in enumerate(cell_types):
            if ct1 in true_means and ct2 in pred_means:
                corr = np.corrcoef(true_means[ct1], 
                                 pred_means[ct2])[0, 1]
                corr_matrix[i, j] = corr
    
    return np.nan_to_num(corr_matrix)

@log_function_call
def plot_correlation_matrix(corr_matrix, title, filename):
    """
    Plot correlation matrix as a heatmap.
    
    Args:
        corr_matrix: numpy.ndarray of correlations
        title: Plot title
        filename: Output filename
    """
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Create heatmap
    im = ax.imshow(corr_matrix, cmap='seismic', aspect='auto')
    
    # Add colorbar
    plt.colorbar(im)
    
    # Customize plot
    ax.set_title(title, size='xx-large', loc='left')
    ax.set_xlabel('Predicted Cell Type', size='xx-large')
    ax.set_ylabel('True Cell Type', size='xx-large')
    
    # Add cell type labels
    cell_types = range(1, corr_matrix.shape[0] + 1)
    ax.set_xticks(range(len(cell_types)))
    ax.set_yticks(range(len(cell_types)))
    ax.set_xticklabels(cell_types, size='x-large')
    ax.set_yticklabels(cell_types, size='x-large')
    
    plt.tight_layout()
    safe_save_figure(fig, 
                    f'final_figures_schaf_revision_pngs/{filename}.png',
                    dpi=FIGURE_PARAMS['dpi'],
                    transparent=FIGURE_PARAMS['transparent'])

@log_function_call
def plot_correlation_vs_prevalence(correlations, prevalences, 
                                 cell_types, title, filename):
    """
    Plot correlation vs prevalence scatter plot.
    
    Args:
        correlations: List of correlation values
        prevalences: List of prevalence values
        cell_types: List of cell type indices
        title: Plot title
        filename: Output filename
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Create scatter plot
    scatter = ax.scatter(prevalences, correlations,
                        c=cell_types, cmap='rainbow', s=100)
    
    # Customize plot
    ax.set_title(title, size='xx-large', loc='left')
    ax.set_xlabel('Prevalence', size='xx-large')
    ax.set_ylabel('Correlation', size='xx-large')
    
    # Set axis limits and ticks
    ax.spines[['right', 'top']].set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_xlim(0, max(prevalences) * 1.1)
    
    # Add ticks
    ax.set_yticks(np.arange(0, 1.01, 0.1))
    ax.set_xticks(np.arange(0, max(prevalences) * 1.1, 0.05))
    
    plt.tight_layout()
    safe_save_figure(fig,
                    f'final_figures_schaf_revision_pngs/{filename}.png',
                    dpi=FIGURE_PARAMS['dpi'],
                    transparent=FIGURE_PARAMS['transparent'])

@log_function_call
def calculate_weighted_correlations(pred_data, true_data, weights):
    """
    Calculate weighted correlations across multiple datasets.
    
    Args:
        pred_data: Dictionary mapping fold/dataset to predicted data
        true_data: Dictionary mapping fold/dataset to true data
        weights: Dictionary mapping fold/dataset to weights
    
    Returns:
        tuple: (correlations, weighted_average)
    """
    correlations = {}
    total_weight = sum(weights.values())
    
    for key in pred_data:
        if key in true_data and key in weights:
            corr = np.corrcoef(pred_data[key].flatten(),
                             true_data[key].flatten())[0, 1]
            correlations[key] = corr
    
    weighted_avg = sum(correlations[k] * weights[k] / total_weight 
                      for k in correlations)
    
    return correlations, weighted_avg


In [None]:
# Earth Mover's Distance (EMD) analysis functions

@log_function_call
def calculate_emd(hist1, hist2, bins):
    """
    Calculate Earth Mover's Distance between two histograms.
    
    Args:
        hist1: First histogram values
        hist2: Second histogram values
        bins: Bin edges for histograms
        
    Returns:
        float: EMD value
    """
    # Normalize histograms
    hist1_norm = hist1 / np.sum(hist1)
    hist2_norm = hist2 / np.sum(hist2)
    
    # Calculate cumulative distributions
    cdf1 = np.cumsum(hist1_norm)
    cdf2 = np.cumsum(hist2_norm)
    
    # Calculate EMD as area between CDFs
    emd = np.sum(np.abs(cdf1 - cdf2)) * (bins[1] - bins[0])
    return emd

@log_function_call
def plot_emd_distribution(emd_values, dataset_name, filename):
    """
    Plot distribution of EMD values.
    
    Args:
        emd_values: List of EMD values
        dataset_name: Name of dataset for title
        filename: Output filename
    """
    fig, ax = plt.subplots(figsize=(5, 5))
    
    # Calculate mean EMD
    mean_emd = round(np.mean(emd_values), 3)
    
    # Create histogram
    plt.hist(emd_values, 
            rwidth=0.7,
            bins=np.arange(0, max(emd_values) + 0.05, 0.03))
    
    # Customize plot
    plt.title(dataset_name, loc='left', size='xx-large')
    plt.xlabel('Earth Mover\'s Distance (EMD)', size='xx-large')
    plt.ylabel('Number of Genes', size='xx-large')
    
    # Add mean line
    plt.plot([1], [1], '--', c='white',
             label=f'Avg. EMD = {mean_emd}')
    plt.legend(fontsize='xx-large', loc='upper right')
    
    # Set ticks
    max_count = plt.gca().get_ylim()[1]
    plt.yticks(ticks=np.arange(0, max_count, 1000),
              labels=np.arange(0, max_count, 1000).round(0),
              size='x-large')
    plt.xticks(ticks=np.arange(0, max(emd_values) + 0.05, 0.25),
              size='x-large')
    
    # Customize spines
    plt.setp(ax.spines.values(), linewidth=2)
    
    plt.tight_layout()
    safe_save_figure(fig,
                    f'final_figures_schaf_revision_pngs/{filename}.png',
                    dpi=FIGURE_PARAMS['dpi'],
                    transparent=FIGURE_PARAMS['transparent'])

@log_function_call
def compare_expression_distributions(true_hist, pred_hist, 
                                  gene_name, sample_id,
                                  dataset_name, quality_score, emd_score):
    """
    Compare and plot true vs predicted expression distributions.
    
    Args:
        true_hist: Ground truth histogram values
        pred_hist: Predicted histogram values
        gene_name: Name of gene
        sample_id: Sample identifier
        dataset_name: Name of dataset
        quality_score: Quality score for the prediction
        emd_score: EMD score for the distributions
    """
    fig, ax = plt.subplots(figsize=(5, 5))
    
    # Calculate kernel density estimates
    kde_true = gaussian_kde(true_hist)
    kde_pred = gaussian_kde(pred_hist)
    
    # Generate x values for plotting
    x = np.linspace(min(min(true_hist), min(pred_hist)),
                   max(max(true_hist), max(pred_hist)),
                   100)
    
    # Plot KDE curves
    plt.plot(x, kde_true(x), linewidth=2, label='Ground Truth')
    plt.plot(x, kde_pred(x), linewidth=2, label='Predicted')
    
    # Fill under curves
    plt.fill_between(x, kde_true(x), alpha=0.5)
    plt.fill_between(x, kde_pred(x), alpha=0.5)
    
    # Customize plot
    plt.title(f'{gene_name.upper()} - {dataset_name} {sample_id}',
              loc='left', size='xx-large')
    plt.xlabel('Gene Expression', size='xx-large')
    plt.ylabel('KDE-Density', size='xx-large')
    
    # Add scores
    plt.plot([1], [1], '--', c='white',
             label=f'Dist. Match Score={quality_score}\nEMD={emd_score}')
    plt.legend(fontsize='xx-large', loc='upper right')
    
    # Set axis limits and ticks
    plt.xlim(-0.5, max(x))
    plt.ylim(0, max(max(kde_true(x)), max(kde_pred(x))))
    
    # Customize spines
    plt.setp(ax.spines.values(), linewidth=2)
    
    plt.tight_layout()
    return fig


In [None]:
# HTAPP visualization functions
def plot_htapp_marker_genes(pred_htapp, ground_truth_htapp, genes, keys, 
                          colormaps, title, filename):
    """Generate marker gene expression plots for HTAPP data."""
    for k in keys:
        fig, axs = plt.subplots(2, len(genes), 
                               figsize=(5*len(genes), 5*2))
        
        for i, g in enumerate(genes):
            this_cmap = colormaps[g]
            ax_pred = axs[1, i]
            ax_true = axs[0, i]
            
            # Set aspect ratio and labels
            ax_pred.set_aspect('auto')
            ax_true.set_aspect('auto')
            
            if i == 0:
                ax_pred.set_ylabel("SCHAF\nInferred", size='x-large')
                ax_true.set_ylabel("ExpSCR", size='x-large')
            
            ax_true.set_title(g.upper())
            
            # Get and normalize expression values
            t = np.array(ground_truth_htapp[k][::, g].X).squeeze()
            t = (t - t.min()) / (t.max() - t.min())
            
            p = np.array(pred_htapp[k][::, g].X).squeeze()
            p = (p - p.min()) / (p.max() - p.min())
            
            # Plot ground truth
            ax_true.scatter(
                ground_truth_htapp[k].obs['y'],
                ground_truth_htapp[k].obs['x'],
                c=t, s=20, vmin=0, vmax=1,
                cmap=this_cmap
            )
            ax_true.invert_yaxis()
            ax_true.set_xticks([])
            ax_true.set_yticks([])
            
            # Plot predictions
            ax_pred.scatter(
                pred_htapp[k].obs['y'],
                pred_htapp[k].obs['x'],
                c=p, s=20, vmin=0, vmax=1,
                cmap=this_cmap
            )
            ax_pred.invert_yaxis()
            ax_pred.set_xticks([])
            ax_pred.set_yticks([])
            
            # Add borders
            for ax in [ax_pred, ax_true]:
                ax.spines[['left', 'right', 'top', 'bottom']].set_visible(1)
                plt.setp(ax.spines.values(), linewidth=2)
        
        plt.tight_layout()
        plt.savefig(f'final_figures_schaf_revision_pngs/{filename}_{k}.png',
                    dpi=400, transparent=True)
        plt.close()

def plot_htapp_schematic(histology_images, title, filename):
    """Generate schematic overview of HTAPP regions."""
    fig, ax = plt.subplots(3, 4, figsize=(9, 9),
                          gridspec_kw={'hspace': 0, 'wspace': 0})
    
    for i, (k, v) in enumerate(histology_images.items()):
        a = i // 4
        b = i % 4
        # Flip x and y axes for correct orientation
        flipped_v = np.transpose(v, (1, 0, 2))
        ax[a, b].imshow(flipped_v)
        ax[a, b].axis('off')
        ax[a, b].set_title(f'HTAPP {k}', pad=15, loc='center')
    
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()

def calculate_htapp_metrics(pred_htapp, ground_truth_htapp):
    """Calculate performance metrics for HTAPP predictions."""
    from scipy.stats import wasserstein_distance
    
    htapp_key_to_scores = {}
    htapp_key_to_dists = {}
    
    for k in ground_truth_htapp:
        num_genes = ground_truth_htapp[k].shape[1]
        true_data = np.array(ground_truth_htapp[k].X)
        pred_data = pred_htapp[k].X
        
        # Calculate distances and scores
        dists = []
        scores = []
        for i in range(num_genes):
            pred_norm = (pred_data[::,i] - pred_data[::,i].min()) / (pred_data[::,i].max() - pred_data[::,i].min())
            dists.append(wasserstein_distance(pred_norm, true_data[::,i]))
            scores.append(pred_norm.std())
        
        htapp_key_to_scores[k] = np.array(scores)
        htapp_key_to_dists[k] = np.array(dists)
    
    return htapp_key_to_scores, htapp_key_to_dists

def plot_htapp_metrics(htapp_key_to_scores, htapp_key_to_dists, 
                      title, filename):
    """Plot performance metrics for HTAPP predictions."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot score distributions
    for k, scores in htapp_key_to_scores.items():
        sns.kdeplot(scores, ax=ax1, label=f'HTAPP {k}')
    ax1.set_title('Score Distributions')
    ax1.set_xlabel('Score')
    ax1.legend()
    
    # Plot distance distributions
    for k, dists in htapp_key_to_dists.items():
        sns.kdeplot(dists, ax=ax2, label=f'HTAPP {k}')
    ax2.set_title('Distance Distributions')
    ax2.set_xlabel('Wasserstein Distance')
    ax2.legend()
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()


In [None]:
# Cell type classification metrics and visualization
def get_freq(cts):
    """Calculate frequency of each cell type."""
    from collections import Counter
    total = float(len(cts))
    cts = Counter(cts)
    cts = {k: v / total for k, v in cts.items()}
    cts = np.array([cts[k] for k in range(1, 11, 1)])
    return cts

def get_perclass_acc(preds, trus):
    """Calculate per-class accuracy."""
    res = []
    for k in range(1, 11, 1):
        tps = preds[np.where(trus==k)]
        tts = trus[np.where(trus==k)]
        res.append((tps==tts).sum() / len(tts))
    return np.array(res)

# Cell type color definitions
def get_cell_type_colors():
    """Get color mapping for cell types."""
    colors = {
        'MBC': '#73d56d',
        'MBC_stem-like': '#146c18',
        'MBC_neuronal': '#39a13c',
        'MBC_chondroid': '#003b00',
        'Endothelial': '#fc0303',
        'Endothelial_sinusoidal': '#6d0000',
        'Endothelial_angiogenic': '#dc7014',
        'Endothelial_vascular': '#aa3700',
        'Endothelial vasc.': '#aa3700',
        'Fibroblast': '#ced208',
        'Chondrocyte': '#515900',
        'Smooth muscle_vascular': '#748000',
        'Smooth muscle vasc.': '#748000',
        'Stellate': '#323400',
        'Skeletal muscle': '#a0a800',
        'Adipocytes': '#bb6fc4',
        'Hepatocyte': '#f3a3f6',
        'Keratinocyte': '#53065f',
        'Neuron': '#873b92',
        'Macrophage': '#99d1fe',
        'Monocyte': '#387fb9',
        'Neutrophil': '#003365',
        'Erythrocyte': '#66a8dd',
        'Mast': '#1a588e',
        'B_plasma': '#f86652',
        'B': '#cf1917',
        'T': '#fbb2a1',
        'NK': '#860000',
        # Broad categories
        'Tumor': '#73d56d',
        'Normal': '#f3a3f6',
        'Vascular': '#fc0303',
        'Immune': '#99d1fe',
        'Fibrosis': '#ced208'
    }
    return colors

def make_color_ramp(colors):
    """Create a color ramp from a list of colors."""
    from colour import Color
    from matplotlib.colors import LinearSegmentedColormap
    color_ramp = LinearSegmentedColormap.from_list('custom_ramp', 
                                                  [Color(c).rgb for c in colors])
    return color_ramp

def plot_cell_type_accuracies(pred_labels, true_labels, dataset_names):
    """Plot cell type classification accuracies."""
    fig, axes = plt.subplots(1, len(dataset_names), 
                            figsize=(13.5, 4.3))
    
    for i, (preds, trues, name) in enumerate(zip(pred_labels, 
                                                true_labels, 
                                                dataset_names)):
        ax = axes[i]
        
        # Calculate metrics
        accuracies = get_perclass_acc(preds, trues)
        freqs = get_freq(trues)
        weighted_acc = (accuracies * freqs).sum()
        weighted_acc = round(weighted_acc, 3)
        
        # Create bar plot
        cts = list(range(1, 11, 1))
        ax.bar(cts, accuracies, 0.6)
        
        # Customize plot
        ax.set_title(f'{name}\nTotal Accuracy: {weighted_acc}', 
                    size='xx-large', loc='left')
        ax.set_xlabel('Cell Type', size='xx-large')
        if i == 0:
            ax.set_ylabel('Accuracy', size='xx-large')
        
        ax.spines[['right', 'top']].set_visible(False)
        ax.set_yticks(ticks=np.arange(0, 1.01, .1),
                     labels="" if i > 0 else np.round(np.arange(0, 1.01, .1), 2),
                     size='x-large')
        ax.set_xticks(ticks=np.arange(1, 10.01, 1),
                     labels=range(1, 11, 1),
                     size='x-large')
        plt.setp(ax.spines.values(), linewidth=2)
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/celltype_accuracies.png',
                dpi=400, transparent=True)
    plt.close()

def plot_cell_type_spatial_distribution(pred_data, cell_type_labels, 
                                      sample_id, title):
    """Plot spatial distribution of cell types."""
    # Get colors
    colors = get_cell_type_colors()
    cell_types = np.unique(cell_type_labels)
    color_list = [colors[ct] for ct in cell_types]
    cmap = make_color_ramp(color_list)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(5, 5))
    
    # Plot cells
    scatter = ax.scatter(
        pred_data.obs['y'],
        pred_data.obs['x'],
        c=cell_type_labels,
        cmap=cmap,
        s=5
    )
    
    # Customize plot
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines[['left', 'right', 'top', 'bottom']].set_visible(True)
    plt.setp(ax.spines.values(), linewidth=2)
    plt.gca().invert_yaxis()
    
    # Add legend
    clean_label = lambda a: (' '.join(a.split('_'))).title().replace('Mbc', 'MBC').replace('Nk', 'NK')
    patches = [mpatches.Patch(color=colors[ct], 
                            label=clean_label(ct)) 
              for ct in cell_types]
    ax.legend(handles=patches,
             bbox_to_anchor=(1.15, 1.),
             loc='upper left',
             fontsize='xx-large',
             frameon=False)
    
    plt.savefig(f'final_figures_schaf_revision_pngs/spatial_distribution_{sample_id}.png',
                dpi=400, transparent=True)
    plt.close()

def plot_pseudobulk_correlations(pred_means, true_means, cell_types, 
                               title, filename):
    """Plot pseudobulk correlations for each cell type."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    correlations = []
    for ct in cell_types:
        if ct in pred_means and ct in true_means:
            corr = np.corrcoef(pred_means[ct], true_means[ct])[0, 1]
            correlations.append(corr)
    
    # Create bar plot
    ax.bar(range(1, len(correlations) + 1), correlations, 0.6)
    
    # Customize plot
    ax.set_xlabel('Cell Type', size='xx-large')
    ax.set_ylabel('Pseudobulk Correlation', size='xx-large')
    ax.spines[['right', 'top']].set_visible(False)
    ax.set_xticks(range(1, len(correlations) + 1))
    ax.set_xticklabels(cell_types, rotation=45, ha='right')
    
    plt.title(title)
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{filename}.png',
                dpi=400, transparent=True)
    plt.close()


In [None]:
# Main execution pipeline
def main():
    """Execute the complete analysis pipeline."""
    print("Setting up output directory...")
    os.makedirs('final_figures_schaf_revision_pngs', exist_ok=True)
    
    print("\nLoading data...")
    # Load core data
    ground_truth_out_of_sample, pred_out_of_sample = load_out_of_sample_data()
    ground_truth_in_sample, pred_in_sample = load_in_sample_data()
    ground_truth_mouse, pred_mouse = load_mouse_data()
    
    # Load cell type labels
    print("Loading cell type labels...")
    broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels = load_cell_type_labels()
    
    # Load MERFISH data
    print("Loading MERFISH data...")
    g1_adata, g2_adata, merfish_coords = load_merfish_data()
    
    # Load out-of-sample mouse data
    print("Loading out-of-sample mouse data...")
    g1_pred, g2_pred, g1_scores, g2_scores = load_out_of_sample_mouse_data()
    
    print("\nCalculating metrics...")
    # Calculate program metrics
    program_metrics = calculate_program_metrics(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        common_in_sample, common_out_of_sample,
        all_programs
    )
    
    # Calculate cell type metrics
    cell_type_metrics = calculate_cell_type_metrics(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        pred_mouse, ground_truth_mouse,
        xenium_in_sample_fold_to_inferred_labels,
        mouse_fold_to_inferred_labels
    )
    
    # Calculate HTAPP metrics
    htapp_scores, htapp_dists = calculate_htapp_metrics(
        pred_htapp, ground_truth_htapp
    )
    
    print("\nGenerating visualizations...")
    # Generate program visualizations
    print("- Program visualizations...")
    plot_program_correlations(program_metrics)
    plot_program_scores(program_metrics)
    
    # Generate cell type visualizations
    print("- Cell type visualizations...")
    plot_cell_type_accuracies(
        [cell_type_metrics['mouse_pred_labels'],
         cell_type_metrics['xenium_pred_labels'],
         cell_type_metrics['out_of_sample_pred_labels']],
        [cell_type_metrics['mouse_true_labels'],
         cell_type_metrics['xenium_true_labels'],
         cell_type_metrics['out_of_sample_true_labels']],
        ['In-Sample Mouse', 'In-Sample Xenium MBC', 'New-Sample Xenium MBC']
    )
    
    # Plot cell type spatial distributions
    print("- Spatial distributions...")
    for sample_id in ['6760', '7149', '7179']:
        plot_cell_type_spatial_distribution(
            pred_htapp[sample_id],
            cell_type_metrics['htapp_labels'][sample_id],
            sample_id,
            f'HTAPP Sample {sample_id}'
        )
    
    # Generate HTAPP visualizations
    print("- HTAPP visualizations...")
    plot_htapp_marker_genes(
        pred_htapp, ground_truth_htapp,
        ['krt19', 'col1a2', 'apoc1', 'pecam1'],
        ['4531', '6760', '7479', '7629'],
        g_to_cmaps,
        'HTAPP Marker Genes',
        'htapp_markers'
    )
    
    plot_htapp_schematic(the_hists, 'HTAPP Overview', 'htapp_schematic')
    plot_htapp_metrics(htapp_scores, htapp_dists, 
                      'HTAPP Performance Metrics', 'htapp_metrics')
    
    print("\nAnalysis complete! All results saved in 'final_figures_schaf_revision_pngs/'")

if __name__ == "__main__":
    main()


In [None]:
# Generate MSKCC schematic figure
def generate_mskcc_schematic():
    image_dir = '/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/smaller_images/smaller_images'
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    n_images = len(image_files)
    n_cols = 6
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8))
    axes = axes.flatten()
    
    for i, image_file in tqdm(enumerate(image_files)):
        img_path = os.path.join(image_dir, image_file)
        img = Image.open(img_path)
        if '133729' in image_file or '129477' in image_file:
            img = np.array(img)
            third = img.shape[0] // 3
            img = img[third:-third,]
        
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title('MSKCC ' + image_file[:6])
    
    # Remove empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    
    plt.savefig(f'final_figures_schaf_revision_pngs/mskcc_schematic.png', dpi=400, transparent=True)
    plt.close()


In [None]:
# Load and process Visium data
def load_visium_data():
    xen_dir = '/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium'
    visium_adata = sc.read_10x_h5(f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/visium_breast_xenium.h5')
    visium_adata.var_names_make_unique()
    
    # Load and process tissue positions
    tissue_positions = pd.read_csv(f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/tissue_positions_visium.csv').set_index('barcode')
    tissue_positions = tissue_positions.loc[visium_adata.obs.index]
    visium_adata.obs = tissue_positions
    
    # Filter and process coordinates
    the_verts = (visium_adata.obs['pxl_col_in_fullres'].max() - visium_adata.obs['pxl_col_in_fullres'])
    good_vis_inds = np.where(the_verts<=10000)[0]
    visium_xs = np.array(list(visium_adata.obs['pxl_row_in_fullres'][good_vis_inds]))
    visium_xs = visium_xs - visium_xs.min()
    visium_ys = np.array(list(the_verts[good_vis_inds]))
    
    visium_adata = visium_adata[good_vis_inds]
    visium_adata.obs['x'] = visium_xs
    visium_adata.obs['y'] = visium_ys
    
    # Normalize data
    sc.pp.log1p(visium_adata)
    visium_adata.X = visium_adata.X.todense()
    visium_adata.var.index = [q.lower() for q in visium_adata.var.index]
    
    # Transform coordinates
    potential_trans = np.array([
        [0.130157158, 2.594980119, -12243.84897],
        [-2.594980119, 0.130157158, 40352.06194],
        [0, 0, 1],
    ])
    
    horis2 = np.array(visium_adata.obs['pxl_col_in_fullres'])
    verts2 = np.array(visium_adata.obs['pxl_row_in_fullres'])
    new_horis2, new_verts2 = do_transform2(horis2, verts2, potential_trans)
    
    visium_adata.obs['x'] = new_horis2
    visium_adata.obs['y'] = new_verts2
    
    return visium_adata

@njit(parallel=True)
def do_transform2(horis2, verts2, potential_trans):
    new_horis, new_verts = np.zeros_like(horis2), np.zeros_like(verts2)
    for t in prange(len(horis2)):
        i = horis2[t] / 1.
        j = verts2[t] / 1.
        new = potential_trans.dot(np.array([i,j,1]))
        new_horis[t] = new[0]
        new_verts[t] = new[1]
    return new_horis, new_verts


In [None]:
# Load and process ground truth and prediction data
def load_ground_truth_and_predictions():
    # Load ground truth out of sample data
    cells_info = pd.read_csv(f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/cells.csv')
    cells_info = cells_info.set_index('cell_id')
    
    # Load and process transformation matrix
    df = pd.read_csv(
        f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/alignment_new_xen.csv',
        header=None,
    )
    transformation_matrix = df.values.astype(np.float32)
    inv_trans = np.linalg.inv(transformation_matrix).astype('float64')
    
    # Transform coordinates
    verts = np.array(cells_info['y_centroid'])
    horis = np.array(cells_info['x_centroid'])
    xs, ys = do_transform(horis, verts, inv_trans)
    
    final_obs = cells_info
    final_obs['x'] = xs
    final_obs['y'] = ys
    
    # Load and process ground truth data
    new_mer = sc.read_10x_h5(
        f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/cell_feature_matrix.h5'
    )
    sc.pp.log1p(new_mer)
    new_mer.X = np.array(new_mer.X.todense())
    new_mer.var.index = list(q.lower() for q in new_mer.var.index)
    new_mer.obs = final_obs
    
    ground_truth_out_of_sample = new_mer
    
    # Load predictions
    pred_out_of_sample = sc.read_h5ad('/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/xenium_cancer_inferences/whole_sample.h5ad')
    
    # Load in-sample data
    ground_truth_in_sample = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/cancer_in_sample_folds/fold_{z}_st.h5ad')
        for z in range(4)
    }
    
    pred_in_sample = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/xenium_cancer_inferences/fold_{z}.h5ad')
        for z in range(4)
    }
    
    # Process in-sample data
    for z in ground_truth_in_sample:
        sc.pp.log1p(ground_truth_in_sample[z])
        ground_truth_in_sample[z].X = np.array(ground_truth_in_sample[z].X.todense())
    
    return ground_truth_out_of_sample, pred_out_of_sample, ground_truth_in_sample, pred_in_sample

@njit(parallel=True)
def do_transform(horis, verts, inv_trans):
    new_horis, new_verts = np.zeros_like(horis), np.zeros_like(verts)
    for t in prange(len(horis)):
        i = horis[t] / .2125
        j = verts[t] / .2125
        new = inv_trans.dot(np.array([i,j,1]))
        new_horis[t] = new[0]
        new_verts[t] = new[1]
    return new_horis, new_verts


In [None]:
def calculate_cell_type_heterogeneity(
    label_to_ground_truth_in_sample, label_to_pred_in_sample,
    label_to_ground_truth_mouse, label_to_pred_mouse,
    label_to_ground_truth_out_of_sample, label_to_pred_out_of_sample,
    common_in_sample, common_mouse, common_out_of_sample
):
    """
    Calculate cell type heterogeneity metrics across different datasets.
    
    Returns:
    - label_to_avg_in_sample_heteros: Dict mapping cell types to average heterogeneity in in-sample data
    - label_to_avg_mouse_heteros: Dict mapping cell types to average heterogeneity in mouse data
    - label_to_heteros_out_of_sample: Dict mapping cell types to heterogeneity in out-of-sample data
    """
    label_to_avg_in_sample_heteros = {}
    for label in label_to_ground_truth_in_sample:
        heteros = []
        for fold in label_to_ground_truth_in_sample[label]:
            gt = label_to_ground_truth_in_sample[label][fold][common_in_sample]
            pred = label_to_pred_in_sample[label][fold][common_in_sample]
            hetero = calculate_heterogeneity(gt, pred)
            heteros.append(hetero)
        label_to_avg_in_sample_heteros[label] = np.mean(heteros)
    
    label_to_avg_mouse_heteros = {}
    for label in label_to_ground_truth_mouse:
        heteros = []
        for fold in label_to_ground_truth_mouse[label]:
            gt = label_to_ground_truth_mouse[label][fold][common_mouse]
            pred = label_to_pred_mouse[label][fold][common_mouse]
            hetero = calculate_heterogeneity(gt, pred)
            heteros.append(hetero)
        label_to_avg_mouse_heteros[label] = np.mean(heteros)
    
    label_to_heteros_out_of_sample = {}
    for label in label_to_ground_truth_out_of_sample:
        gt = label_to_ground_truth_out_of_sample[label][common_out_of_sample]
        pred = label_to_pred_out_of_sample[label][common_out_of_sample]
        hetero = calculate_heterogeneity(gt, pred)
        label_to_heteros_out_of_sample[label] = hetero
    
    return label_to_avg_in_sample_heteros, label_to_avg_mouse_heteros, label_to_heteros_out_of_sample

def calculate_heterogeneity(ground_truth, predictions):
    """
    Calculate heterogeneity between ground truth and predictions.
    
    Args:
    - ground_truth: Ground truth values
    - predictions: Predicted values
    
    Returns:
    - heterogeneity: Calculated heterogeneity metric
    """
    # Calculate variance of differences
    differences = ground_truth - predictions
    return np.var(differences)


In [None]:
def generate_histology_split_figures(
    in_sample_fold_to_hist,
    mouse_fold_to_hist,
    in_sample_hist,
    out_of_sample_hist
):
    """
    Generate figures showing performance splits by histology.
    
    Args:
    - in_sample_fold_to_hist: Dict mapping folds to histology labels for in-sample data
    - mouse_fold_to_hist: Dict mapping folds to histology labels for mouse data
    - in_sample_hist: Histology labels for in-sample data
    - out_of_sample_hist: Histology labels for out-of-sample data
    """
    # Set up the figure
    plt.figure(figsize=(15, 10))
    
    # Plot in-sample histology splits
    plt.subplot(2, 2, 1)
    sns.boxplot(data=pd.DataFrame({
        'Histology': [in_sample_fold_to_hist[fold] for fold in in_sample_fold_to_hist],
        'Performance': [performance_metric(fold) for fold in in_sample_fold_to_hist]
    }))
    plt.title('In-sample Performance by Histology')
    plt.xticks(rotation=45)
    
    # Plot mouse histology splits
    plt.subplot(2, 2, 2)
    sns.boxplot(data=pd.DataFrame({
        'Histology': [mouse_fold_to_hist[fold] for fold in mouse_fold_to_hist],
        'Performance': [performance_metric(fold) for fold in mouse_fold_to_hist]
    }))
    plt.title('Mouse Performance by Histology')
    plt.xticks(rotation=45)
    
    # Plot out-of-sample histology comparison
    plt.subplot(2, 2, 3)
    sns.barplot(data=pd.DataFrame({
        'Histology': out_of_sample_hist,
        'Performance': [performance_metric(sample) for sample in out_of_sample_hist.index]
    }))
    plt.title('Out-of-sample Performance by Histology')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/histology_splits.png', dpi=400, bbox_inches='tight')
    plt.close()

def generate_correlation_histograms(
    mouse_avg_in_sample_corrs,
    mouse_train_genes
):
    """
    Generate histograms showing correlation distributions.
    
    Args:
    - mouse_avg_in_sample_corrs: Average correlations for mouse in-sample data
    - mouse_train_genes: List of genes used in mouse training
    """
    plt.figure(figsize=(12, 6))
    
    # Plot correlation distribution for mouse genes
    plt.subplot(1, 2, 1)
    sns.histplot(data=mouse_avg_in_sample_corrs[mouse_train_genes], bins=50)
    plt.title('Mouse Gene Correlations')
    plt.xlabel('Correlation')
    plt.ylabel('Count')
    
    # Plot correlation distribution for all genes
    plt.subplot(1, 2, 2)
    sns.histplot(data=mouse_avg_in_sample_corrs, bins=50)
    plt.title('All Gene Correlations')
    plt.xlabel('Correlation')
    plt.ylabel('Count')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/correlation_histograms.png', dpi=400, bbox_inches='tight')
    plt.close()

def performance_metric(sample):
    """
    Calculate performance metric for a given sample.
    This is a placeholder - implement actual metric calculation.
    """
    # Implement actual performance metric calculation here
    pass


In [None]:
def performance_metric(sample_data):
    """
    Calculate performance metric for a given sample.
    
    Args:
    - sample_data: Dictionary containing ground truth and predicted values
    
    Returns:
    - float: Performance metric (correlation coefficient)
    """
    # Extract ground truth and predictions
    ground_truth = sample_data['ground_truth']
    predictions = sample_data['predictions']
    
    # Calculate correlation coefficient
    correlation = np.corrcoef(ground_truth.flatten(), predictions.flatten())[0, 1]
    
    # Calculate R-squared
    r2 = r2_score(ground_truth.flatten(), predictions.flatten())
    
    # Return combined metric
    return 0.5 * (correlation + r2)  # Average of correlation and R-squared


In [None]:
def load_htapp_images():
    """
    Load HTAPP histology images.
    
    Returns:
    - Dict: Mapping of HTAPP sample IDs to histology images
    """
    image_dir = '/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/htapp_images'
    htapp_images = {}
    
    # Load all HTAPP images
    for image_file in os.listdir(image_dir):
        if image_file.endswith(('.tif', '.png')):
            sample_id = image_file.split('_')[0]
            img_path = os.path.join(image_dir, image_file)
            htapp_images[sample_id] = iio.imread(img_path)
    
    return htapp_images

def load_mskcc_images():
    """
    Load MSKCC histology images.
    
    Returns:
    - Dict: Mapping of MSKCC sample IDs to histology images
    """
    image_dir = '/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/smaller_images/smaller_images'
    mskcc_images = {}
    
    # Load all MSKCC images
    for image_file in os.listdir(image_dir):
        if image_file.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
            sample_id = image_file[:6]
            img_path = os.path.join(image_dir, image_file)
            img = Image.open(img_path)
            
            # Special processing for certain samples
            if '133729' in image_file or '129477' in image_file:
                img = np.array(img)
                third = img.shape[0] // 3
                img = img[third:-third,]
            
            mskcc_images[sample_id] = img
    
    return mskcc_images

def generate_mskcc_schematic():
    """
    Generate MSKCC schematic figure showing all samples.
    """
    # Load MSKCC images
    mskcc_images = load_mskcc_images()
    
    # Calculate grid dimensions
    n_images = len(mskcc_images)
    n_cols = 6
    n_rows = (n_images + n_cols - 1) // n_cols
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8))
    axes = axes.flatten()
    
    # Plot each image
    for i, (sample_id, img) in enumerate(mskcc_images.items()):
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f'MSKCC {sample_id}')
    
    # Remove empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/mskcc_schematic.png', dpi=400, transparent=True)
    plt.close()

def analyze_cell_type_specificity(adata, cell_type_key='cell_type'):
    """
    Analyze cell type-specific expression patterns.
    
    Args:
    - adata: AnnData object containing expression data
    - cell_type_key: Key in adata.obs containing cell type labels
    
    Returns:
    - Dict: Cell type-specific expression patterns and statistics
    """
    results = {}
    
    # Calculate mean expression per cell type
    cell_types = adata.obs[cell_type_key].unique()
    for ct in cell_types:
        ct_mask = adata.obs[cell_type_key] == ct
        ct_mean = adata[ct_mask].X.mean(axis=0)
        ct_std = adata[ct_mask].X.std(axis=0)
        
        results[ct] = {
            'mean_expression': ct_mean,
            'std_expression': ct_std,
            'n_cells': ct_mask.sum()
        }
    
    return results


In [None]:
def plot_cell_type_expression_patterns(cell_type_results, genes_of_interest=None):
    """
    Plot cell type-specific expression patterns.
    
    Args:
    - cell_type_results: Dict from analyze_cell_type_specificity
    - genes_of_interest: Optional list of genes to highlight
    """
    # Get all cell types and genes
    cell_types = list(cell_type_results.keys())
    if genes_of_interest is None:
        genes_of_interest = cell_type_results[cell_types[0]]['mean_expression'].index
    
    # Create expression matrix
    expr_matrix = np.zeros((len(cell_types), len(genes_of_interest)))
    for i, ct in enumerate(cell_types):
        expr_matrix[i] = cell_type_results[ct]['mean_expression'][genes_of_interest]
    
    # Plot heatmap
    plt.figure(figsize=(15, 10))
    sns.heatmap(
        expr_matrix,
        xticklabels=genes_of_interest,
        yticklabels=cell_types,
        cmap='viridis',
        center=0
    )
    plt.title('Cell Type-Specific Expression Patterns')
    plt.xlabel('Genes')
    plt.ylabel('Cell Types')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/cell_type_expression_patterns.png', dpi=400, bbox_inches='tight')
    plt.close()

def plot_cell_type_correlations(adata1, adata2, cell_type_key='cell_type'):
    """
    Plot correlations between cell types in two datasets.
    
    Args:
    - adata1: First AnnData object
    - adata2: Second AnnData object
    - cell_type_key: Key for cell type annotations
    """
    # Get common cell types
    cell_types = np.intersect1d(
        adata1.obs[cell_type_key].unique(),
        adata2.obs[cell_type_key].unique()
    )
    
    # Calculate correlations
    corr_matrix = np.zeros((len(cell_types), len(cell_types)))
    for i, ct1 in enumerate(cell_types):
        for j, ct2 in enumerate(cell_types):
            mask1 = adata1.obs[cell_type_key] == ct1
            mask2 = adata2.obs[cell_type_key] == ct2
            mean1 = adata1[mask1].X.mean(axis=0)
            mean2 = adata2[mask2].X.mean(axis=0)
            corr_matrix[i, j] = np.corrcoef(mean1, mean2)[0, 1]
    
    # Plot correlation matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        corr_matrix,
        xticklabels=cell_types,
        yticklabels=cell_types,
        cmap='coolwarm',
        center=0,
        vmin=-1,
        vmax=1
    )
    plt.title('Cell Type Correlations Between Datasets')
    plt.xlabel('Dataset 2 Cell Types')
    plt.ylabel('Dataset 1 Cell Types')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/cell_type_correlations.png', dpi=400, bbox_inches='tight')
    plt.close()

def plot_cell_type_heterogeneity(cell_type_results):
    """
    Plot cell type heterogeneity metrics.
    
    Args:
    - cell_type_results: Dict from analyze_cell_type_specificity
    """
    # Calculate heterogeneity (coefficient of variation) for each cell type
    cell_types = []
    cv_values = []
    for ct, data in cell_type_results.items():
        cell_types.append(ct)
        cv = data['std_expression'] / (data['mean_expression'] + 1e-10)  # Add small constant to avoid division by zero
        cv_values.append(cv.mean())
    
    # Plot bar chart
    plt.figure(figsize=(12, 6))
    sns.barplot(x=cell_types, y=cv_values)
    plt.title('Cell Type Heterogeneity')
    plt.xlabel('Cell Type')
    plt.ylabel('Mean Coefficient of Variation')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/cell_type_heterogeneity.png', dpi=400, bbox_inches='tight')
    plt.close()


In [None]:
def load_visium_data():
    """
    Load and preprocess Visium spatial transcriptomics data.
    
    Returns:
    - AnnData: Processed Visium data
    """
    # Load Visium data
    visium_adata = sc.read_10x_h5(f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/visium_breast_xenium.h5')
    visium_adata.var_names_make_unique()
    
    # Load tissue positions
    tissue_positions = pd.read_csv(f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/tissue_positions_visium.csv').set_index('barcode')
    tissue_positions = tissue_positions.loc[visium_adata.obs.index]
    visium_adata.obs = tissue_positions
    
    # Filter spots
    the_verts = (visium_adata.obs['pxl_col_in_fullres'].max() - visium_adata.obs['pxl_col_in_fullres'])
    good_vis_inds = np.where(the_verts <= 10000)[0]
    visium_adata = visium_adata[good_vis_inds]
    
    # Process coordinates
    visium_xs = np.array(visium_adata.obs['pxl_row_in_fullres'])
    visium_xs = visium_xs - visium_xs.min()
    visium_ys = np.array(the_verts)
    visium_adata.obs['x'] = visium_xs
    visium_adata.obs['y'] = visium_ys
    
    # Transform data
    sc.pp.log1p(visium_adata)
    visium_adata.X = visium_adata.X.todense()
    visium_adata.var.index = [q.lower() for q in visium_adata.var.index]
    
    # Apply spatial transformation
    potential_trans = np.array([
        [0.130157158, 2.594980119, -12243.84897],
        [-2.594980119, 0.130157158, 40352.06194],
        [0, 0, 1],
    ])
    
    horis = np.array(visium_adata.obs['pxl_col_in_fullres'])
    verts = np.array(visium_adata.obs['pxl_row_in_fullres'])
    new_horis, new_verts = transform_coordinates(horis, verts, potential_trans)
    
    visium_adata.obs['x'] = new_horis
    visium_adata.obs['y'] = new_verts
    
    return visium_adata

@njit(parallel=True)
def transform_coordinates(horis, verts, trans_matrix):
    """
    Transform spatial coordinates using transformation matrix.
    
    Args:
    - horis: Horizontal coordinates
    - verts: Vertical coordinates
    - trans_matrix: Transformation matrix
    
    Returns:
    - Tuple: Transformed coordinates (new_horis, new_verts)
    """
    new_horis = np.zeros_like(horis)
    new_verts = np.zeros_like(verts)
    
    for t in prange(len(horis)):
        i = horis[t]
        j = verts[t]
        new = trans_matrix.dot(np.array([i, j, 1]))
        new_horis[t] = new[0]
        new_verts[t] = new[1]
    
    return new_horis, new_verts

def trans_good(x):
    """
    Transform data using log1p and exp transformations.
    
    Args:
    - x: Input data
    
    Returns:
    - Transformed data
    """
    return np.log1p(((np.exp(x)) - 1.).round())


In [None]:
class CellTypeClassifier(nn.Module):
    """Neural network for cell type classification."""
    
    def __init__(self, input_dim, hidden_dims, num_classes):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # Add hidden layers
        for dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.BatchNorm1d(dim),
                nn.ReLU()
            ])
            prev_dim = dim
        
        # Add output layer
        layers.extend([
            nn.Linear(prev_dim, num_classes),
            nn.Softmax(dim=1)
        ])
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

def predict_cell_types(orig_adata, pred_adata, celltype_key, device='cuda'):
    """
    Predict cell types using a neural network classifier.
    
    Args:
    - orig_adata: Original AnnData with cell type labels
    - pred_adata: AnnData to predict cell types for
    - celltype_key: Key in orig_adata.obs containing cell type labels
    - device: Device to run computations on
    
    Returns:
    - np.array: Predicted cell type labels
    - float: Classification accuracy on original data
    """
    # Get common genes
    common_var = np.intersect1d(orig_adata.var.index, pred_adata.var.index)
    
    # Get cell type labels
    annos = np.unique(orig_adata.obs[celltype_key])
    anno_to_label = dict(zip(annos, range(len(annos))))
    label_to_anno = dict(zip(range(len(annos)), annos))
    
    # Prepare data
    orig_features = orig_adata[:, common_var].X
    orig_features = (orig_features - orig_features.mean(axis=0)) / orig_features.std(axis=0)
    orig_features = np.nan_to_num(orig_features)
    
    orig_labels = np.array([anno_to_label[anno] for anno in orig_adata.obs[celltype_key]])
    
    # Create model
    input_dim = len(common_var)
    if input_dim > (1 << 10):
        hidden_dims = [1 << 10, 1 << 8, 1 << 6]
    else:
        hidden_dims = [1 << 8, 1 << 6, 1 << 4]
    
    model = CellTypeClassifier(input_dim, hidden_dims, len(annos)).to(device)
    
    # Calculate class weights
    num_cells = len(orig_labels)
    class_weights = torch.tensor([
        (float(num_cells) / np.sum(orig_labels == i)) 
        for i in range(len(annos))
    ]).float().to(device)
    
    # Training parameters
    batch_size = 128
    epochs = 10
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # Create data loader
    train_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(orig_features).float(),
            torch.from_numpy(orig_labels)
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=6,
        pin_memory=True
    )
    
    # Train model
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch_features, batch_labels in train_loader:
            batch_features = batch_features.to(device) / 10.
            batch_labels = batch_labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader.dataset):.6f}')
    
    # Predict on original data
    model.eval()
    orig_pred = []
    with torch.no_grad():
        for batch_features, _ in train_loader:
            batch_features = batch_features.to(device) / 10.
            outputs = model(batch_features)
            orig_pred.extend(outputs.argmax(dim=1).cpu().numpy())
    
    # Prepare prediction data
    pred_features = pred_adata[:, common_var].X
    pred_features = (pred_features - pred_features.mean(axis=0)) / pred_features.std(axis=0)
    pred_features = np.nan_to_num(pred_features)
    
    pred_loader = DataLoader(
        TensorDataset(torch.from_numpy(pred_features).float()),
        batch_size=batch_size,
        shuffle=False,
        num_workers=6,
        pin_memory=True
    )
    
    # Predict on new data
    new_pred = []
    with torch.no_grad():
        for (batch_features,) in pred_loader:
            batch_features = batch_features.to(device) / 10.
            outputs = model(batch_features)
            new_pred.extend(outputs.argmax(dim=1).cpu().numpy())
    
    # Calculate accuracy on original data
    orig_pred = np.array(orig_pred)
    accuracy = np.mean(orig_pred == orig_labels)
    print(f'Classification accuracy on original data: {accuracy:.4f}')
    
    # Convert numeric predictions back to cell type labels
    new_pred = np.array([label_to_anno[label] for label in new_pred])
    
    return new_pred, accuracy


In [None]:
def analyze_spatial_data(pred_adata, visium_adata, fold):
    """
    Analyze spatial relationships between predicted data and Visium data.
    
    Args:
    - pred_adata: Predicted AnnData object
    - visium_adata: Visium AnnData object
    - fold: Fold number for analysis
    
    Returns:
    - Dict: Mapping of Xenium coordinates to Visium indices
    - np.array: Spot predictions
    """
    # Get common genes
    common_var = np.intersect1d(pred_adata.var.index, visium_adata.var.index)
    
    # Extract coordinates
    vis_xs = []
    vis_ys = []
    vis_ts = []
    
    # Define region boundaries
    y_lim = 12900
    x_lim = 17700
    
    # Filter Visium spots based on fold
    for x, y, t in zip(visium_adata.obs['x'], visium_adata.obs['y'], 
                      visium_adata[:, common_var].X):
        if not (0 <= y <= 25778 and 0 <= x < 35416):
            continue
            
        if fold == 0 and x < x_lim and y < y_lim:
            vis_xs.append(x)
            vis_ys.append(y)
            vis_ts.append(t)
        elif fold == 1 and x >= x_lim and y >= y_lim:
            vis_xs.append(x)
            vis_ys.append(y)
            vis_ts.append(t)
        elif fold == 2 and x >= x_lim and y < y_lim:
            vis_xs.append(x)
            vis_ys.append(y)
            vis_ts.append(t)
        elif fold == 3 and x < x_lim and y >= y_lim:
            vis_xs.append(x)
            vis_ys.append(y)
            vis_ts.append(t)
    
    vis_xs = np.array(vis_xs)
    vis_ys = np.array(vis_ys)
    vis_ts = np.array(vis_ts)
    
    # Define boundary lines
    def get_line_params(p1, p2):
        """Get line parameters (slope and intercept)."""
        x1, y1 = p1
        x2, y2 = p2
        m = (y2 - y1) / (x2 - x1)
        b = y1 - m * x1
        return m, b
    
    bottom_line = get_line_params((3201, 4850), (1386, 1458))
    right_line = get_line_params((33057, 8591), (33218, 4798))
    left_line = get_line_params((504, 25967), (726, 20755))
    
    # Filter predicted data points
    xen_xs = []
    xen_ys = []
    xen_ps = []
    
    for x, y, p in zip(pred_adata.obs['x'], pred_adata.obs['y'], 
                      pred_adata[:, common_var].X):
        if y > 25761:
            continue
            
        m, b = bottom_line
        if y < m * x + b:
            continue
            
        m, b = left_line
        if y < m * x + b:
            continue
            
        m, b = right_line
        if y > m * x + b:
            continue
            
        xen_xs.append(x)
        xen_ys.append(y)
        xen_ps.append(p)
    
    xen_xs = np.array(xen_xs)
    xen_ys = np.array(xen_ys)
    xen_ps = np.array(xen_ps)
    
    # Map Xenium spots to nearest Visium spots
    xen_to_vis_ind = {}
    radius = 130
    
    for x, y in zip(xen_xs, xen_ys):
        ind = -1
        for i, (vx, vy) in enumerate(zip(vis_xs, vis_ys)):
            if (vx - x)**2 + (vy - y)**2 < radius**2:
                ind = i
                break
        xen_to_vis_ind[(x, y)] = ind
    
    # Calculate spot predictions
    num_spots = len(vis_xs)
    spot_preds = np.zeros((num_spots, len(common_var)))
    
    for x, y, p in zip(xen_xs, xen_ys, xen_ps):
        if xen_to_vis_ind[(x, y)] >= 0:
            ind = xen_to_vis_ind[(x, y)]
            spot_preds[ind] += p
    
    return xen_to_vis_ind, spot_preds

def plot_spatial_predictions(visium_adata, spot_preds, fold, gene_name):
    """
    Plot spatial predictions for a specific gene.
    
    Args:
    - visium_adata: Visium AnnData object
    - spot_preds: Predicted spot values
    - fold: Fold number
    - gene_name: Name of gene to plot
    """
    # Create figure
    plt.figure(figsize=(12, 5))
    
    # Plot ground truth
    plt.subplot(121)
    sc.pl.spatial(
        visium_adata,
        color=gene_name,
        show=False,
        title=f'Ground Truth - {gene_name}'
    )
    
    # Plot predictions
    plt.subplot(122)
    gene_idx = visium_adata.var.index.get_loc(gene_name)
    sc.pl.spatial(
        visium_adata,
        color=spot_preds[:, gene_idx],
        show=False,
        title=f'Predictions - {gene_name}'
    )
    
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/spatial_predictions_fold{fold}_{gene_name}.png', 
                dpi=400, bbox_inches='tight')
    plt.close()


In [None]:
def load_programs():
    """
    Load hallmark and cancer programs.
    
    Returns:
    - Dict: Combined hallmark and cancer programs
    """
    # Load hallmark programs
    with open('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/hallmark_programs.json') as f:
        hallmark_programs = json.load(f)
    
    # Load cancer programs
    with open('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/cancer_programs.json') as f:
        cancer_programs = json.load(f)
    
    # Extract gene symbols
    hallmark_programs = {k: v['geneSymbols'] for k, v in hallmark_programs.items()}
    cancer_programs = {k: v['geneSymbols'] for k, v in cancer_programs.items()}
    
    # Combine programs
    hallmark_programs.update(cancer_programs)
    return hallmark_programs

def load_htapp_data():
    """
    Load and process HTAPP data.
    
    Returns:
    - Dict: Ground truth HTAPP data
    - Dict: Predicted HTAPP data
    - Dict: MERFISH HTAPP data
    """
    # Load single-cell data
    these_keys = ['4531', '7179', '7479', '7629', '932', '6760', '7149', '4381', '8239']
    sc_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/final_scs/schtapp'
    
    # Process single-cell data
    the_scs = {}
    for file in os.listdir(sc_dir):
        for k in these_keys:
            if k not in file:
                continue
            sc_adata = sc.read_h5ad(f'{sc_dir}/{file}')
            new_v = sc.AnnData(X=np.array(sc_adata.obsm['counts'].todense()), obs=sc_adata.obs)
            sc.pp.log1p(new_v)
            new_v.var.index = sc_adata.uns['counts_var']
            sc_adata = new_v
            sc_adata.var.index = [q.lower() for q in sc_adata.var.index]
            the_scs[k] = sc_adata
    
    # Get common genes
    all_common_sc_var = the_scs['6760'].var.index
    for k, v in the_scs.items():
        all_common_sc_var = np.intersect1d(all_common_sc_var, v.var.index)
        the_scs[k] = v[:, all_common_sc_var]
    
    # Load predictions
    pred_htapp = {
        k: sc.read_h5ad(f'htapp_inferences_with_annos/{k}.h5ad') 
        for k in these_keys
    }
    
    # Load MERFISH data
    mers_dir = f'{os.getcwd().split("/ccomiter/")[0]}/ccomiter/htapp_supervise/final_mers'
    merfish_htapp = {}
    for f in os.listdir(mers_dir):
        key = f.split('_')[0]
        mer = sc.read_h5ad(os.path.join(mers_dir, f'{key}_merfish.h5ad'))
        mer.X = np.array(mer.obsm['counts'].todense())
        sc.pp.log1p(mer)
        mer.var.index = [q.lower() for q in mer.var.index]
        merfish_htapp[key] = mer
    
    # Get common MERFISH genes
    all_common_mer_var = merfish_htapp['6760'].var.index
    for k, v in merfish_htapp.items():
        all_common_mer_var = np.intersect1d(all_common_mer_var, v.var.index)
        merfish_htapp[k] = v[:, all_common_mer_var]
    
    return the_scs, pred_htapp, merfish_htapp

def generate_cell_type_labels(merfish_data, segmentation_data):
    """
    Generate cell type labels from MERFISH and segmentation data.
    
    Args:
    - merfish_data: Dict of MERFISH data
    - segmentation_data: Dict of segmentation data
    
    Returns:
    - Dict: Cell type labels for each sample
    """
    # Define cell types
    celltypes = ['Fibrosis_1', 'ImmuneCells_1', 'Normal_1', 'Tumor_1', 'Vasculature_1']
    ct_names = ['Fibrosis', 'Immune', 'Normal', 'Tumor', 'Vascular']
    ct_to_name = dict(zip(celltypes, ct_names))
    
    # Process segmentation data
    the_xs = {}
    the_ys = {}
    for k, seg in segmentation_data.items():
        the_xs[k] = np.array(seg['Centroid X px']).astype(int)
        the_ys[k] = np.array(seg['Centroid Y px']).astype(int)
    
    def make_label(info):
        """Create normalized label vector."""
        res = []
        for ct in celltypes:
            if ct in info and info[ct]:
                res.append(1)
            else:
                res.append(0)
        res = np.array(res)
        if res.sum() > 0:
            res = res / res.sum()
        return res
    
    # Generate labels
    the_labels = {}
    for k, mer in merfish_data.items():
        # Handle special cases
        if k == '8239':
            mer.obs['Vasculature_1'] = mer.obs['BloodClots_1']
        elif k == '932':
            mer.obs['Vasculature_1'] = mer.obs['BloodVessels_1']
        
        # Get coordinates
        xs = list(mer.obs['x'])
        ys = list(mer.obs['y'])
        
        # Build KD-tree
        tree = cKDTree(np.array(list(zip(xs, ys))))
        
        # Generate labels
        labels = []
        for x, y in zip(the_xs[k], the_ys[k]):
            _, ind = tree.query((x, y), k=1)
            label = make_label(mer.obs.iloc[ind])
            labels.append(label)
        
        the_labels[k] = np.array(labels)
    
    return the_labels

def calculate_spatial_distances(merfish_data, segmentation_data):
    """
    Calculate spatial distances between cells.
    
    Args:
    - merfish_data: Dict of MERFISH data
    - segmentation_data: Dict of segmentation data
    
    Returns:
    - Dict: Distances for each sample
    """
    the_dists = {}
    for k, mer in merfish_data.items():
        # Get coordinates
        xs = list(mer.obs['x'])
        ys = list(mer.obs['y'])
        
        # Build KD-tree
        tree = cKDTree(np.array(list(zip(xs, ys))))
        
        # Calculate distances
        dists = []
        for x, y in zip(segmentation_data[k]['Centroid X px'], 
                       segmentation_data[k]['Centroid Y px']):
            dist, _ = tree.query((x, y), k=1)
            dists.append(dist)
        
        the_dists[k] = np.array(dists)
    
    return the_dists


In [None]:
def calculate_correlations_and_scores(
    ground_truth_out_of_sample, pred_out_of_sample,
    ground_truth_in_sample, pred_in_sample
):
    """
    Calculate correlations and scores for both in-sample and out-of-sample data.
    
    Args:
    - ground_truth_out_of_sample: Ground truth out-of-sample data
    - pred_out_of_sample: Predicted out-of-sample data
    - ground_truth_in_sample: Dict of ground truth in-sample data
    - pred_in_sample: Dict of predicted in-sample data
    
    Returns:
    - Tuple: Correlation and score metrics
    """
    # Get common genes
    common_out_of_sample = np.intersect1d(
        ground_truth_out_of_sample.var.index,
        pred_out_of_sample.var.index
    )
    common_in_sample = np.intersect1d(
        ground_truth_in_sample[0].var.index,
        pred_in_sample[0].var.index
    )
    all_pred_genes = list(pred_in_sample[0].var.index)
    
    # Calculate out-of-sample metrics
    out_of_sample_corrs = {}
    out_of_sample_scores = {}
    for g in common_out_of_sample:
        t = np.array(ground_truth_out_of_sample[:, g].X.squeeze())
        p = np.array(pred_out_of_sample[:, g].X.squeeze())
        out_of_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
    out_of_sample_scores = dict(zip(all_pred_genes, pred_out_of_sample.X.std(axis=0)))
    
    # Calculate in-sample metrics
    fold_to_in_sample_corrs = {}
    fold_to_in_sample_scores = {}
    for z in range(4):
        in_sample_corrs = {}
        in_sample_scores = {}
        for g in common_in_sample:
            t = np.array(ground_truth_in_sample[z][:, g].X.squeeze())
            p = np.array(pred_in_sample[z][:, g].X.squeeze())
            in_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
        in_sample_scores = dict(zip(all_pred_genes, pred_in_sample[z].X.std(axis=0)))
        fold_to_in_sample_corrs[z] = in_sample_corrs
        fold_to_in_sample_scores[z] = in_sample_scores
    
    # Calculate fold proportions
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells
    
    # Calculate average in-sample metrics
    avg_in_sample_corrs = {}
    avg_in_sample_scores = {}
    for g in common_in_sample:
        avg_in_sample_corrs[g] = sum(
            fold_to_in_sample_corrs[z][g] * fold_to_prop[z] 
            for z in range(4)
        )
    for g in all_pred_genes:
        avg_in_sample_scores[g] = sum(
            fold_to_in_sample_scores[z][g] * fold_to_prop[z] 
            for z in range(4)
        )
    
    return (
        out_of_sample_corrs,
        out_of_sample_scores,
        avg_in_sample_corrs,
        avg_in_sample_scores,
        common_out_of_sample,
        common_in_sample
    )

def calculate_program_metrics(
    pred_in_sample, ground_truth_in_sample,
    pred_out_of_sample, ground_truth_out_of_sample,
    common_in_sample, common_out_of_sample,
    all_programs
):
    """
    Calculate program-level metrics.
    
    Args:
    - pred_in_sample: Dict of predicted in-sample data
    - ground_truth_in_sample: Dict of ground truth in-sample data
    - pred_out_of_sample: Predicted out-of-sample data
    - ground_truth_out_of_sample: Ground truth out-of-sample data
    - common_in_sample: Common genes for in-sample data
    - common_out_of_sample: Common genes for out-of-sample data
    - all_programs: Dict of gene programs
    
    Returns:
    - Tuple: Program-level metrics
    """
    # Calculate fold proportions
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells
    
    # Calculate in-sample program metrics
    gene_set_to_avg_in_sample_program_corrs = {}
    gene_set_to_avg_in_sample_program_scores = {}
    
    for program_name, program_genes in all_programs.items():
        program_genes = [g.lower() for g in program_genes]
        program_genes = np.intersect1d(program_genes, common_in_sample)
        
        if len(program_genes) < 5:
            continue
        
        fold_to_program_corrs = {}
        fold_to_program_scores = {}
        
        for fold in range(4):
            # Calculate program expression
            true_program = ground_truth_in_sample[fold][:, program_genes].X.mean(axis=1)
            pred_program = pred_in_sample[fold][:, program_genes].X.mean(axis=1)
            
            # Calculate correlation and score
            fold_to_program_corrs[fold] = np.corrcoef(true_program, pred_program)[0, 1]
            fold_to_program_scores[fold] = pred_program.std()
        
        # Calculate weighted averages
        gene_set_to_avg_in_sample_program_corrs[program_name] = sum(
            fold_to_program_corrs[z] * fold_to_prop[z] 
            for z in range(4)
        )
        gene_set_to_avg_in_sample_program_scores[program_name] = sum(
            fold_to_program_scores[z] * fold_to_prop[z] 
            for z in range(4)
        )
    
    # Calculate out-of-sample program metrics
    gene_set_to_out_of_sample_program_corrs = {}
    gene_set_to_out_of_sample_program_scores = {}
    
    for program_name, program_genes in all_programs.items():
        program_genes = [g.lower() for g in program_genes]
        program_genes = np.intersect1d(program_genes, common_out_of_sample)
        
        if len(program_genes) < 5:
            continue
        
        # Calculate program expression
        true_program = ground_truth_out_of_sample[:, program_genes].X.mean(axis=1)
        pred_program = pred_out_of_sample[:, program_genes].X.mean(axis=1)
        
        # Calculate correlation and score
        gene_set_to_out_of_sample_program_corrs[program_name] = np.corrcoef(
            true_program, pred_program
        )[0, 1]
        gene_set_to_out_of_sample_program_scores[program_name] = pred_program.std()
    
    return (
        gene_set_to_avg_in_sample_program_corrs,
        gene_set_to_avg_in_sample_program_scores,
        gene_set_to_out_of_sample_program_corrs,
        gene_set_to_out_of_sample_program_scores
    )


In [None]:
# Calculate correlations and scores
def calculate_correlations_and_scores(ground_truth_out_of_sample, pred_out_of_sample, ground_truth_in_sample, pred_in_sample):
    # Calculate out-of-sample correlations and scores
    common_out_of_sample = np.intersect1d(ground_truth_out_of_sample.var.index, pred_out_of_sample.var.index)
    common_in_sample = np.intersect1d(ground_truth_in_sample[0].var.index, pred_in_sample[0].var.index)
    all_pred_genes = list(pred_in_sample[0].var.index)
    
    out_of_sample_corrs = {}
    out_of_sample_scores = {}
    for g in common_out_of_sample:
        t = np.array(ground_truth_out_of_sample[::,g].X.squeeze())
        p = np.array(pred_out_of_sample[::,g].X.squeeze())
        out_of_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
        out_of_sample_scores[g] = p.std()
    out_of_sample_scores = dict(zip(all_pred_genes, pred_out_of_sample.X.std(axis=0)))
    
    # Calculate in-sample correlations and scores
    fold_to_in_sample_corrs = {}
    fold_to_in_sample_scores = {}
    for z in range(4):
        in_sample_corrs = {}
        in_sample_scores = {}
        for g in common_in_sample:
            t = np.array(ground_truth_in_sample[z][::,g].X.squeeze())
            p = np.array(pred_in_sample[z][::,g].X.squeeze())
            in_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
            in_sample_scores[g] = p.std()
        in_sample_scores = dict(zip(all_pred_genes, pred_in_sample[z].X.std(axis=0)))
        fold_to_in_sample_corrs[z] = in_sample_corrs
        fold_to_in_sample_scores[z] = in_sample_scores
    
    # Calculate fold proportions and averages
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells
    
    avg_in_sample_corrs = {}
    avg_in_sample_scores = {}
    for g in common_in_sample:
        avg_in_sample_corrs[g] = sum(fold_to_in_sample_corrs[z][g] * fold_to_prop[z] for z in range(4))
        avg_in_sample_scores[g] = sum(fold_to_in_sample_scores[z][g] * fold_to_prop[z] for z in range(4))
    for g in all_pred_genes:
        avg_in_sample_scores[g] = sum(fold_to_in_sample_scores[z][g] * fold_to_prop[z] for z in range(4))
    
    return (out_of_sample_corrs, out_of_sample_scores, avg_in_sample_corrs, avg_in_sample_scores,
            common_out_of_sample, common_in_sample)


In [None]:
# Load and process programs
def load_programs():
    # Load hallmark programs
    with open('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/hallmark_programs.json') as f:
        hallmark_programs = json.load(f)
    
    # Load cancer programs
    with open('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/cancer_programs.json') as f:
        cancer_programs = json.load(f)
    
    # Process programs
    hallmark_programs = {k: v['geneSymbols'] for k, v in hallmark_programs.items()}
    cancer_programs = {k: v['geneSymbols'] for k, v in cancer_programs.items()}
    
    # Convert gene symbols to lowercase
    hallmark_programs = {k: [g.lower() for g in v] for k, v in hallmark_programs.items()}
    cancer_programs = {k: [g.lower() for g in v] for k, v in cancer_programs.items()}
    
    # Combine programs
    all_programs = hallmark_programs.copy()
    all_programs.update(cancer_programs)
    
    return all_programs

# Main execution
def main():
    # Create output directory
    os.makedirs('final_figures_schaf_revision_pngs', exist_ok=True)
    
    # Load data
    print("Loading Visium data...")
    visium_adata = load_visium_data()
    
    print("Loading ground truth and prediction data...")
    ground_truth_out_of_sample, pred_out_of_sample, ground_truth_in_sample, pred_in_sample = load_ground_truth_and_predictions()
    
    # Calculate correlations and scores
    print("Calculating correlations and scores...")
    results = calculate_correlations_and_scores(
        ground_truth_out_of_sample, pred_out_of_sample, 
        ground_truth_in_sample, pred_in_sample
    )
    out_of_sample_corrs, out_of_sample_scores, avg_in_sample_corrs, avg_in_sample_scores, common_out_of_sample, common_in_sample = results
    
    # Load programs
    print("Loading and processing programs...")
    all_programs = load_programs()
    
    # Calculate program metrics
    print("Calculating program metrics...")
    program_results = calculate_program_metrics(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        common_in_sample, common_out_of_sample,
        all_programs
    )
    gene_set_to_avg_in_sample_program_corrs, gene_set_to_avg_in_sample_program_scores, gene_set_to_out_of_sample_program_corrs, gene_set_to_out_of_sample_program_scores = program_results
    
    # Generate correlation heatmaps
    print("Generating correlation heatmaps...")
    generate_heatmaps(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        common_in_sample, common_out_of_sample
    )
    
    # Load and process mouse data
    print("Loading mouse data...")
    pred_mouse, ground_truth_mouse, pred_mouse_scores = load_mouse_data()
    
    
    # Calculate mouse metrics
    print("Calculating mouse metrics...")
    mouse_avg_in_sample_corrs, mouse_avg_in_sample_scores, common_mouse = calculate_mouse_metrics(
        pred_mouse, ground_truth_mouse, pred_mouse_scores
    )
    
    # Generate mouse heatmaps
    print("Generating mouse heatmaps...")
    generate_mouse_heatmaps(pred_mouse, ground_truth_mouse, common_mouse)
    
    # Load cell type labels
    print("Loading cell type labels...")
    broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels = load_cell_type_labels()
    
    # Calculate cell type-specific expression
    print("Calculating cell type-specific expression...")
    cell_type_results = calculate_cell_type_expression(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        pred_mouse, ground_truth_mouse,
        common_in_sample, common_mouse, common_out_of_sample,
        broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels
    )
    fold_to_ct_to_mouse_pred_mean, fold_to_ct_to_mouse_tru_mean, fold_to_ct_to_in_sample_pred_mean, fold_to_ct_to_in_sample_tru_mean, ct_to_out_of_sample_pred_mean, ct_to_out_of_sample_tru_mean = cell_type_results
    
    # Generate cell type expression heatmaps
    print("Generating cell type expression heatmaps...")
    generate_cell_type_heatmaps(
        fold_to_ct_to_mouse_pred_mean, fold_to_ct_to_mouse_tru_mean,
        fold_to_ct_to_in_sample_pred_mean, fold_to_ct_to_in_sample_tru_mean,
        ct_to_out_of_sample_pred_mean, ct_to_out_of_sample_tru_mean
    )
    
    # Calculate cell type-specific correlations
    print("Calculating cell type-specific correlations...")
    dataset_to_ct_corrs_info = calculate_cell_type_correlations(
        pred_in_sample, ground_truth_in_sample,
        pred_out_of_sample, ground_truth_out_of_sample,
        pred_mouse, ground_truth_mouse,
        common_in_sample, common_mouse, common_out_of_sample,
        ground_truth_out_of_sample_for_cts, pred_out_of_sample_for_cts
    )
    
    # Generate cell type correlation plots
    print("Generating cell type correlation plots...")
    generate_cell_type_correlation_plots(dataset_to_ct_corrs_info)
    
    # Generate cell type heterogeneity plots
    print("Generating cell type heterogeneity plots...")
    generate_cell_type_heterogeneity_plots(dataset_to_ct_corrs_info)
    
    # Calculate cell type metacorrelations
    print("Calculating cell type metacorrelations...")
    label_to_in_sample_plot_metacorr, label_to_mouse_plot_metacorr, label_to_out_of_sample_plot_metacorr = calculate_cell_type_metacorrelations(
        label_to_ground_truth_in_sample, label_to_pred_in_sample,
        label_to_ground_truth_mouse, label_to_pred_mouse,
        label_to_ground_truth_out_of_sample, label_to_pred_out_of_sample,
        common_in_sample, common_mouse, common_out_of_sample
    )
    
    # Generate metacorrelation plots
    print("Generating metacorrelation plots...")
    generate_metacorrelation_plots(
        label_to_in_sample_plot_metacorr,
        label_to_mouse_plot_metacorr,
        label_to_out_of_sample_plot_metacorr
    )
    
    # Calculate cell type heterogeneity
    print("Calculating cell type heterogeneity...")
    label_to_avg_in_sample_heteros, label_to_avg_mouse_heteros, label_to_heteros_out_of_sample = calculate_cell_type_heterogeneity(
        label_to_ground_truth_in_sample, label_to_pred_in_sample,
        label_to_ground_truth_mouse, label_to_pred_mouse,
        label_to_ground_truth_out_of_sample, label_to_pred_out_of_sample,
        common_in_sample, common_mouse, common_out_of_sample
    )
    
    # Generate histology split figures
    print("Generating histology split figures...")
    generate_histology_split_figures(
        in_sample_fold_to_hist,
        mouse_fold_to_hist,
        in_sample_hist,
        out_of_sample
    )
    
    # Generate correlation histograms
    print("Generating correlation histograms...")
    generate_correlation_histograms(
        mouse_avg_in_sample_corrs,
        mouse_train_genes
    )
    
    # Generate HTAPP and MSKCC schematics
    print("Generating HTAPP and MSKCC schematics...")
    generate_htapp_schematic(the_hists)  # Note: the_hists needs to be defined
    generate_mskcc_schematic()
    
    print("All figures have been generated successfully!")
    
if __name__ == "__main__":
    main()


In [None]:
# Generate correlation heatmaps
def generate_heatmaps(pred_in_sample, ground_truth_in_sample, pred_out_of_sample, ground_truth_out_of_sample,
                     common_in_sample, common_out_of_sample):
    # Calculate in-sample heatmaps
    fold_to_pred_in_sample_heatmap = {}
    fold_to_true_in_sample_heatmap = {}
    fold_to_metacorr = {}
    
    for z in range(4):
        pred_arr = np.array(pred_in_sample[z][::,common_in_sample].X)
        true_arr = np.array(ground_truth_in_sample[z][::,common_in_sample].X.squeeze())
        
        pred_heatmap = np.corrcoef(pred_arr, rowvar=0)
        true_heatmap = np.corrcoef(true_arr, rowvar=0)
        
        if not z:
            hierarchical_cluster = AgglomerativeClustering(n_clusters=2, affinity='euclidean', linkage='ward')
            labels = hierarchical_cluster.fit_predict(true_arr.T)
            c1_inds = [i for i, l in enumerate(labels) if l]
            c2_inds = [i for i, l in enumerate(labels) if not l]
        
        pred_heatmap = pred_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        true_heatmap = true_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        
        fold_to_pred_in_sample_heatmap[z] = pred_heatmap
        fold_to_true_in_sample_heatmap[z] = true_heatmap
        fold_to_metacorr[z] = np.corrcoef(pred_heatmap.reshape(-1), true_heatmap.reshape(-1))[0, 1]
    
    # Calculate fold proportions
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_in_sample[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_in_sample[z].shape[0] / total_cells
    
    # Calculate average in-sample heatmaps
    avg_pred_in_sample_heatmap = sum([fold_to_pred_in_sample_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_true_in_sample_heatmap = sum([fold_to_true_in_sample_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_in_sample_metacorr = sum(fold_to_metacorr[z]*fold_to_prop[z] for z in range(4))
    
    # Calculate out-of-sample heatmaps
    pred_arr = np.array(pred_out_of_sample[::,common_out_of_sample].X)
    true_arr = np.array(ground_truth_out_of_sample[::, common_out_of_sample].X)
    
    pred_heatmap_out_of_sample = np.corrcoef(pred_arr, rowvar=0)
    true_heatmap_out_of_sample = np.corrcoef(true_arr, rowvar=0)
    
    hierarchical_cluster = AgglomerativeClustering(n_clusters=2, affinity='euclidean', linkage='ward')
    labels = hierarchical_cluster.fit_predict(true_arr.T)
    c1_inds = [i for i, l in enumerate(labels) if l]
    c2_inds = [i for i, l in enumerate(labels) if not l]
    
    pred_heatmap_out_of_sample = pred_heatmap_out_of_sample[c1_inds+c2_inds][::,c1_inds+c2_inds]
    true_heatmap_out_of_sample = true_heatmap_out_of_sample[c1_inds+c2_inds][::,c1_inds+c2_inds]
    
    out_of_sample_metacorr = np.corrcoef(pred_heatmap_out_of_sample.reshape(-1), true_heatmap_out_of_sample.reshape(-1))[0, 1]
    
    # Save heatmap figures
    plt.figure(figsize=(12, 4))
    
    plt.subplot(131)
    sns.heatmap(avg_true_in_sample_heatmap, cmap='coolwarm', center=0)
    plt.title('True In-Sample Correlation')
    
    plt.subplot(132)
    sns.heatmap(avg_pred_in_sample_heatmap, cmap='coolwarm', center=0)
    plt.title(f'Predicted In-Sample Correlation\nMetacorr: {avg_in_sample_metacorr:.3f}')
    
    plt.subplot(133)
    sns.heatmap(np.abs(avg_true_in_sample_heatmap - avg_pred_in_sample_heatmap), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/in_sample_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(131)
    sns.heatmap(true_heatmap_out_of_sample, cmap='coolwarm', center=0)
    plt.title('True Out-of-Sample Correlation')
    
    plt.subplot(132)
    sns.heatmap(pred_heatmap_out_of_sample, cmap='coolwarm', center=0)
    plt.title(f'Predicted Out-of-Sample Correlation\nMetacorr: {out_of_sample_metacorr:.3f}')
    
    plt.subplot(133)
    sns.heatmap(np.abs(true_heatmap_out_of_sample - pred_heatmap_out_of_sample), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/out_of_sample_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
# Load and process mouse data
def load_mouse_data():
    # Load predictions
    pred_mouse = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/mouse_inferences/fold_{z}.h5ad')
        for z in range(4)
    }
    
    # Load ground truth
    ground_truth_mouse = {
        z: sc.read_h5ad(f'/mounts/stultzlab03/ccomiter/schaf_for_revision052424/data/xenium_cancer/mouse_folds/fold_{z}_st.h5ad')
        for z in range(4)
    }
    
    # Process ground truth data
    for z in ground_truth_mouse:
        sc.pp.log1p(ground_truth_mouse[z])
        ground_truth_mouse[z].X = np.array(ground_truth_mouse[z].X.todense())
    
    # Load scores
    pred_mouse_scores = {
        z: np.load(f'mouse_inferences/fold_{z}_scores.npy') for z in range(4)
    }
    
    return pred_mouse, ground_truth_mouse, pred_mouse_scores



# Calculate mouse metrics
def calculate_mouse_metrics(pred_mouse, ground_truth_mouse, pred_mouse_scores):
    common_mouse = np.intersect1d(ground_truth_mouse[0].var.index, pred_mouse[0].var.index)
    all_pred_genes_mouse = list(pred_mouse[0].var.index)
    
    # Calculate correlations and scores
    mouse_fold_to_in_sample_corrs = {}
    mouse_fold_to_in_sample_scores = {}
    for z in range(4):
        in_sample_corrs = {}
        in_sample_scores = {}
        for g in common_mouse:
            t = np.array(ground_truth_mouse[z][::,g].X.squeeze())
            p = np.array(pred_mouse[z][::,g].X.squeeze())
            in_sample_corrs[g] = np.corrcoef(t, p)[0, 1]
            in_sample_scores[g] = p.std()
        in_sample_scores = dict(zip(all_pred_genes_mouse, pred_mouse_scores[z]))
        mouse_fold_to_in_sample_corrs[z] = in_sample_corrs
        mouse_fold_to_in_sample_scores[z] = in_sample_scores
    
    # Calculate fold proportions
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_mouse[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_mouse[z].shape[0] / total_cells
    
    # Calculate averages
    mouse_avg_in_sample_corrs = {}
    mouse_avg_in_sample_scores = {}
    for g in common_mouse:
        mouse_avg_in_sample_corrs[g] = sum(mouse_fold_to_in_sample_corrs[z][g] * fold_to_prop[z] for z in range(4))
        mouse_avg_in_sample_scores[g] = sum(mouse_fold_to_in_sample_scores[z][g] * fold_to_prop[z] for z in range(4))
    for g in all_pred_genes_mouse:
        mouse_avg_in_sample_scores[g] = sum(mouse_fold_to_in_sample_scores[z][g] * fold_to_prop[z] for z in range(4))
    
    return mouse_avg_in_sample_corrs, mouse_avg_in_sample_scores, common_mouse

# Generate mouse heatmaps
def generate_mouse_heatmaps(pred_mouse, ground_truth_mouse, common_mouse):
    fold_to_pred_mouse_heatmap = {}
    fold_to_true_mouse_heatmap = {}
    fold_to_metacorr = {}
    
    for z in range(4):
        pred_arr = np.array(pred_mouse[z][::,common_mouse].X)
        true_arr = np.array(ground_truth_mouse[z][::,common_mouse].X.squeeze())
        
        pred_heatmap = np.corrcoef(pred_arr, rowvar=0)
        true_heatmap = np.corrcoef(true_arr, rowvar=0)
        
        if not z:
            hierarchical_cluster = AgglomerativeClustering(n_clusters=2, affinity='euclidean', linkage='ward')
            labels = hierarchical_cluster.fit_predict(true_arr.T)
            c1_inds = [i for i, l in enumerate(labels) if l]
            c2_inds = [i for i, l in enumerate(labels) if not l]
        
        pred_heatmap = pred_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        true_heatmap = true_heatmap[c1_inds+c2_inds][::,c1_inds+c2_inds]
        
        fold_to_pred_mouse_heatmap[z] = pred_heatmap
        fold_to_true_mouse_heatmap[z] = true_heatmap
        fold_to_metacorr[z] = np.corrcoef(pred_heatmap.reshape(-1), true_heatmap.reshape(-1))[0, 1]
    
    # Calculate fold proportions
    fold_to_prop = {}
    total_cells = float(sum(ground_truth_mouse[z].shape[0] for z in range(4)))
    for z in range(4):
        fold_to_prop[z] = ground_truth_mouse[z].shape[0] / total_cells
    
    # Calculate average heatmaps
    avg_pred_mouse_heatmap = sum([fold_to_pred_mouse_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_true_mouse_heatmap = sum([fold_to_true_mouse_heatmap[z]*fold_to_prop[z] for z in range(4)])
    avg_mouse_metacorr = sum(fold_to_metacorr[z]*fold_to_prop[z] for z in range(4))
    
    # Generate heatmap figure
    plt.figure(figsize=(12, 4))
    
    plt.subplot(131)
    sns.heatmap(avg_true_mouse_heatmap, cmap='coolwarm', center=0)
    plt.title('True Mouse Correlation')
    
    plt.subplot(132)
    sns.heatmap(avg_pred_mouse_heatmap, cmap='coolwarm', center=0)
    plt.title(f'Predicted Mouse Correlation\nMetacorr: {avg_mouse_metacorr:.3f}')
    
    plt.subplot(133)
    sns.heatmap(np.abs(avg_true_mouse_heatmap - avg_pred_mouse_heatmap), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/mouse_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
# Load and process cell type labels
def load_cell_type_labels():
    # Load broad clusters
    broad_clusters = pd.read_csv('/mounts/stultzlab03/ccomiter/htapp_supervise/new_schaf_experiment_scripts/more_data/xenium/analysis/clustering/gene_expression_kmeans_10_clusters/clusters.csv')
    broad_clusters = broad_clusters.set_index('Barcode')
    
    # Load inferred labels
    xenium_in_sample_fold_to_inferred_labels = {}
    mouse_fold_to_inferred_labels = {}
    for f in range(4):
        xenium_in_sample_fold_to_inferred_labels[f] = np.load(f'cancer_fold_to_new_labels/{f}.npy')
        mouse_fold_to_inferred_labels[f] = np.load(f'mouse_fold_to_new_labels/{f}.npy')
    
    return broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels

# Calculate cell type-specific expression
def calculate_cell_type_expression(pred_in_sample, ground_truth_in_sample, pred_out_of_sample, ground_truth_out_of_sample,
                                 pred_mouse, ground_truth_mouse, common_in_sample, common_mouse, common_out_of_sample,
                                 broad_clusters, xenium_in_sample_fold_to_inferred_labels, mouse_fold_to_inferred_labels):
    # Process out-of-sample data with broad clusters
    ground_truth_out_of_sample_for_cts = ground_truth_out_of_sample[broad_clusters.index]
    pred_out_of_sample_for_cts = pred_out_of_sample[broad_clusters.index]
    ground_truth_out_of_sample_for_cts.obs['broad_clusters'] = np.array(broad_clusters['Cluster'])
    
    # Get true and predicted labels
    ostl = np.array(ground_truth_out_of_sample_for_cts.obs['broad_clusters']).astype(int)
    ospl = np.load('cancer_fold_to_new_labels/out_of_sample.npy').astype(int)
    
    # Process in-sample labels
    isftl = {}
    isfpl = {}
    for i in range(4):
        isfpl[i] = np.load(f'cancer_fold_to_new_labels/{i}.npy').astype(int)
        isftl[i] = np.array(ground_truth_in_sample[i].obs['broad_clusters']).astype(int)
    
    # Process mouse labels
    mftl = {}
    mfpl = {}
    for i in range(4):
        mfpl[i] = np.load(f'mouse_fold_to_new_labels/{i}.npy').astype(int)
        mftl[i] = np.array(ground_truth_mouse[i].obs['broad_clusters']).astype(int)
    
    # Calculate overall means and standard deviations
    num_genes = len(common_in_sample)
    all_means = np.zeros(num_genes)
    total_samples = 0
    for k, v in ground_truth_in_sample.items():
        all_means = all_means + v.shape[0]*(ground_truth_in_sample[k][::,common_in_sample].X.mean(axis=0))
        total_samples += v.shape[0]
    all_means = all_means / total_samples
    
    all_vars = np.zeros(num_genes)
    for k, v in ground_truth_in_sample.items():
        n = v.shape[0]
        all_vars += v.shape[0]*(ground_truth_in_sample[k][::,common_in_sample].X.var(axis=0) + 
                               (ground_truth_in_sample[k][::,common_in_sample].X.mean(axis=0) - all_means)**2)
    all_vars = all_vars / total_samples
    all_stds = (all_vars)**.5
    
    # Calculate cell type-specific expression
    cts = list(range(1, 11, 1))
    
    # Mouse cell type expression
    fold_to_ct_to_mouse_pred = {}
    fold_to_ct_to_mouse_tru = {}
    fold_to_ct_to_mouse_pred_mean = {}
    fold_to_ct_to_mouse_tru_mean = {}
    
    for z in range(4):
        ct_to_mouse_pred = {}
        ct_to_mouse_tru = {}
        for ct in cts:
            ct_to_mouse_pred[ct] = pred_mouse[z][mfpl[z]==ct,common_mouse]
            ct_to_mouse_tru[ct] = ground_truth_mouse[z][mftl[z]==ct,common_mouse]
        fold_to_ct_to_mouse_pred[z] = ct_to_mouse_pred
        fold_to_ct_to_mouse_tru[z] = ct_to_mouse_tru
        
        fold_to_ct_to_mouse_pred_mean[z] = {}
        fold_to_ct_to_mouse_tru_mean[z] = {}
        for ct in cts:
            fold_to_ct_to_mouse_pred_mean[z][ct] = fold_to_ct_to_mouse_pred[z][ct].X.mean(axis=0)
            fold_to_ct_to_mouse_tru_mean[z][ct] = fold_to_ct_to_mouse_tru[z][ct].X.mean(axis=0)
    
    # In-sample cell type expression
    fold_to_ct_to_in_sample_pred = {}
    fold_to_ct_to_in_sample_tru = {}
    fold_to_ct_to_in_sample_pred_mean = {}
    fold_to_ct_to_in_sample_tru_mean = {}
    
    for z in range(4):
        ct_to_in_sample_pred = {}
        ct_to_in_sample_tru = {}
        for ct in cts:
            ct_to_in_sample_pred[ct] = pred_in_sample[z][isfpl[z]==ct,common_in_sample]
            ct_to_in_sample_tru[ct] = ground_truth_in_sample[z][isftl[z]==ct,common_in_sample]
        fold_to_ct_to_in_sample_pred[z] = ct_to_in_sample_pred
        fold_to_ct_to_in_sample_tru[z] = ct_to_in_sample_tru
        
        fold_to_ct_to_in_sample_pred_mean[z] = {}
        fold_to_ct_to_in_sample_tru_mean[z] = {}
        for ct in cts:
            fold_to_ct_to_in_sample_pred_mean[z][ct] = (all_means+(all_stds*fold_to_ct_to_in_sample_pred[z][ct].X)).mean(axis=0)
            fold_to_ct_to_in_sample_tru_mean[z][ct] = (all_means+(all_stds*fold_to_ct_to_in_sample_tru[z][ct].X)).mean(axis=0)
    
    # Out-of-sample cell type expression
    ct_to_out_of_sample_pred = {}
    ct_to_out_of_sample_tru = {}
    ct_to_out_of_sample_pred_mean = {}
    ct_to_out_of_sample_tru_mean = {}
    
    for ct in cts:
        ct_to_out_of_sample_pred[ct] = pred_out_of_sample_for_cts[ospl==ct,common_out_of_sample]
        ct_to_out_of_sample_tru[ct] = ground_truth_out_of_sample_for_cts[ostl==ct,common_out_of_sample]
        ct_to_out_of_sample_pred_mean[ct] = ct_to_out_of_sample_pred[ct].X.mean(axis=0)
        ct_to_out_of_sample_tru_mean[ct] = ct_to_out_of_sample_tru[ct].X.mean(axis=0)
    
    return (fold_to_ct_to_mouse_pred_mean, fold_to_ct_to_mouse_tru_mean,
            fold_to_ct_to_in_sample_pred_mean, fold_to_ct_to_in_sample_tru_mean,
            ct_to_out_of_sample_pred_mean, ct_to_out_of_sample_tru_mean)

# Generate cell type expression heatmaps
def generate_cell_type_heatmaps(fold_to_ct_to_mouse_pred_mean, fold_to_ct_to_mouse_tru_mean,
                               fold_to_ct_to_in_sample_pred_mean, fold_to_ct_to_in_sample_tru_mean,
                               ct_to_out_of_sample_pred_mean, ct_to_out_of_sample_tru_mean):
    cts = list(range(1, 11, 1))
    
    # Generate mouse heatmap
    plt.figure(figsize=(15, 5))
    
    plt.subplot(131)
    mouse_true_matrix = np.array([fold_to_ct_to_mouse_tru_mean[0][ct] for ct in cts])
    sns.heatmap(mouse_true_matrix, cmap='coolwarm', center=0)
    plt.title('True Mouse Cell Type Expression')
    
    plt.subplot(132)
    mouse_pred_matrix = np.array([fold_to_ct_to_mouse_pred_mean[0][ct] for ct in cts])
    sns.heatmap(mouse_pred_matrix, cmap='coolwarm', center=0)
    plt.title('Predicted Mouse Cell Type Expression')
    
    plt.subplot(133)
    sns.heatmap(np.abs(mouse_true_matrix - mouse_pred_matrix), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/mouse_cell_type_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Generate in-sample heatmap
    plt.figure(figsize=(15, 5))
    
    plt.subplot(131)
    in_sample_true_matrix = np.array([fold_to_ct_to_in_sample_tru_mean[0][ct] for ct in cts])
    sns.heatmap(in_sample_true_matrix, cmap='coolwarm', center=0)
    plt.title('True In-Sample Cell Type Expression')
    
    plt.subplot(132)
    in_sample_pred_matrix = np.array([fold_to_ct_to_in_sample_pred_mean[0][ct] for ct in cts])
    sns.heatmap(in_sample_pred_matrix, cmap='coolwarm', center=0)
    plt.title('Predicted In-Sample Cell Type Expression')
    
    plt.subplot(133)
    sns.heatmap(np.abs(in_sample_true_matrix - in_sample_pred_matrix), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/in_sample_cell_type_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Generate out-of-sample heatmap
    plt.figure(figsize=(15, 5))
    
    plt.subplot(131)
    out_sample_true_matrix = np.array([ct_to_out_of_sample_tru_mean[ct] for ct in cts])
    sns.heatmap(out_sample_true_matrix, cmap='coolwarm', center=0)
    plt.title('True Out-of-Sample Cell Type Expression')
    
    plt.subplot(132)
    out_sample_pred_matrix = np.array([ct_to_out_of_sample_pred_mean[ct] for ct in cts])
    sns.heatmap(out_sample_pred_matrix, cmap='coolwarm', center=0)
    plt.title('Predicted Out-of-Sample Cell Type Expression')
    
    plt.subplot(133)
    sns.heatmap(np.abs(out_sample_true_matrix - out_sample_pred_matrix), cmap='coolwarm')
    plt.title('Absolute Difference')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/out_of_sample_cell_type_heatmaps.png', dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
# Calculate cell type-specific correlations
def calculate_cell_type_correlations(pred_in_sample, ground_truth_in_sample, pred_out_of_sample, ground_truth_out_of_sample,
                                   pred_mouse, ground_truth_mouse, common_in_sample, common_mouse, common_out_of_sample,
                                   ground_truth_out_of_sample_for_cts, pred_out_of_sample_for_cts):
    # Helper functions for label remapping
    def do_in_sample_change(r):
        return np.array([9 if int(x) == 10 else x for x in r])
        
    def do_new_sample_change(r):
        return np.array([7 if int(x) == 8 else 1 if int(x) == 5 else x for x in r])
    
    # Update labels
    for fold in range(4):
        ground_truth_in_sample[fold].obs['broad_clusters'] = do_in_sample_change(ground_truth_in_sample[fold].obs['broad_clusters'])
        pred_in_sample[fold].obs['broad_clusters'] = do_in_sample_change(pred_in_sample[fold].obs['broad_clusters'])
    
    ground_truth_out_of_sample_for_cts.obs['broad_clusters'] = do_new_sample_change(ground_truth_out_of_sample_for_cts.obs['broad_clusters'])
    
    # Get unique labels
    labels = []
    for fold in range(4):
        labels = np.union1d(labels, sorted(set(ground_truth_in_sample[fold].obs['broad_clusters'])))
    labels = labels.astype(int)
    
    # Process in-sample data by cell type
    label_to_ground_truth_in_sample = {}
    label_to_pred_in_sample = {}
    label_to_pred_predicted_in_sample = {}
    
    for label in labels:
        # Ground truth
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = ground_truth_in_sample[fold][ground_truth_in_sample[fold].obs['broad_clusters']==label]
        label_to_ground_truth_in_sample[label] = this_dict
        
        # Predictions based on true labels
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = pred_in_sample[fold][ground_truth_in_sample[fold].obs['broad_clusters']==label]
        label_to_pred_in_sample[label] = this_dict
        
        # Predictions based on predicted labels
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = pred_in_sample[fold][pred_in_sample[fold].obs['broad_clusters']==label]
        label_to_pred_predicted_in_sample[label] = this_dict
    
    # Process mouse data by cell type
    mouse_labels = []
    for fold in range(4):
        mouse_labels = np.union1d(mouse_labels, sorted(set(ground_truth_mouse[fold].obs['broad_clusters'])))
    mouse_labels = mouse_labels.astype(int)
    
    label_to_ground_truth_mouse = {}
    label_to_pred_mouse = {}
    label_to_pred_predicted_mouse = {}
    
    for label in mouse_labels:
        # Ground truth
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = ground_truth_mouse[fold][ground_truth_mouse[fold].obs['broad_clusters']==label]
        label_to_ground_truth_mouse[label] = this_dict
        
        # Predictions based on true labels
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = pred_mouse[fold][ground_truth_mouse[fold].obs['broad_clusters']==label]
        label_to_pred_mouse[label] = this_dict
        
        # Predictions based on predicted labels
        this_dict = {}
        for fold in range(4):
            this_dict[fold] = pred_mouse[fold][pred_mouse[fold].obs['broad_clusters']==label]
        label_to_pred_predicted_mouse[label] = this_dict
    
    # Calculate correlations for each cell type
    label_to_fold_to_in_sample_corrs = {}
    label_to_fold_to_mouse_corrs = {}
    label_to_avg_in_sample_corrs = {}
    label_to_avg_mouse_corrs = {}
    
    # In-sample correlations
    for label in label_to_ground_truth_in_sample:
        label_to_fold_to_in_sample_corrs[label] = {}
        label_to_avg_in_sample_corrs[label] = {}
        
        for z in range(4):
            in_sample_corrs = {}
            for g in common_in_sample:
                t = np.array(label_to_ground_truth_in_sample[label][z][::,g].X.squeeze())
                p = np.array(label_to_pred_in_sample[label][z][::,g].X.squeeze())
                in_sample_corrs[g] = (p.std(), np.nan_to_num(np.corrcoef(t, p)[0, 1]))
            label_to_fold_to_in_sample_corrs[label][z] = in_sample_corrs
        
        # Calculate weighted average across folds
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_in_sample[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_in_sample[label][z].shape[0] / total_cells
        
        for g in common_in_sample:
            label_to_avg_in_sample_corrs[label][g] = (
                sum(label_to_fold_to_in_sample_corrs[label][z][g][0] * fold_to_prop[z] for z in range(4)),
                sum(label_to_fold_to_in_sample_corrs[label][z][g][1] * fold_to_prop[z] for z in range(4))
            )
    
    # Mouse correlations
    for label in label_to_ground_truth_mouse:
        label_to_fold_to_mouse_corrs[label] = {}
        label_to_avg_mouse_corrs[label] = {}
        
        for z in range(4):
            mouse_corrs = {}
            for g in common_mouse:
                t = np.array(label_to_ground_truth_mouse[label][z][::,g].X.squeeze())
                p = np.array(label_to_pred_mouse[label][z][::,g].X.squeeze())
                mouse_corrs[g] = (p.std(), np.nan_to_num(np.corrcoef(t, p)[0, 1]))
            label_to_fold_to_mouse_corrs[label][z] = mouse_corrs
        
        # Calculate weighted average across folds
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_mouse[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_mouse[label][z].shape[0] / total_cells
        
        for g in common_mouse:
            label_to_avg_mouse_corrs[label][g] = (
                sum(label_to_fold_to_mouse_corrs[label][z][g][0] * fold_to_prop[z] for z in range(4)),
                sum(label_to_fold_to_mouse_corrs[label][z][g][1] * fold_to_prop[z] for z in range(4))
            )
    
    # Out-of-sample correlations
    out_sample_labels = []
    for fold in range(4):
        out_sample_labels = np.union1d(out_sample_labels, sorted(set(ground_truth_out_of_sample_for_cts.obs['broad_clusters'])))
    out_sample_labels = out_sample_labels.astype(int)
    
    label_to_ground_truth_out_of_sample = {}
    label_to_pred_out_of_sample = {}
    
    for label in out_sample_labels:
        label_to_ground_truth_out_of_sample[label] = ground_truth_out_of_sample_for_cts[ground_truth_out_of_sample_for_cts.obs['broad_clusters']==label]
        label_to_pred_out_of_sample[label] = pred_out_of_sample_for_cts[ground_truth_out_of_sample_for_cts.obs['broad_clusters']==label]
    
    # Calculate out-of-sample correlations
    label_to_corrs_out_of_sample = {}
    for label in label_to_ground_truth_out_of_sample:
        label_to_corrs_out_of_sample[label] = {}
        for g in common_out_of_sample:
            t = np.array(label_to_ground_truth_out_of_sample[label][::,g].X.squeeze())
            p = np.array(label_to_pred_out_of_sample[label][::,g].X.squeeze())
            label_to_corrs_out_of_sample[label][g] = (p.std(), np.nan_to_num(np.corrcoef(t, p)[0, 1]))
    
    # Prepare summary statistics
    dataset_to_ct_corrs_info = {}
    dataset_to_all_info = {
        'mouse': label_to_avg_mouse_corrs,
        'in_sample': label_to_avg_in_sample_corrs,
        'out_of_sample': label_to_corrs_out_of_sample,
    }
    
    for dataset in ['mouse', 'in_sample', 'out_of_sample']:
        all_info = dataset_to_all_info[dataset]
        ct_corrs_info = {}
        for ct, more_stuff in all_info.items():
            this_res = {k: v[-1] for k, v in more_stuff.items()}
            ct_corrs_info[ct] = this_res
        dataset_to_ct_corrs_info[dataset] = ct_corrs_info
    
    return dataset_to_ct_corrs_info

# Generate cell type correlation plots
def generate_cell_type_correlation_plots(dataset_to_ct_corrs_info):
    datasets = ['mouse', 'in_sample', 'out_of_sample']
    fig, axes = plt.subplots(1, len(datasets), figsize=(15, 5))
    
    for i, dataset in enumerate(datasets):
        ct_corrs_info = dataset_to_ct_corrs_info[dataset]
        
        # Calculate mean correlation for each cell type
        ct_means = []
        for ct in sorted(ct_corrs_info.keys()):
            corrs = list(ct_corrs_info[ct].values())
            ct_means.append(np.mean(corrs))
        
        # Plot boxplot
        axes[i].boxplot([list(ct_corrs_info[ct].values()) for ct in sorted(ct_corrs_info.keys())],
                       labels=[f'CT{ct}' for ct in sorted(ct_corrs_info.keys())])
        axes[i].set_title(f'{dataset.replace("_", " ").title()} Correlations')
        axes[i].set_ylabel('Correlation')
        axes[i].set_xlabel('Cell Type')
        axes[i].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/cell_type_correlations.png', dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
# Define cell type names
datasets_to_ct_name = {
    'in_sample': {
        1: 'Invasive Tumor',
        2: 'Ductal Carcinoma In Situ',
        3: 'Stromal Cells',
        4: 'T Cells',
        5: 'Macrophages',
        6: 'Endothelial Cells',
        7: 'Myoepithelial Cells',
        8: 'NK Cells',
        9: 'B Cells',
        10: 'Mature B Cells',
    },
    'out_of_sample': {
        1: 'T Cells',
        2: 'Stromal Cells',
        3: 'Macrophages',
        4: 'Invasive Tumor',
        5: 'T CD8 Cells',
        6: 'Endothelial Cells',
        7: 'Myoepithelial Cells',
        8: 'Epithelial Cells',
        9: 'Ductal Carcinoma In Situ',
        10: 'B Cells',
    },
    'mouse': {
        1: 'Bone/Cartilage',
        2: 'Skeletal Muscle',
        3: 'Blood Vessel',
        4: 'Brain',
        5: 'Lung',
        6: 'Liver',
        7: 'Skin',
        8: 'GI Track',
        9: 'Lymphatic Vessels',
        10: 'Kidney',
    },
}

ds_to_name = {
    'in_sample': 'In-Sample Xenium MBC',
    'out_of_sample': 'New-Sample Xenium MBC',
    'mouse': 'In-Sample Mouse',
}

# Generate cell type heterogeneity plots
def generate_cell_type_heterogeneity_plots(dataset_to_ct_corrs_info):
    num_celltypes = 10
    fig, axs = plt.subplots(num_celltypes, len(dataset_to_ct_corrs_info), figsize=(18, 24))
    
    for col, (dataset_name, celltypes) in enumerate(dataset_to_ct_corrs_info.items()):
        ct_to_name = datasets_to_ct_name[dataset_name]
        for row, (celltype, values) in enumerate(celltypes.items()):
            values = list(values.values())
            axs[row][col].hist(values, rwidth=.7, bins=np.arange(-.2, 1.01, .1))
            mean_value = np.mean(values)
            
            # Set title
            if not row:
                axs[row][col].set_title(f'{ds_to_name[dataset_name]}\n{ct_to_name[celltype]}', pad=15)
            else:
                axs[row][col].set_title(f'{ct_to_name[celltype]}', pad=15)
            
            axs[row][col].set_xlabel('Spatial Correlation', size='x-large')
            axs[row][col].set_ylabel('Number of Genes', size='x-large')
            plt.setp(axs[row][col].spines.values(), linewidth=2)
            
            # Display mean value
            axs[row][col].text(0.95, 0.95, f'Avg. = {mean_value:.2f}',
                             transform=axs[row][col].transAxes,
                             fontsize=15, ha='right', va='top',
                             bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
        
        # Hide axes for missing cell types
        if len(celltypes) < num_celltypes:
            axs[-1, col].axis('off')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/celltype_heterogeneity.png', dpi=400, transparent=True)
    plt.close()

# Calculate metacorrelations for each cell type
def calculate_cell_type_metacorrelations(label_to_ground_truth_in_sample, label_to_pred_in_sample,
                                       label_to_ground_truth_mouse, label_to_pred_mouse,
                                       label_to_ground_truth_out_of_sample, label_to_pred_out_of_sample,
                                       common_in_sample, common_mouse, common_out_of_sample):
    # Calculate in-sample metacorrelations
    label_to_in_sample_plot_metacorr = {}
    for label in label_to_ground_truth_in_sample:
        fold_to_metacorr = {}
        for z in range(4):
            pred_arr = np.array(label_to_pred_in_sample[label][z][::,common_in_sample].X)
            true_arr = np.array(label_to_ground_truth_in_sample[label][z][::,common_in_sample].X.squeeze())
            pred_heatmap = np.nan_to_num(np.corrcoef(pred_arr, rowvar=0))
            true_heatmap = np.nan_to_num(np.corrcoef(true_arr, rowvar=0))
            fold_to_metacorr[z] = np.nan_to_num(np.corrcoef(pred_heatmap.reshape(-1), true_heatmap.reshape(-1))[0, 1])
        
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_in_sample[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_in_sample[label][z].shape[0] / total_cells
        avg_in_sample_metacorr = sum(fold_to_metacorr[z]*fold_to_prop[z] for z in range(4))
        label_to_in_sample_plot_metacorr[label] = avg_in_sample_metacorr
    
    # Calculate mouse metacorrelations
    label_to_mouse_plot_metacorr = {}
    for label in label_to_ground_truth_mouse:
        fold_to_metacorr = {}
        for z in range(4):
            pred_arr = np.array(label_to_pred_mouse[label][z][::,common_mouse].X)
            true_arr = np.array(label_to_ground_truth_mouse[label][z][::,common_mouse].X.squeeze())
            pred_heatmap = np.nan_to_num(np.corrcoef(pred_arr, rowvar=0))
            true_heatmap = np.nan_to_num(np.corrcoef(true_arr, rowvar=0))
            fold_to_metacorr[z] = np.nan_to_num(np.corrcoef(pred_heatmap.reshape(-1), true_heatmap.reshape(-1))[0, 1])
        
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_mouse[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_mouse[label][z].shape[0] / total_cells
        avg_mouse_metacorr = sum(fold_to_metacorr[z]*fold_to_prop[z] for z in range(4))
        label_to_mouse_plot_metacorr[label] = avg_mouse_metacorr
    
    # Calculate out-of-sample metacorrelations
    label_to_out_of_sample_plot_metacorr = {}
    for label in label_to_ground_truth_out_of_sample:
        pred_arr = np.array(label_to_pred_out_of_sample[label][::,common_out_of_sample].X)
        true_arr = np.array(label_to_ground_truth_out_of_sample[label][::, common_out_of_sample].X)
        pred_heatmap_out_of_sample = np.nan_to_num(np.corrcoef(pred_arr, rowvar=0))
        true_heatmap_out_of_sample = np.nan_to_num(np.corrcoef(true_arr, rowvar=0))
        out_of_sample_metacorr = np.nan_to_num(np.corrcoef(pred_heatmap_out_of_sample.reshape(-1), true_heatmap_out_of_sample.reshape(-1))[0, 1])
        label_to_out_of_sample_plot_metacorr[label] = out_of_sample_metacorr
    
    return label_to_in_sample_plot_metacorr, label_to_mouse_plot_metacorr, label_to_out_of_sample_plot_metacorr

# Generate metacorrelation plots
def generate_metacorrelation_plots(label_to_in_sample_plot_metacorr, label_to_mouse_plot_metacorr, label_to_out_of_sample_plot_metacorr):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot in-sample metacorrelations
    axes[0].bar(range(len(label_to_in_sample_plot_metacorr)), 
                [label_to_in_sample_plot_metacorr[i] for i in sorted(label_to_in_sample_plot_metacorr.keys())])
    axes[0].set_title('In-Sample Metacorrelations')
    axes[0].set_xticks(range(len(label_to_in_sample_plot_metacorr)))
    axes[0].set_xticklabels([datasets_to_ct_name['in_sample'][i] for i in sorted(label_to_in_sample_plot_metacorr.keys())],
                           rotation=45, ha='right')
    axes[0].set_ylabel('Metacorrelation')
    
    # Plot mouse metacorrelations
    axes[1].bar(range(len(label_to_mouse_plot_metacorr)),
                [label_to_mouse_plot_metacorr[i] for i in sorted(label_to_mouse_plot_metacorr.keys())])
    axes[1].set_title('Mouse Metacorrelations')
    axes[1].set_xticks(range(len(label_to_mouse_plot_metacorr)))
    axes[1].set_xticklabels([datasets_to_ct_name['mouse'][i] for i in sorted(label_to_mouse_plot_metacorr.keys())],
                           rotation=45, ha='right')
    axes[1].set_ylabel('Metacorrelation')
    
    # Plot out-of-sample metacorrelations
    axes[2].bar(range(len(label_to_out_of_sample_plot_metacorr)),
                [label_to_out_of_sample_plot_metacorr[i] for i in sorted(label_to_out_of_sample_plot_metacorr.keys())])
    axes[2].set_title('Out-of-Sample Metacorrelations')
    axes[2].set_xticks(range(len(label_to_out_of_sample_plot_metacorr)))
    axes[2].set_xticklabels([datasets_to_ct_name['out_of_sample'][i] for i in sorted(label_to_out_of_sample_plot_metacorr.keys())],
                           rotation=45, ha='right')
    axes[2].set_ylabel('Metacorrelation')
    
    plt.tight_layout()
    plt.savefig('final_figures_schaf_revision_pngs/metacorrelations.png', dpi=400, bbox_inches='tight')
    plt.close()


In [None]:
# Calculate cell type heterogeneity
def calculate_cell_type_heterogeneity(label_to_ground_truth_in_sample, label_to_pred_in_sample,
                                    label_to_ground_truth_mouse, label_to_pred_mouse,
                                    label_to_ground_truth_out_of_sample, label_to_pred_out_of_sample,
                                    common_in_sample, common_mouse, common_out_of_sample):
    def get_hetero(cells, genes):
        res = {}
        for g in genes:
            try:
                to_see = np.array(cells[::,g].X).squeeze()
                to_see = (to_see - to_see.max()) / (to_see.max() - to_see.min())
                res[g] = np.nan_to_num(to_see.std())
            except:
                res[g] = 0
        return res
    
    # Calculate in-sample heterogeneity
    label_to_fold_to_in_sample_heteros = {}
    label_to_avg_in_sample_heteros = {}
    
    for label in label_to_ground_truth_in_sample:
        label_to_fold_to_in_sample_heteros[label] = {}
        label_to_avg_in_sample_heteros[label] = {}
        
        for z in range(4):
            label_to_fold_to_in_sample_heteros[label][z] = (
                get_hetero(label_to_pred_in_sample[label][z], common_in_sample),
                get_hetero(label_to_ground_truth_in_sample[label][z], common_in_sample)
            )
        
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_in_sample[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_in_sample[label][z].shape[0] / total_cells
        
        for g in common_in_sample:
            label_to_avg_in_sample_heteros[label][g] = (
                sum(label_to_fold_to_in_sample_heteros[label][z][0][g] * fold_to_prop[z] for z in range(4)),
                sum(label_to_fold_to_in_sample_heteros[label][z][1][g] * fold_to_prop[z] for z in range(4))
            )
    
    # Calculate mouse heterogeneity
    label_to_fold_to_mouse_heteros = {}
    label_to_avg_mouse_heteros = {}
    
    for label in label_to_ground_truth_mouse:
        label_to_fold_to_mouse_heteros[label] = {}
        label_to_avg_mouse_heteros[label] = {}
        
        for z in range(4):
            label_to_fold_to_mouse_heteros[label][z] = (
                get_hetero(label_to_pred_mouse[label][z], common_mouse),
                get_hetero(label_to_ground_truth_mouse[label][z], common_mouse)
            )
        
        fold_to_prop = {}
        total_cells = float(sum(label_to_ground_truth_mouse[label][z].shape[0] for z in range(4)))
        for z in range(4):
            fold_to_prop[z] = label_to_ground_truth_mouse[label][z].shape[0] / total_cells
        
        for g in common_mouse:
            label_to_avg_mouse_heteros[label][g] = (
                sum(label_to_fold_to_mouse_heteros[label][z][0][g] * fold_to_prop[z] for z in range(4)),
                sum(label_to_fold_to_mouse_heteros[label][z][1][g] * fold_to_prop[z] for z in range(4))
            )
    
    # Calculate out-of-sample heterogeneity
    label_to_heteros_out_of_sample = {}
    for label in label_to_ground_truth_out_of_sample:
        label_to_heteros_out_of_sample[label] = (
            get_hetero(label_to_pred_out_of_sample[label], common_out_of_sample),
            get_hetero(label_to_ground_truth_out_of_sample[label], common_out_of_sample)
        )
    
    return label_to_avg_in_sample_heteros, label_to_avg_mouse_heteros, label_to_heteros_out_of_sample

# Generate histology split figures
def generate_histology_split_figures(in_sample_fold_to_hist, mouse_fold_to_hist, in_sample_hist, out_of_sample):
    # In-sample histology split
    fig_name = 'in_sample_hist_split'
    f, axs = plt.subplots(2, 2, figsize=(20, 14))
    
    fold_to_row_col = {
        0: (1, 0),
        1: (0, 1),
        2: (0, 0),
        3: (1, 1),
    }
    
    f.suptitle("In-Sample Xenium MBC", fontsize="xx-large")
    for fold in range(4):
        row, col = fold_to_row_col[fold]
        axs[row, col].axis('off')
        if row == 0 and col == 1:
            to_show = in_sample_fold_to_hist[fold][::,::,0]
            axs[row, col].imshow(to_show.max() - to_show, cmap=mpl.cm.Blues, aspect='auto')
            axs[row, col].set_title(f'Evaluation', size='xx-large', c='blue')
        else:
            to_show = in_sample_fold_to_hist[fold][::,::,0]
            axs[row, col].imshow(to_show.max() - to_show, cmap=mpl.cm.Reds, aspect='auto')
            axs[row, col].set_title(f'Training', size='xx-large', c='red')
        axs[row, col].invert_yaxis()
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
    plt.close()
    
    # Mouse histology split
    fig_name = 'mouse_hist_split'
    f, axs = plt.subplots(2, 2, figsize=(20, 12.5))
    
    fold_to_row_col = {
        0: (0, 0),
        1: (1, 1),
        2: (1, 0),
        3: (0, 1),
    }
    
    f.suptitle("In-Sample Mouse", fontsize="xx-large")
    for fold in range(4):
        row, col = fold_to_row_col[fold]
        axs[row, col].axis('off')
        if row == 0 and col == 1:
            to_show = mouse_fold_to_hist[fold][::,::,0]
            axs[row, col].imshow(to_show.max() - to_show, cmap=mpl.cm.Blues, aspect='auto')
            axs[row, col].set_title(f'Evaluation', size='xx-large', c='blue')
        else:
            to_show = mouse_fold_to_hist[fold][::,::,0]
            axs[row, col].imshow(to_show.max() - to_show, cmap=mpl.cm.Reds, aspect='auto')
            axs[row, col].set_title(f'Training', size='xx-large', c='red')
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
    plt.close()
    
    # New sample histology split
    fig_name = 'new_sample_hist_split_train'
    f, axs = plt.subplots(1, 1, figsize=(20, 14))
    f.suptitle("New-Sample Xenium MBC", fontsize="xx-large")
    axs.axis('off')
    
    to_show = in_sample_hist[::-1][::,::,0]
    axs.imshow(to_show.max() - to_show, cmap=mpl.cm.Reds)
    axs.set_title(f'Training', size='xx-large', c='red')
    plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
    plt.close()
    
    fig_name = 'new_sample_hist_split_eval'
    f, axs = plt.subplots(1, 1, figsize=(13, 20))
    axs.axis('off')
    
    to_show = out_of_sample[::-1][::,::,0]
    axs.imshow(to_show.max() - to_show, cmap=mpl.cm.Blues)
    axs.set_title(f'Evaluation', size='xx-large', c='blue')
    plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
    plt.close()

# Generate correlation histograms
def generate_correlation_histograms(mouse_avg_in_sample_corrs, mouse_train_genes):
    fig_name = 'all_gene_corrs'
    fig, ax = plt.subplots(1, 3, figsize=(13.5, 5.4))
    
    # Mouse correlations
    ax[0].hist(
        [
            [mouse_avg_in_sample_corrs[g] for g in mouse_avg_in_sample_corrs if g in mouse_train_genes],
            [mouse_avg_in_sample_corrs[g] for g in mouse_avg_in_sample_corrs if g not in mouse_train_genes],
        ],
        color=['orange', 'blue'],
        label=[f'Training Genes Avg. = {.445}', f'Hold-Out Genes Avg. = {.377}'],
        alpha=.5, rwidth=.7, bins=np.arange(0, 1.01, .1)
    )
    ax[0].set_title('In-Sample Mouse', loc='left')
    ax[0].set_ylabel('Number of Genes')
    ax[0].set_xlabel('Spatial Correlation Coefficient')
    ax[0].set_xticks(ticks=np.arange(0, 1.01, .2), labels=np.round(np.arange(0, 1.01, .2), 2))
    ax[0].set_yticks(ticks=np.arange(0, 51, 10), labels=np.arange(0, 51, 10))
    ax[0].legend()
    
    plt.tight_layout()
    plt.savefig(f'final_figures_schaf_revision_pngs/{fig_name}.png', dpi=400, transparent=True)
    plt.close()
