# Contrastive Learning for Debiasing Genomic Embeddings

## Overview

This notebook implements supervised contrastive learning to improve pre-computed genomic embeddings by reducing confounding effects from ancestry and technical variables while preserving discriminative signal for the outcome of interest.

## Background

Genomic data often contains confounding factors (ancestry, sequencing batch, read depth, etc.) that can lead to spurious associations in downstream analyses. Traditional approaches like linear regression adjustment or propensity score matching have limitations. Contrastive learning offers an alternative: learn an embedding space where:

1. **Samples with the same phenotype** are pulled together (positive pairs)
2. **Samples with different phenotypes** but matched confounders are pushed apart (negative pairs)
3. **Confounder effects** are minimized through the matched design

## Study Design

We use a **matched case-control design** with:
- Cases: Samples with `is_positive = 1`
- Controls: 4 matched controls per case (matched on confounders)
- Each case-control group is identified by `case_matched` column

### Contrastive Pairs Definition

- **Positive pairs**: Two samples with the same `is_positive` label AND same `case_matched` group
- **Negative pairs**: Two samples with different `is_positive` labels AND same `case_matched` group

This ensures we're learning to discriminate between cases and controls while being invariant to the confounders they share.

## Training Strategy

1. Split data by `case_matched` groups (20% validation) to prevent leakage
2. Train a projection network using supervised contrastive loss
3. Monitor progress using cluster separation and logistic regression AUC
4. Compare embeddings before and after training using PCA visualizations

## Expected Outcome

The trained embeddings should:
- ✓ Improve separation between cases and controls (higher AUC)
- ✓ Reduce correlation with confounding variables
- ✓ Maintain or improve downstream predictive performance

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, calinski_harabasz_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
from IPython.display import display

warnings.filterwarnings("ignore")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Loading

Load the parquet file containing pre-computed embeddings and metadata. The data should include:
- **Embeddings**: Pre-computed feature vectors (e.g., from a variant autoencoder, PRS, or other genomic model)
- **is_positive**: Binary outcome label (1 = case, 0 = control)
- **case_matched**: Group ID linking each case to its matched controls
- **Confounders**: Variables to control for (ancestry, batch, read depth, etc.)

In [None]:
USE_REAL_DATA = False

if USE_REAL_DATA:
    EMBEDDING_COLS = [
        "SIRPG",
        "BACH2",
    ]
    MATCH_CASE_COL = "matched_case_kitid"
else:
    EMBEDDING_COLS = [
        "α-syn",
        "β-k3",
        "γ-DR",
        "ε-BP",
        "λ-trans7",
        "ζ-lig",
        "ω-deg",
        "κ-pol",
        "η-rec",
        "δ-act2",
        "θ-phos",
        "μ-chan",
        "ξ-meth",
        "π-ox",
        "ρ-hydr",
        "σ-cat",
        "τ-tub",
        "υ-reg",
        "φ-fold",
        "χ-chap",
        "ψ-sens",
    ][:5]
    MATCH_CASE_COL = "case_id"

LABEL_COL = "is_positive"

data_path = "/home/ext_meehl_joshua_mayo_edu/strand_cohort_eda/genomic/evals/data/embeddings/t1d/t1d_v1_emb_matrix_named.parquet"
meta_path = "/home/ext_meehl_joshua_mayo_edu/strand_cohort_eda/genomic/evals/data/gsm/t1d/matched_has_diabetes1_v4_modified.csv"
cofound_path = "/home/ext_meehl_joshua_mayo_edu/strand_cohort_eda/genomic/datasets/golden_benchmark/data/cohort_design/master_cohort.csv"

synth_path = "/root/pre-phd-genomics/02_debiasing/data/synthetic_embeddings.parquet"

In [None]:
if USE_REAL_DATA:
    df_emb = pd.read_parquet(data_path)
    print(df_emb.shape)

    df_meta = pd.read_csv(meta_path)
    df_meta["sample_id"] = df_meta["tap_kitid"]
    mask = df_meta["sample_id"].isin(df_emb["sample_id"].unique())
    df_meta = df_meta[mask]
    print(df_meta.shape)
    # df_meta.head(10)

    df_conf = pd.read_csv(cofound_path)
    df_conf["sample_id"] = df_conf["tap_kitid"]
    mask = df_conf["sample_id"].isin(df_emb["sample_id"].unique())
    df_conf = df_conf[mask]
    print(df_conf.shape)
    # df_conf.head(10)

    cols = ["sample_id"] + EMBEDDING_COLS
    df = df_emb[cols].copy()
    df = df.merge(df_meta, on="sample_id", how="left")
    df = df.merge(df_conf, on="sample_id", how="left")
else:
    df = pd.read_parquet(synth_path)
    print(df.shape)
df.head(6)

In [None]:
# Extract embedding columns
# Each column contains a numpy array of shape (gene_dim,)
print(f"Found {len(EMBEDDING_COLS)} gene embedding columns: {EMBEDDING_COLS}")

# Extract embeddings properly from dataframe
# Each gene column contains arrays, we need to stack them into (n_samples, gene_dim, n_genes)
n_samples = len(df)
n_genes = len(EMBEDDING_COLS)

# Get gene_dim from first array
GENE_DIM = df[EMBEDDING_COLS[0]].iloc[0].shape[0]
print(f"Gene dimension: {GENE_DIM}")

# Create 3D tensor: (n_samples, gene_dim, n_genes)
embeddings = np.zeros((n_samples, GENE_DIM, n_genes))

for gene_idx, gene_col in enumerate(EMBEDDING_COLS):
    # Stack all arrays for this gene
    embeddings[:, :, gene_idx] = np.vstack(df[gene_col].values)

print(f"Embeddings shape: {embeddings.shape} (n_samples, gene_dim, n_genes)")
print(f"  - {n_samples} samples")
print(f"  - {GENE_DIM} features per gene")
print(f"  - {n_genes} genes")

In [None]:
# Verify the embeddings shape
print(f"Embeddings successfully reshaped to: {embeddings.shape}")
print(f"First sample, first gene, first 5 features: {embeddings[0, :5, 0]}")

In [None]:
# Embeddings are now ready for the genewise contrastive learning architecture
# Shape: (n_samples, gene_dim, n_genes) = ({embeddings.shape[0]}, {embeddings.shape[1]}, {embeddings.shape[2]})
print(f"✓ Embeddings ready for genewise architecture")
print(f"  Total features: {GENE_DIM * n_genes}")
print(f"  Format: Each of {n_genes} genes has {GENE_DIM} features")

In [None]:
# Perform PCA on original embeddings (flatten for PCA)
embeddings_flat = embeddings.reshape(embeddings.shape[0], -1)  # (n, gene_dim * n_genes)
pca_original = PCA(n_components=2)
pca_coords_original = pca_original.fit_transform(embeddings_flat)

# Add PCA coordinates to dataframe
df["PC1_original"] = pca_coords_original[:, 0]
df["PC2_original"] = pca_coords_original[:, 1]

print(f"Explained variance ratio: {pca_original.explained_variance_ratio_}")
print(f"Total variance explained: {pca_original.explained_variance_ratio_.sum():.3f}")

In [None]:
# Get unique case_matched groups (each represents a case + its 4 matched controls)
unique_groups = df["case_matched"].unique()
print(f"Total unique case groups: {len(unique_groups)}")

# Split groups into train/val (20% validation)
train_groups, val_groups = train_test_split(
    unique_groups, test_size=0.2, random_state=42
)
print(f"Train groups: {len(train_groups)}, Val groups: {len(val_groups)}")

# Create train and validation masks
train_mask = df["case_matched"].isin(train_groups)
val_mask = df["case_matched"].isin(val_groups)

df_train = df[train_mask].copy()
df_val = df[val_mask].copy()

print(f"\nTrain set: {len(df_train)} samples ({df_train[LABEL_COL].sum()} positive)")
print(f"Val set: {len(df_val)} samples ({df_val[LABEL_COL].sum()} positive)")

# Extract embeddings for train and val (already reshaped to 3D)
X_train = embeddings[train_mask]  # (n_train, gene_dim, n_genes)
y_train = df_train[LABEL_COL].values
X_val = embeddings[val_mask]  # (n_val, gene_dim, n_genes)
y_val = df_val[LABEL_COL].values

print(f"\nX_train shape: {X_train.shape}")
print(f"X_val shape: {X_val.shape}")

## 4. PyTorch Contrastive Learning Architecture

### GenewiseContrastiveProjector
A gene-aware projection network that works with **any number of genes**:
- **Input**: (batch_size, gene_dim, n_genes) where gene_dim is typically 1024
- **Shared gene projection**: Applies same transformation to each gene's features
- **Pooling**: Aggregates across genes (mean pooling or attention)
- **Final projection**: Maps pooled features to output space
- **L2 Normalization**: Output embeddings lie on unit hypersphere

**Key advantage**: Train once, works for any gene set (100 genes, 500 genes, etc.)

### ContrastiveLoss
Supervised contrastive loss adapted for matched case-control design:
- Uses temperature-scaled cosine similarity
- Only considers pairs within the same `case_matched` group
- Pulls together samples with same label, pushes apart samples with different labels
- Automatically handles variable numbers of positive pairs per sample

In [None]:
class GenewiseContrastiveProjector(nn.Module):
    """
    Gene-aware projection network for contrastive learning.
    Works with any number of genes by applying shared projection per gene,
    then pooling across genes.
    """

    def __init__(
        self, gene_dim=1024, hidden_dim=256, output_dim=128, dropout=0.1, pooling="mean"
    ):
        super(GenewiseContrastiveProjector, self).__init__()

        self.pooling = pooling

        # Shared projection applied to each gene independently
        self.gene_projector = nn.Sequential(
            nn.Linear(gene_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Attention pooling (optional)
        if pooling == "attention":
            self.attention = nn.Linear(hidden_dim, 1)

        # Final projection after pooling
        self.final_projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        """
        Args:
            x: shape (batch_size, gene_dim, num_genes)
        Returns:
            embeddings: shape (batch_size, output_dim)
        """
        batch_size, gene_dim, num_genes = x.shape

        # Reshape to (batch_size * num_genes, gene_dim)
        x = x.permute(0, 2, 1)  # (batch, genes, features)
        x = x.reshape(batch_size * num_genes, gene_dim)

        # Apply shared gene projection
        x = self.gene_projector(x)  # (batch*genes, hidden_dim)

        # Reshape back to (batch_size, num_genes, hidden_dim)
        x = x.reshape(batch_size, num_genes, -1)

        # Pool across genes
        if self.pooling == "mean":
            x = x.mean(dim=1)  # (batch, hidden_dim)
        elif self.pooling == "max":
            x = x.max(dim=1)[0]  # (batch, hidden_dim)
        elif self.pooling == "attention":
            weights = torch.softmax(self.attention(x), dim=1)  # (batch, genes, 1)
            x = (x * weights).sum(dim=1)  # (batch, hidden_dim)

        # Final projection
        x = self.final_projector(x)  # (batch, output_dim)

        return F.normalize(x, dim=1)


class ContrastiveLoss(nn.Module):
    """
    Supervised contrastive loss for matched case-control design.
    Positive pairs: same is_positive label AND same case_matched group
    Negative pairs: different is_positive label AND same case_matched group (matched controls)
    """

    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels, case_matched):
        """
        Args:
            features: normalized embeddings, shape (batch_size, embed_dim)
            labels: is_positive labels, shape (batch_size,)
            case_matched: case_matched group IDs, shape (batch_size,)
        """
        device = features.device
        batch_size = features.shape[0]

        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / self.temperature

        # Create masks for positive and negative pairs
        labels = labels.contiguous().view(-1, 1)
        case_matched = case_matched.contiguous().view(-1, 1)

        # Positive mask: same label AND same case_matched group (but not same sample)
        label_mask = torch.eq(labels, labels.T).float().to(device)
        case_mask = torch.eq(case_matched, case_matched.T).float().to(device)
        positive_mask = label_mask * case_mask
        positive_mask.fill_diagonal_(0)  # Exclude self-comparisons

        # Negative mask: different label AND same case_matched group
        negative_mask = (1 - label_mask) * case_mask

        # For numerical stability
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()

        # Compute log probabilities
        exp_logits = torch.exp(logits) * (
            1 - torch.eye(batch_size).to(device)
        )  # Exclude diagonal

        # Only consider negatives from same case_matched group
        log_prob = logits - torch.log(
            torch.sum(
                exp_logits * (positive_mask + negative_mask + 1e-8), dim=1, keepdim=True
            )
        )

        # Compute mean of log-likelihood over positive pairs
        mean_log_prob_pos = (positive_mask * log_prob).sum(1) / (
            positive_mask.sum(1) + 1e-8
        )

        # Loss is negative log-likelihood
        loss = -mean_log_prob_pos
        loss = loss[
            positive_mask.sum(1) > 0
        ].mean()  # Only compute loss for samples with positive pairs

        return loss


# Initialize model - works for ANY number of genes!
gene_dim = GENE_DIM  # From the embeddings we extracted
n_genes_model = n_genes  # Number of genes
print(f"Gene dimension: {gene_dim}")
print(f"Number of genes: {n_genes_model}")
print(
    f"Building genewise model with gene_dim={gene_dim}, hidden_dim=256, output_dim=128"
)

model = GenewiseContrastiveProjector(
    gene_dim=gene_dim,
    hidden_dim=256,
    output_dim=128,
    pooling="mean",  # Options: 'mean', 'max', 'attention'
).to(device)

criterion = ContrastiveLoss(temperature=0.07)

print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nThis model works for ANY number of genes!")

In [None]:
# Skip this cell - duplicate PCA already done in cell-8
# PCA will be recomputed on the filtered dataset in cell-37 after training
print("PCA will be computed on filtered data in cell-37 after training")

In [None]:
# Note: This PCA is on the FULL dataset before filtering in cell-18
# This is just for initial exploration
# The final PCA comparisons will be done in cell-37 after training

# Perform PCA on original embeddings (full dataset, before filtering)
pca_original_full = PCA(n_components=2)
embeddings_flat_full = embeddings.reshape(embeddings.shape[0], -1)
pca_coords_original_full = pca_original_full.fit_transform(embeddings_flat_full)

print(f"PCA on full dataset (before filtering in cell-18):")
print(f"  Explained variance ratio: {pca_original_full.explained_variance_ratio_}")
print(
    f"  Total variance explained: {pca_original_full.explained_variance_ratio_.sum():.3f}"
)
print(f"\nNote: Final PCA for visualization will be computed in cell-37 after training")

In [None]:
# Identify confounder columns (UPDATE THESE based on your data)
# Examples: ancestry, batch, read_depth, sequencing_platform, etc.
confounder_cols = [
    col
    for col in df.columns
    if col
    in [
        "has_t1d",
        "is_european",
        "is_gt_65_years",
        "vcf_assay_version",
        "is_female",
        "is_bmi_gt_30",  # Add your actual confounder column names
    ]
]

print(f"Confounder columns found: {confounder_cols}")
if len(confounder_cols) == 0:
    print(
        "WARNING: No confounder columns found. Please update the confounder_cols list above."
    )
    # Create dummy example for demonstration
    confounder_cols = [LABEL_COL]

In [None]:
# Create multi-plot visualization of PC1 and PC2 vs confounders
show_plt = True

if show_plt:

    from matplotlib.colors import ListedColormap

    n_confounders = len(confounder_cols) + 1  # +1 for the main is_positive plot
    n_cols = min(3, n_confounders)
    n_rows = (n_confounders + n_cols - 1) // n_cols

    # Use more contrasting colormaps
    contrast_cmap = ListedColormap(
        [
            "#e41a1c",
            "#ffff33",
            "#377eb8",
        ]
    )  # "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"])
    contrast_cmap_cont = "plasma"  # For continuous variables

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows))
    axes = axes.flatten() if n_confounders > 1 else [axes]

    # Plot PC1 vs PC2 colored by is_positive (main outcome)
    scatter = axes[0].scatter(
        df["PC1_original"],
        df["PC2_original"],
        c=df[LABEL_COL],
        cmap=contrast_cmap,
        alpha=0.7,
        s=30,
        edgecolor="k",
        linewidth=0.5,
    )
    axes[0].set_xlabel("PC1")
    axes[0].set_ylabel("PC2")
    axes[0].set_title("PC1 vs PC2 (Original) - Colored by is_positive")
    axes[0].legend(*scatter.legend_elements(), title="is_positive")

    # Plot PC1 vs PC2 colored by each confounder
    for idx, confounder in enumerate(confounder_cols, start=1):
        if idx >= len(axes):
            break

        # Check if confounder is categorical or continuous
        if df[confounder].dtype == "object" or df[confounder].nunique() < 10:
            # Categorical - use discrete, high-contrast colors
            scatter = axes[idx].scatter(
                df["PC1_original"],
                df["PC2_original"],
                c=pd.Categorical(df[confounder]).codes,
                cmap=contrast_cmap,
                alpha=0.25,
                s=30,
                edgecolor="k",
                linewidth=0.5,
            )
            axes[idx].legend(
                *scatter.legend_elements(), title=confounder, loc="best", fontsize=8
            )
        else:
            # Continuous - use a more vibrant colormap
            scatter = axes[idx].scatter(
                df["PC1_original"],
                df["PC2_original"],
                c=df[confounder],
                cmap=contrast_cmap_cont,
                alpha=0.25,
                s=30,
                edgecolor="k",
                linewidth=0.5,
            )
            plt.colorbar(scatter, ax=axes[idx], label=confounder)

        axes[idx].set_xlabel("PC1")
        axes[idx].set_ylabel("PC2")
        axes[idx].set_title(f"PC1 vs PC2 (Original) - Colored by {confounder}")

    # Hide unused subplots
    for idx in range(n_confounders, len(axes)):
        axes[idx].axis("off")

    plt.tight_layout()
    plt.savefig("pca_original_confounders.png", dpi=150, bbox_inches="tight")
    plt.show()

## 3. Train/Validation Split

**Critical**: We split by `case_matched` groups (not individual samples) to prevent data leakage.

If we randomly split individuals, a case and its matched controls might end up in different sets, allowing the model to "cheat" by learning confounder patterns that appear in both train and validation.

By keeping each case-control group together, we ensure the model generalizes to unseen confounder combinations.

In [None]:
train_only_mask = df["split"] == "train"
df = df[train_only_mask]

missing_case_mask = df[MATCH_CASE_COL].isna()
df.loc[missing_case_mask, MATCH_CASE_COL] = df.loc[missing_case_mask, "sample_id"]

In [None]:
# Recreate train/val splits by matched_case_kitid groups
unique_groups = df[MATCH_CASE_COL].unique()
from sklearn.model_selection import train_test_split

train_groups, val_groups = train_test_split(
    unique_groups, test_size=0.2, random_state=42
)
train_mask = df[MATCH_CASE_COL].isin(train_groups)
val_mask = df[MATCH_CASE_COL].isin(val_groups)
df_train = df[train_mask].copy()
df_val = df[val_mask].copy()
print(f"Train set: {len(df_train)} samples ({df_train[LABEL_COL].sum()} positive)")
print(f"Val set: {len(df_val)} samples ({df_val[LABEL_COL].sum()} positive)")

In [None]:
def extract_embeddings(df, EMBEDDING_COLS):
    """
    Extract embeddings from dataframe where each gene column contains numpy arrays.
    Returns a 3D array of shape (n_samples, gene_dim, n_genes).
    """
    n_samples = len(df)
    n_genes = len(EMBEDDING_COLS)

    # Get gene_dim from first array
    gene_dim = df[EMBEDDING_COLS[0]].iloc[0].shape[0]

    # Create 3D tensor: (n_samples, gene_dim, n_genes)
    embeddings = np.zeros((n_samples, gene_dim, n_genes))

    for gene_idx, gene_col in enumerate(EMBEDDING_COLS):
        # Stack all arrays for this gene
        embeddings[:, :, gene_idx] = np.vstack(df[gene_col].values)

    return embeddings.astype(np.float32)


X_train = extract_embeddings(df_train, EMBEDDING_COLS)
y_train = df_train[LABEL_COL].values
case_matched_train = pd.factorize(df_train[MATCH_CASE_COL])[0]

X_val = extract_embeddings(df_val, EMBEDDING_COLS)
y_val = df_val[LABEL_COL].values
case_matched_val = pd.factorize(df_val[MATCH_CASE_COL])[0]

print(f"X_train shape: {X_train.shape} (n_samples, gene_dim, n_genes)")
print(f"X_val shape: {X_val.shape}")
print(f"Ready for genewise architecture!")

In [None]:
# Define the EmbeddingDataset class
class EmbeddingDataset(Dataset):
    """Dataset for pre-computed embeddings with labels and case_matched groups"""

    def __init__(self, embeddings, labels, case_matched):
        # embeddings shape: (n_samples, gene_dim, n_genes)
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.LongTensor(labels)
        self.case_matched = torch.LongTensor(case_matched)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx], self.case_matched[idx]


# Check if each factor in case_matched_train has exactly 5 samples
import collections

factor_counts = collections.Counter(case_matched_train)
print("Value counts for each factor in case_matched_train:")
print(f"Total unique groups: {len(factor_counts)}")
print(f"Min samples per group: {min(factor_counts.values())}")
print(f"Max samples per group: {max(factor_counts.values())}")
all_five = all(count == 5 for count in factor_counts.values())
print(f"All factors have exactly 5 samples? {all_five}")

# Create datasets and dataloaders for genewise architecture
print(f"\nCreating datasets for genewise architecture...")
train_dataset = EmbeddingDataset(X_train, y_train, case_matched_train)
val_dataset = EmbeddingDataset(X_val, y_val, case_matched_val)

# Create data loaders
batch_size = 64
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"\nDataLoaders created successfully!")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(
    f"  Batch shape: (batch_size={batch_size}, gene_dim={X_train.shape[1]}, n_genes={X_train.shape[2]})"
)

## 4. PyTorch Contrastive Learning Architecture

### ContrastiveProjector
A 3-layer MLP that projects the input embeddings to a lower-dimensional space (128D) optimized for contrastive learning:
- **BatchNorm**: Stabilizes training and prevents internal covariate shift
- **Dropout**: Prevents overfitting to specific confounder patterns
- **L2 Normalization**: Output embeddings lie on unit hypersphere, making cosine similarity meaningful

### ContrastiveLoss
Supervised contrastive loss adapted for matched case-control design:
- Uses temperature-scaled cosine similarity
- Only considers pairs within the same `case_matched` group
- Pulls together samples with same label, pushes apart samples with different labels
- Automatically handles variable numbers of positive pairs per sample

In [None]:
# Model and loss already defined in cell-11 (GenewiseContrastiveProjector)
# This cell is skipped to avoid overwriting the genewise model
print("Using GenewiseContrastiveProjector model from cell-11")
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

## 5. Dataset and DataLoader

PyTorch dataset wrapper for the embeddings with custom collation to pass:
1. Pre-computed embeddings (input features)
2. Labels (`is_positive`)
3. Case-match group IDs (for contrastive loss computation)

The DataLoader shuffles training data while keeping batches reasonably sized to ensure diverse case-control groups per batch.

In [None]:
# Skip this cell - embeddings will be generated after training in cell-36
print("Skipping - embeddings will be generated after training (see cell-36)")

## 6. Monitoring Functions

We track two complementary metrics during training:

### 1. Calinski-Harabasz Index (Geometric)
Ratio of between-cluster variance to within-cluster variance. Higher values indicate better-defined, more separated clusters.

### 2. Logistic Regression AUC (Discriminative)
Train a simple linear classifier on the embeddings and evaluate on validation set.

In [None]:
def compute_cluster_separation(embeddings, labels):
    """
    Compute Calinski-Harabasz Index (Variance Ratio Criterion).
    Measures the ratio of between-cluster to within-cluster variance.
    Higher values indicate better-defined clusters.

    Returns:
        float: CH index score (higher is better, range [0, inf))
    """
    if len(np.unique(labels)) < 2:
        return 0.0

    return calinski_harabasz_score(embeddings, labels)


def compute_logistic_auc(X_train, y_train, X_val, y_val):
    """
    Train a logistic regression classifier and return validation AUC.
    """
    clf = LogisticRegression(max_iter=1000, random_state=42)
    clf.fit(X_train, y_train)

    # Predict probabilities
    y_pred_proba = clf.predict_proba(X_val)[:, 1]
    auc = roc_auc_score(y_val, y_pred_proba)

    return auc


def evaluate_embeddings(model, train_loader, val_loader, device):
    """
    Evaluate the quality of learned embeddings using cluster separation and AUC.
    """
    model.eval()

    # Get all train embeddings
    train_embeddings_list = []
    train_labels_list = []
    with torch.no_grad():
        for embeddings, labels, _ in train_loader:
            embeddings = embeddings.to(device)
            projected = model(embeddings)
            train_embeddings_list.append(projected.cpu().numpy())
            train_labels_list.append(labels.numpy())

    train_embeddings = np.vstack(train_embeddings_list)
    train_labels = np.concatenate(train_labels_list)

    # Get all val embeddings
    val_embeddings_list = []
    val_labels_list = []
    with torch.no_grad():
        for embeddings, labels, _ in val_loader:
            embeddings = embeddings.to(device)
            projected = model(embeddings)
            val_embeddings_list.append(projected.cpu().numpy())
            val_labels_list.append(labels.numpy())

    val_embeddings = np.vstack(val_embeddings_list)
    val_labels = np.concatenate(val_labels_list)

    # Compute metrics
    train_separation = compute_cluster_separation(train_embeddings, train_labels)
    val_separation = compute_cluster_separation(val_embeddings, val_labels)
    auc = compute_logistic_auc(
        train_embeddings, train_labels, val_embeddings, val_labels
    )

    return {
        "train_separation": train_separation,
        "val_separation": val_separation,
        "val_auc": auc,
        "train_embeddings": train_embeddings,
        "train_labels": train_labels,
        "val_embeddings": val_embeddings,
        "val_labels": val_labels,
    }


print("Monitoring functions defined successfully!")

## 7. Training Loop

Train the projection network with:
- **AdamW optimizer**: Decoupled weight decay for better generalization
- **Cosine annealing schedule**: Gradually reduces learning rate for fine-tuning
- **Evaluation every 5 epochs**: Track cluster separation and AUC without slowing training
- **Best model selection**: Save model with highest validation AUC

**Note**: Some batches may have NaN loss if they don't contain valid positive pairs. These are safely skipped.

In [None]:
# Training hyperparameters
num_epochs = 50
learning_rate = 1e-3
weight_decay = 1e-5

# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# Training history
history = {
    "train_loss": [],
    "val_loss": [],
    "train_separation": [],
    "val_separation": [],
    "val_auc": [],
}

print(f"Training for {num_epochs} epochs...")
print(f"Initial evaluation before training:")

In [None]:
# Main training loop
best_val_auc = 0
best_model_state = None

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0

    for embeddings, labels, case_matched in tqdm(
        train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"
    ):
        embeddings = embeddings.to(device)
        labels = labels.to(device)
        case_matched = case_matched.to(device)

        # Forward pass
        projected = model(embeddings)
        loss = criterion(projected, labels, case_matched)

        # Skip if loss is nan (can happen if batch has no valid pairs)
        if torch.isnan(loss):
            continue

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_batches += 1

    train_loss /= max(train_batches, 1)

    # Validation phase
    model.eval()
    val_loss = 0
    val_batches = 0

    with torch.no_grad():
        for embeddings, labels, case_matched in val_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)
            case_matched = case_matched.to(device)

            projected = model(embeddings)
            loss = criterion(projected, labels, case_matched)

            if not torch.isnan(loss):
                val_loss += loss.item()
                val_batches += 1

    val_loss /= max(val_batches, 1)

    # Update learning rate
    scheduler.step()

    # Evaluate embeddings quality every 5 epochs
    if (epoch + 1) % 5 == 0 or epoch == 0:
        metrics = evaluate_embeddings(model, train_loader, val_loader, device)

        history["train_separation"].append(metrics["train_separation"])
        history["val_separation"].append(metrics["val_separation"])
        history["val_auc"].append(metrics["val_auc"])

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(
            f"  Train Sep: {metrics['train_separation']:.4f} | Val Sep: {metrics['val_separation']:.4f}"
        )
        print(f"  Val AUC: {metrics['val_auc']:.4f}")
        print()

        # Save best model
        if metrics["val_auc"] > best_val_auc:
            best_val_auc = metrics["val_auc"]
            best_model_state = model.state_dict().copy()

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)

print(f"Training complete!")
print(f"Best validation AUC: {best_val_auc:.4f}")

In [None]:
# Note: This summary cell should come AFTER training
# Moving summary to cell-41 where it belongs
print("Summary will be displayed after training completes (see cell-41)")

In [None]:
# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model based on validation AUC")

## 8. Training History Visualization

Visualize three key metrics over training:

1. **Contrastive Loss**: Should decrease and stabilize
2. **Calinski-Harabasz Index**: Should increase as clusters become better separated
3. **Validation AUC**: Should improve, indicating better downstream utility

If CH index increases but AUC plateaus, the model may be overfitting to train-specific patterns.

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curves
axes[0].plot(history["train_loss"], label="Train Loss", linewidth=2)
axes[0].plot(history["val_loss"], label="Val Loss", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Contrastive Loss")
axes[0].set_title("Training and Validation Loss")
axes[0].legend()
axes[0].grid(alpha=0.3)

# Calinski-Harabasz Index
# Evaluation happens at epoch 1, then every 5 epochs: 1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50
eval_epochs = [1] + list(range(5, num_epochs + 1, 5))
print(f"Eval epochs: {eval_epochs}")
print(f"Number of evals: {len(eval_epochs)}")
print(f"History length: {len(history['train_separation'])}")

axes[1].plot(
    eval_epochs[: len(history["train_separation"])],
    history["train_separation"],
    marker="o",
    label="Train CH Index",
    linewidth=2,
)
axes[1].plot(
    eval_epochs[: len(history["val_separation"])],
    history["val_separation"],
    marker="o",
    label="Val CH Index",
    linewidth=2,
)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Calinski-Harabasz Index")
axes[1].set_title("Calinski-Harabasz Index (Higher is Better)")
axes[1].legend()
axes[1].grid(alpha=0.3)

# AUC
axes[2].plot(
    eval_epochs[: len(history["val_auc"])],
    history["val_auc"],
    marker="o",
    color="green",
    linewidth=2,
)
axes[2].axhline(y=0.5, color="red", linestyle="--", label="Random Baseline", alpha=0.5)
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Validation AUC")
axes[2].set_title("Logistic Regression AUC on Learned Embeddings")
axes[2].legend()
axes[2].grid(alpha=0.3)
axes[2].set_ylim([0.4, 1.0])

plt.tight_layout()
plt.savefig("training_history.png", dpi=150, bbox_inches="tight")
plt.show()

## 9. Post-Training Embedding Analysis

Generate the final debiased embeddings by passing all samples through the trained projection network, then compare to the original embeddings.

### What to Look For:

**In the is_positive plots:**
- ✓ Clearer separation between cases (red) and controls (blue)
- ✓ Tighter within-class clusters

**In the confounder plots:**
- ✓ Reduced correlation with confounders (more mixed colors)
- ✓ Random scatter rather than clear gradients

If confounders still show strong patterns, consider:
- Longer training
- Stronger weight decay
- Explicit adversarial debiasing

In [None]:
# Skip this cell - correct embedding generation happens in cell-36
print("Skipping - using cell-36 for proper embedding generation")

In [None]:
# Generate post-training embeddings for all data (current filtered df)
model.eval()

# Re-extract embeddings from current df state (after filtering in cell-18)
print(f"Extracting embeddings from current df (after filtering)...")
print(f"Current df shape: {df.shape}")

current_embeddings = extract_embeddings(df, EMBEDDING_COLS)
print(
    f"Extracted embeddings shape: {current_embeddings.shape} (n_samples, gene_dim, n_genes)"
)

# Convert to tensor and generate trained embeddings
all_embeddings = torch.FloatTensor(current_embeddings).to(device)

with torch.no_grad():
    # Process in batches to avoid memory issues
    batch_size_inference = 256
    trained_embeddings_list = []

    for i in range(0, len(all_embeddings), batch_size_inference):
        batch = all_embeddings[i : i + batch_size_inference]
        projected = model(batch)
        trained_embeddings_list.append(projected.cpu().numpy())

    trained_embeddings = np.vstack(trained_embeddings_list)

print(f"\nGenerated trained embeddings: {trained_embeddings.shape}")
print(
    f"Reduced from {current_embeddings.shape[1] * current_embeddings.shape[2]} dimensions to {trained_embeddings.shape[1]} dimensions"
)

In [None]:
# Perform PCA on both current original embeddings and trained embeddings
# Flatten current_embeddings for PCA
current_embeddings_flat = current_embeddings.reshape(current_embeddings.shape[0], -1)

# PCA on original (current filtered data)
pca_original_current = PCA(n_components=2)
pca_coords_original = pca_original_current.fit_transform(current_embeddings_flat)

# PCA on trained embeddings
pca_trained = PCA(n_components=2)
pca_coords_trained = pca_trained.fit_transform(trained_embeddings)

# Add to dataframe
df["PC1_original"] = pca_coords_original[:, 0]
df["PC2_original"] = pca_coords_original[:, 1]
df["PC1_trained"] = pca_coords_trained[:, 0]
df["PC2_trained"] = pca_coords_trained[:, 1]

print(
    f"Original embeddings - Explained variance ratio: {pca_original_current.explained_variance_ratio_}"
)
print(
    f"Original embeddings - Total variance explained: {pca_original_current.explained_variance_ratio_.sum():.3f}"
)
print(
    f"\nTrained embeddings - Explained variance ratio: {pca_trained.explained_variance_ratio_}"
)
print(
    f"Trained embeddings - Total variance explained: {pca_trained.explained_variance_ratio_.sum():.3f}"
)

In [None]:
# Multi-plot comparison: Original vs Trained for each confounder
n_confounders = len(confounder_cols) + 1  # +1 for is_positive
n_cols = min(3, n_confounders)
n_rows = (n_confounders + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows * 2, n_cols, figsize=(6 * n_cols, 5 * n_rows * 2))
axes = axes.flatten() if n_confounders > 1 else [axes]

plot_idx = 0

# Plot is_positive first
for emb_type, pc1_col, pc2_col, title_suffix in [
    ("original", "PC1_original", "PC2_original", "Original"),
    ("trained", "PC1_trained", "PC2_trained", "Trained"),
]:
    scatter = axes[plot_idx].scatter(
        df[pc1_col], df[pc2_col], c=df[LABEL_COL], cmap="coolwarm", alpha=0.6, s=20
    )
    axes[plot_idx].set_xlabel("PC1")
    axes[plot_idx].set_ylabel("PC2")
    axes[plot_idx].set_title(f"PC1 vs PC2 ({title_suffix}) - is_positive")
    axes[plot_idx].legend(*scatter.legend_elements(), title="is_positive", fontsize=8)
    plot_idx += 1

# Plot each confounder
for confounder in confounder_cols:
    for emb_type, pc1_col, pc2_col, title_suffix in [
        ("original", "PC1_original", "PC2_original", "Original"),
        ("trained", "PC1_trained", "PC2_trained", "Trained"),
    ]:
        if plot_idx >= len(axes):
            break

        # Check if confounder is categorical or continuous
        if df[confounder].dtype == "object" or df[confounder].nunique() < 10:
            # Categorical
            scatter = axes[plot_idx].scatter(
                df[pc1_col],
                df[pc2_col],
                c=pd.Categorical(df[confounder]).codes,
                cmap="tab10",
                alpha=0.6,
                s=20,
            )
            axes[plot_idx].legend(
                *scatter.legend_elements(), title=confounder, loc="best", fontsize=8
            )
        else:
            # Continuous
            scatter = axes[plot_idx].scatter(
                df[pc1_col],
                df[pc2_col],
                c=df[confounder],
                cmap="viridis",
                alpha=0.6,
                s=20,
            )
            plt.colorbar(scatter, ax=axes[plot_idx], label=confounder)

        axes[plot_idx].set_xlabel("PC1")
        axes[plot_idx].set_ylabel("PC2")
        axes[plot_idx].set_title(f"PC1 vs PC2 ({title_suffix}) - {confounder}")
        plot_idx += 1

# Hide unused subplots
for idx in range(plot_idx, len(axes)):
    axes[idx].axis("off")

plt.tight_layout()
plt.savefig("pca_comparison_all_confounders.png", dpi=150, bbox_inches="tight")
plt.show()

print("\nVisualization complete! Check the saved PNG files for detailed comparisons.")

## 10. Summary and Export

### Final Metrics
Quantify the improvement achieved by contrastive learning:
- **Cluster separation improvement**: How much better are cases/controls separated?
- **AUC improvement**: Is the embedding more useful for downstream prediction?

### Saving Results
Optionally save:
1. **Trained embeddings**: For use in downstream analyses (GWAS, prediction models, etc.)
2. **Model weights**: To apply the same transformation to new samples

### Next Steps
With the debiased embeddings, you can:
- Run association studies with reduced confounding
- Train fairer predictive models
- Perform clustering or dimensionality reduction with less technical artifact

In [None]:
# This cell is outdated - see cell-41 for correct final summary
print("See cell-41 for final summary with correct variables")

In [None]:
# Final evaluation on all data
print("=" * 60)
print("FINAL SUMMARY")
print("=" * 60)
print(f"\nBest validation AUC achieved: {best_val_auc:.4f}")
print(f"\nOriginal embeddings:")
print(f"  Shape: {current_embeddings.shape} (n_samples, gene_dim, n_genes)")
print(f"  Total features: {current_embeddings.shape[1] * current_embeddings.shape[2]}")
print(
    f"  PCA variance explained (PC1+PC2): {pca_original_current.explained_variance_ratio_.sum():.3f}"
)

print(f"\nTrained embeddings:")
print(f"  Shape: {trained_embeddings.shape}")
print(
    f"  Dimension reduction: {current_embeddings.shape[1] * current_embeddings.shape[2]} → {trained_embeddings.shape[1]}"
)
print(
    f"  PCA variance explained (PC1+PC2): {pca_trained.explained_variance_ratio_.sum():.3f}"
)

# Compute final metrics on full dataset
# Flatten current embeddings for metrics
current_embeddings_flat = current_embeddings.reshape(current_embeddings.shape[0], -1)
final_ch_original = compute_cluster_separation(
    current_embeddings_flat, df[LABEL_COL].values
)
final_ch_trained = compute_cluster_separation(trained_embeddings, df[LABEL_COL].values)

print(f"\nCalinski-Harabasz Index (higher is better):")
print(f"  Original embeddings: {final_ch_original:.2f}")
print(f"  Trained embeddings: {final_ch_trained:.2f}")
if final_ch_original > 0:
    print(
        f"  Improvement: {((final_ch_trained - final_ch_original) / final_ch_original * 100):.2f}%"
    )
else:
    print(f"  Improvement: N/A (original CH index was 0)")

print("\n" + "=" * 60)
print("Files saved:")
print("  - pca_original_confounders.png")
print("  - training_history.png")
print("  - pca_comparison_is_positive.png")
print("  - pca_comparison_all_confounders.png")
print("=" * 60)
print("\n✨ Genewise architecture works with ANY number of genes!")