In [None]:
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
import pathlib 
import sys
import joblib
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import seaborn as sns
import matplotlib.pyplot as plt

script_directory = pathlib.Path("../utils/").resolve()
sys.path.insert(0, str(script_directory))
from data_loader import load_model_data
from model_utils import extract_latent_dims

In [2]:
# Define the location of the saved models and output directory for results
model_save_dir = pathlib.Path("../4.gene-expression-signatures/saved_models")
output_dir = pathlib.Path("results")
output_dir.mkdir(parents=True, exist_ok=True)

In [3]:
data_directory = pathlib.Path("../0.data-download/data").resolve()
dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()

In [4]:
# Load metadata
metadata_df_dir = pathlib.Path("../0.data-download/data/metadata_df.parquet")
metadata = pd.read_parquet(metadata_df_dir)
print(metadata.shape)

#Load dependency data
dependency_df, gene_dict_df = load_model_data(dependency_file, gene_dict_file)
dependency_df.head()


(958, 3)
(1150, 18444)


Unnamed: 0,ModelID,ATP6V1B2,PGP,LONP1,PKN2,TXNDC17,TACC3,LARS2,ACD,PBRM1,...,MAT2A,RPAP3,SRFBP1,SELENOI,ZSWIM6,METTL17,SOX17,GARS1,POP1,PSMC6
0,ACH-000001,-0.722523,-0.379114,-1.079895,-0.304835,-0.19277,-0.804104,-0.643955,-0.095769,-0.258745,...,-1.609151,-0.96057,-0.655273,0.060792,-0.112401,-1.288937,-0.124743,-1.239178,-1.173456,-1.856352
1,ACH-000004,-2.612545,-0.337075,-1.900047,-0.470642,-0.188651,-0.531188,-0.250625,-0.604683,-0.422913,...,-2.037072,-0.865169,-0.440801,-0.088064,-0.072403,-0.127299,0.067915,-1.040483,-1.350506,-2.20755
2,ACH-000005,-2.434407,-0.00219,-1.250451,-0.155679,-0.27172,-0.241461,-0.317184,-0.511741,-0.014821,...,-2.151609,-0.947788,-0.336573,-0.041227,0.062079,-0.090011,0.126301,-1.122754,-1.184203,-1.329586
3,ACH-000007,-1.926781,0.315856,-1.095203,0.031807,-0.223783,-0.323253,-0.706735,-0.572485,-0.37608,...,-1.795881,-0.402845,-0.708576,-0.028369,0.002692,-0.400043,-0.084501,-1.211087,-1.255302,-1.35177
4,ACH-000009,-1.449962,0.054887,-1.307617,-0.200971,-0.202928,-0.615397,-0.271526,-0.212804,-0.189308,...,-1.831,-0.759129,-0.457776,-0.018156,0.188474,-0.483733,0.037198,-1.142514,-1.367033,-1.45041


In [None]:
# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Apply the scaler to the numeric columns
dependency_df[dependency_df.select_dtypes(include=["number"]).columns] = scaler.fit_transform(
    dependency_df.select_dtypes(include=["number"])
)

In [None]:
train_and_test_subbed_dir = pathlib.Path("../0.data-download/data/train_and_test_subbed.parquet")
train_and_test_subbed = pd.read_parquet(train_and_test_subbed_dir)

train_and_test_subbed[train_and_test_subbed.select_dtypes(include=["number"]).columns] = scaler.fit_transform(
    train_and_test_subbed.select_dtypes(include=["number"])
)

# Convert DataFrame to NumPy and then Tensor
train_test_array = train_and_test_subbed.to_numpy()
train_test_tensor = torch.tensor(train_test_array, dtype=torch.float32)

#Create TensorDataset and DataLoader
tensor_dataset = TensorDataset(train_test_tensor)
train_and_test_subbed_loader = DataLoader(tensor_dataset, batch_size=32, shuffle=False)

In [7]:
results = []

for model_file in model_save_dir.glob("*.joblib"):
    model_file_name = model_file.stem
    try:
        parts = model_file_name.split("_")
        model_name = parts[0]
        num_components = int(parts[3])  # total latent dimensions
        init = int(parts[7])  # initialization value
        seed = int(parts[9])  # seed value
    except (IndexError, ValueError):
        print(f"Skipping file {model_file} due to unexpected filename format.")
        continue

    # Load the model
    print(f"Loading model from {model_file}")
    try:
        model = joblib.load(model_file)
    except Exception as e:
        print(f"Failed to load model from {model_file}: {e}")
        continue
        
    # Extract z, original input, and reconstruction
    latent_df, original_data, reconstructed_data = extract_latent_dims(
        model_name, model, dependency_df, train_and_test_subbed_loader, metadata
    )
    print(original_data)
    print(reconstructed_data)

    # Convert to torch tensors
    original_tensor = torch.tensor(original_data, dtype=torch.float32)
    reconstructed_tensor = torch.tensor(reconstructed_data, dtype=torch.float32)

    # Clamp reconstructions to avoid log(0)
    reconstructed_tensor = torch.clamp(reconstructed_tensor, min=1e-7, max=1 - 1e-7)

    # Compute BCE loss across all elements
    mse = F.mse_loss(reconstructed_tensor, original_tensor, reduction='mean')

    results.append({
        "model": model_name,
        "latent_dim": num_components,
        "init": init,
        "mse": mse.item()
    })
    print("Original min/max:", original_data.min(), original_data.max())
    print("Reconstructed min/max:", reconstructed_data.min(), reconstructed_data.max())

# Convert results to DataFrame
recon_df = pd.DataFrame(results)
print(recon_df)

Empty DataFrame
Columns: []
Index: []


In [None]:
# Set global font sizes
plt.rcParams.update({
    "font.size": 16,          # Base font size
    "axes.titlesize": 18,     # Facet title
    "axes.labelsize": 16,     # Axis labels
    "xtick.labelsize": 7,    # X tick labels
    "ytick.labelsize": 14,    # Y tick labels
    "legend.fontsize": 14,    # Legend text
    "legend.title_fontsize": 16  # Legend title
})

# Convert latent_dim to a categorical type for equal spacing
recon_df['latent_dim'] = recon_df['latent_dim'].astype(str)

# Convert latent_dim to ordered categories for within_model_df
dimension_order = sorted(recon_df["latent_dim"].unique(), key=int)

recon_df["latent_dim"] = pd.Categorical(
    recon_df["latent_dim"], categories=dimension_order, ordered=True
)

# FacetGrid
g = sns.FacetGrid(
    recon_df,
    col="model",
    col_wrap=3,
    height=4,
    sharey=True
)

g.map_dataframe(
    sns.scatterplot,
    x="latent_dim", 
    y="mse",
    hue="init",
    style="init"
)

# Rotate x-axis labels
for ax in g.axes.flatten():
    for label in ax.get_xticklabels():
        label.set_rotation(90)

# Format
g.set_titles(col_template="{col_name}")
g.set_axis_labels("Latent Dimension", "Reconstruction MSE")
g.add_legend(title="Init")
g.tight_layout()
plt.subplots_adjust(top=0.9, wspace=0.3)

KeyError: 'latent_dim'