In [None]:
import anndata as ad
from anndata.experimental.pytorch import AnnLoader
import torch
import torch.optim
import torch.nn as nn
import numpy as np
import time
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
import os
os.getcwd()

# Import Data

In [None]:
t0 = time.time()
genes_df = ad.read_h5ad("/work3/s193518/scIsoPred/data/bulk_processed_genes.h5ad")
t1 = time.time()

print('loaded gene df in:', f'{t1-t0:.2f}', 'seconds')

In [None]:
t0 = time.time()
isoform_df = ad.read_h5ad("/work3/s193518/scIsoPred/data/bulk_processed_transcripts.h5ad")
t1 = time.time()

print('loaded isoform df in:', f'{t1-t0:.2f}', 'seconds')

In [None]:
# inspecting dimensions
print('number of genes =', genes_df.n_vars)
print('number of isoforms =', isoform_df.n_vars)
print('number of samples =', genes_df.n_obs)
print('proportion =', f'1 gene : {isoform_df.n_vars/genes_df.n_vars:.1f} isoforms')

## Filter Low-Expression Features
Remove genes and isoforms with very low total counts across all samples

In [None]:
# Convert to dense arrays for filtering
genes_X = torch.from_numpy(genes_df.X.toarray()).float()
isoforms_Y = torch.from_numpy(isoform_df.X.toarray()).float()

print(f"Original shapes: genes {genes_X.shape}, isoforms {isoforms_Y.shape}")

In [None]:
# Filter genes with sum of counts < 20 across all samples
MIN_COUNTS = 20

gene_sums = genes_X.sum(dim=0)
genes_to_keep = gene_sums >= MIN_COUNTS
genes_X_filtered = genes_X[:, genes_to_keep]

print(f"Genes: {genes_X.shape[1]} -> {genes_X_filtered.shape[1]} (removed {(~genes_to_keep).sum()} low-expression genes)")

# Filter isoforms with sum of counts < 20 across all samples
isoform_sums = isoforms_Y.sum(dim=0)
isoforms_to_keep = isoform_sums >= MIN_COUNTS
isoforms_Y_filtered = isoforms_Y[:, isoforms_to_keep]

print(f"Isoforms: {isoforms_Y.shape[1]} -> {isoforms_Y_filtered.shape[1]} (removed {(~isoforms_to_keep).sum()} low-expression isoforms)")

In [None]:
# Update to use filtered data
genes_X = genes_X_filtered
isoforms_Y = isoforms_Y_filtered

# Store filtered gene and isoform metadata
genes_filtered_var = genes_df.var[genes_to_keep.numpy()]
isoforms_filtered_var = isoform_df.var[isoforms_to_keep.numpy()]

print(f"\nFinal dimensions: genes={genes_X.shape[1]}, isoforms={isoforms_Y.shape[1]}")

## Convert to Isoform Proportions per Gene

In [None]:
def convert_to_isoform_proportions(isoforms_Y, isoform_var_df):
    """
    Convert isoform counts to proportions per gene.
    For each gene, isoform proportions sum to 1.
    
    Args:
        isoforms_Y: tensor of shape [n_samples, n_isoforms]
        isoform_var_df: DataFrame with 'gene_id' column mapping isoforms to genes
    
    Returns:
        proportions: tensor of shape [n_samples, n_isoforms] with proportions
    """
    proportions = isoforms_Y.clone()
    
    # Group isoforms by gene
    gene_ids = isoform_var_df['gene_id'].values
    unique_genes = np.unique(gene_ids)
    
    print(f"Converting counts to proportions for {len(unique_genes)} unique genes...")
    
    # For each gene, normalize its isoforms to sum to 1
    for gene_id in tqdm(unique_genes, desc="Processing genes"):
        # Find all isoforms belonging to this gene
        isoform_mask = gene_ids == gene_id
        isoform_indices = np.where(isoform_mask)[0]
        
        # Get counts for these isoforms across all samples
        gene_isoform_counts = isoforms_Y[:, isoform_indices]  # [n_samples, n_isoforms_for_gene]
        
        # Sum across isoforms for this gene (per sample)
        gene_total = gene_isoform_counts.sum(dim=1, keepdim=True) + 1e-8  # [n_samples, 1]
        
        # Convert to proportions
        proportions[:, isoform_indices] = gene_isoform_counts / gene_total
    
    return proportions

In [None]:
# Convert isoform counts to proportions
isoforms_Y_proportions = convert_to_isoform_proportions(isoforms_Y, isoforms_filtered_var)

# Verify some proportions sum to ~1 for each gene
sample_gene = isoforms_filtered_var['gene_id'].values[0]
sample_isoforms = isoforms_filtered_var['gene_id'].values == sample_gene
print(f"\nSample verification for gene {sample_gene}:")
print(f"Sum of proportions (sample 0): {isoforms_Y_proportions[0, sample_isoforms].sum():.4f}")
print(f"Sum of proportions (sample 1): {isoforms_Y_proportions[1, sample_isoforms].sum():.4f}")

In [None]:
def create_gene_to_isoform_mapping(isoform_var_df):
    """
    create mapping from gene ids to their isoform indices.
    used by the GeneAwareMLP class to apply softmax per gene.
    
    Returns:
        gene_to_isoform_map: dict mapping gene_id -> list of isoform indices
        sorted_genes: list of gene IDs in sorted order
        isoform_to_position: dict mapping global isoform index -> (gene_idx, position_in_gene)
    """
    gene_ids = isoform_var_df['gene_id'].values
    unique_genes = np.unique(gene_ids)
    
    gene_to_isoform_map = {}
    isoform_to_position = {}
    
    for gene_id in unique_genes:
        # Find all isoforms for this gene
        isoform_indices = np.where(gene_ids == gene_id)[0].tolist()
        gene_to_isoform_map[gene_id] = isoform_indices
        
        # Store reverse mapping for reconstruction
        for pos, iso_idx in enumerate(isoform_indices):
            isoform_to_position[iso_idx] = (gene_id, pos)
    
    print(f"Created mapping for {len(unique_genes)} genes")
    print(f"Average isoforms per gene: {len(gene_ids) / len(unique_genes):.2f}")
    
    # Sort genes for consistent ordering
    sorted_genes = sorted(gene_to_isoform_map.keys())
    
    return gene_to_isoform_map, sorted_genes, isoform_to_position


gene_to_isoform_map, sorted_gene_ids, isoform_to_position = create_gene_to_isoform_mapping(isoforms_filtered_var)

# Show some examples
print("\nExample genes and their isoforms:")
for i, gene_id in enumerate(sorted_gene_ids[:3]):
    n_isoforms = len(gene_to_isoform_map[gene_id])
    print(f"  {gene_id}: {n_isoforms} isoforms")

# Create Gene-to-Isoform Mapping for Gene-Aware Model

## Fixed GenesDataset with Proper Normalization

In [None]:
class GenesDataset(Dataset):
    def __init__(self, genes_X, isoforms_Y, log_transform=True, normalize=True, axis='feature'):
        """
        Arguments:
            genes_X (torch.Tensor): gene expression matrix [n_samples, n_genes]
            isoforms_Y (torch.Tensor): isoform abundance matrix [n_samples, n_isoforms]
            log_transform (bool): whether to apply log transform with pseudocount. Default = True
            normalize (bool): whether to apply z-score normalization. Default = True
            axis (str): 'feature' or 'sample' - axis along which to normalize
        """
        # Clone to avoid modifying original data
        self.genes_X = genes_X.clone()
        self.isoforms_Y = isoforms_Y.clone()
        
        if log_transform:
            # Log transform with pseudocount
            self.genes_X = torch.log1p(self.genes_X)
            self.isoforms_Y = torch.log1p(self.isoforms_Y)
        
        if normalize:
            if axis == "feature":  # normalize per gene/isoform across samples
                genes_mean = self.genes_X.mean(dim=0, keepdim=True)
                genes_std = self.genes_X.std(dim=0, keepdim=True) + 1e-8
                self.genes_X = (self.genes_X - genes_mean) / genes_std

                iso_mean = self.isoforms_Y.mean(dim=0, keepdim=True)
                iso_std = self.isoforms_Y.std(dim=0, keepdim=True) + 1e-8
                self.isoforms_Y = (self.isoforms_Y - iso_mean) / iso_std

            elif axis == "sample":  # normalize per cell/sample
                genes_mean = self.genes_X.mean(dim=1, keepdim=True)
                genes_std = self.genes_X.std(dim=1, keepdim=True) + 1e-8
                self.genes_X = (self.genes_X - genes_mean) / genes_std

                iso_mean = self.isoforms_Y.mean(dim=1, keepdim=True)
                iso_std = self.isoforms_Y.std(dim=1, keepdim=True) + 1e-8
                self.isoforms_Y = (self.isoforms_Y - iso_mean) / iso_std

            else:
                raise ValueError("axis must be either 'feature' or 'sample'")

    def __len__(self):
        return self.genes_X.size(0)

    def __getitem__(self, idx):
        return self.genes_X[idx].unsqueeze(0), self.isoforms_Y[idx].unsqueeze(0)

## Use Full Dataset

In [None]:
# Use full dataset instead of subset
seed = 42
test_size = 0.1

np.random.seed(seed)
n_obs = genes_df.n_obs
indices = np.random.permutation(n_obs)

test_count = int(n_obs * test_size)
test_idx = indices[:test_count]
train_idx = indices[test_count:]

print(f"Train samples: {len(train_idx)}")
print(f"Test samples: {len(test_idx)}")

In [None]:
# Create datasets with isoform proportions
train_dataset = GenesDataset(
    genes_X[train_idx],
    isoforms_Y_proportions[train_idx],
    log_transform=True,
    normalize=True,
    axis='feature'
)

test_dataset = GenesDataset(
    genes_X[test_idx],
    isoforms_Y_proportions[test_idx],
    log_transform=True,
    normalize=True,
    axis='feature'
)

print(f"Dataset created with {len(train_dataset)} training and {len(test_dataset)} test samples")

In [None]:
batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

validation_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

## Gene-Aware Model Architecture

In [None]:
class GeneAwareMLP(nn.Module):
    """
    Gene-aware MLP that ensures isoform proportions sum to 1 per gene.
    Uses a shared encoder followed by gene-specific output heads with softmax.
    """
    def __init__(self, input_dim, hidden_dims, gene_to_isoform_map, sorted_gene_ids):
        """
        Args:
            input_dim: number of input genes
            hidden_dims: list of hidden layer dimensions
            gene_to_isoform_map: dict mapping gene_id -> list of isoform indices
            sorted_gene_ids: sorted list of gene IDs for consistent ordering
        """
        super().__init__()
        
        # Shared encoder
        encoder_layers = []
        for in_size, out_size in zip([input_dim] + hidden_dims, hidden_dims):
            encoder_layers.append(nn.Linear(in_size, out_size))
            encoder_layers.append(nn.LayerNorm(out_size))
            encoder_layers.append(nn.ReLU())
            encoder_layers.append(nn.Dropout(p=0.5))
        
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Gene-specific output heads
        # Each gene gets its own linear layer that outputs logits for its isoforms
        self.gene_heads = nn.ModuleDict()
        for gene_id in sorted_gene_ids:
            n_isoforms = len(gene_to_isoform_map[gene_id])
            # Linear layer: hidden_dim -> n_isoforms for this gene
            self.gene_heads[str(gene_id)] = nn.Linear(hidden_dims[-1], n_isoforms)
        
        self.gene_to_isoform_map = gene_to_isoform_map
        self.sorted_gene_ids = sorted_gene_ids
        
        # Pre-compute output order for efficient reconstruction
        self._precompute_output_order()
    
    def _precompute_output_order(self):
        """
        Precompute the mapping from gene outputs to final isoform tensor.
        This avoids recomputing it every forward pass.
        """
        self.output_mapping = []  # List of (gene_idx, isoform_positions)
        
        for gene_idx, gene_id in enumerate(self.sorted_gene_ids):
            isoform_positions = self.gene_to_isoform_map[gene_id]
            self.output_mapping.append((gene_idx, isoform_positions))
    
    def forward(self, x):
        """
        Forward pass with gene-aware softmax.
        
        Returns:
            isoform_proportions: tensor [batch, n_isoforms] where proportions
                                 sum to 1 for each gene
        """
        batch_size = x.size(0)
        
        # Shared encoding
        encoded = self.encoder(x)  # [batch, hidden_dim]
        
        # Initialize output tensor
        total_isoforms = sum(len(positions) for positions in self.gene_to_isoform_map.values())
        output = torch.zeros(batch_size, total_isoforms, device=x.device)
        
        # For each gene, predict isoform proportions with softmax
        for gene_id in self.sorted_gene_ids:
            # Get logits for this gene's isoforms
            logits = self.gene_heads[str(gene_id)](encoded)  # [batch, n_isoforms_for_gene]
            
            # Apply softmax to ensure sum to 1
            proportions = torch.softmax(logits, dim=1)  # [batch, n_isoforms_for_gene]
            
            # Place in correct positions in output tensor
            isoform_indices = self.gene_to_isoform_map[gene_id]
            output[:, isoform_indices] = proportions
        
        return output.unsqueeze(1)  # [batch, 1, n_isoforms]

### Initialize GeneAwareMLP model


In [None]:
use_cuda = torch.cuda.is_available()
hidden_dims = [2048, 1024, 512, 256, 128]

model = GeneAwareMLP(
    input_dim=genes_X.shape[1],  # filtered gene count
    hidden_dims=hidden_dims,
    gene_to_isoform_map=gene_to_isoform_map,
    sorted_gene_ids=sorted_gene_ids
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
print(f"Number of gene-specific heads: {len(model.gene_heads)}")

if use_cuda:
    model.to('cuda')
    print("Using CUDA")
else:
    print("Using CPU")

### old, not gene-aware model

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, out_dim):
        super().__init__()
        
        modules = []
        for in_size, out_size in zip([input_dim]+hidden_dims, hidden_dims):
            modules.append(nn.Linear(in_size, out_size))
            modules.append(nn.LayerNorm(out_size))
            modules.append(nn.ReLU())
            modules.append(nn.Dropout(p=0.5))
        modules.append(nn.Linear(hidden_dims[-1], out_dim))
        self.model = nn.Sequential(*modules)

    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize model with filtered dimensions
use_cuda = torch.cuda.is_available()
hidden_dims = [2048, 1024, 512, 256, 128]

model = MLP(
    input_dim=genes_X.shape[1],  # filtered gene count
    hidden_dims=hidden_dims,
    out_dim=isoforms_Y_proportions.shape[1]  # filtered isoform count
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

if use_cuda:
    model.to('cuda')
    print("Using CUDA")
else:
    print("Using CPU")

## Pearson Correlation Loss Function

In [None]:
class PearsonCorrelationLoss(nn.Module):
    """
    Pearson correlation loss: returns 1 - correlation
    Minimizing this maximizes correlation between predictions and targets
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, pred, target):
        # Flatten to compute global correlation
        pred_flat = pred.flatten()
        target_flat = target.flatten()
        
        # Center the variables
        pred_centered = pred_flat - pred_flat.mean()
        target_centered = target_flat - target_flat.mean()
        
        # Compute correlation
        numerator = (pred_centered * target_centered).sum()
        denominator = torch.sqrt((pred_centered ** 2).sum() * (target_centered ** 2).sum()) + 1e-8
        
        correlation = numerator / denominator
        
        # Return 1 - correlation (to minimize)
        return 1 - correlation


class CombinedLoss(nn.Module):
    """
    Combines MSE and Pearson correlation loss
    """
    def __init__(self, mse_weight=0.5, corr_weight=0.5):
        super().__init__()
        self.mse_weight = mse_weight
        self.corr_weight = corr_weight
        self.mse_loss = nn.MSELoss()
        self.corr_loss = PearsonCorrelationLoss()
    
    def forward(self, pred, target):
        mse = self.mse_loss(pred, target)
        corr = self.corr_loss(pred, target)
        return self.mse_weight * mse + self.corr_weight * corr

# Training Utils

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode="min"):
        """
        Arguments:
            patience (int): number of epochs to wait after last improvement
            min_delta (float): minimum change to qualify as improvement
            mode (str): "min" to monitor decreasing metric, "max" for increasing
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
            return False

        improvement = (
            (current_score < self.best_score - self.min_delta)
            if self.mode == "min"
            else (current_score > self.best_score + self.min_delta)
        )

        if improvement:
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

In [None]:
def pearson_corr(x, y):
    """Compute Pearson correlation coefficient"""
    vx = x - torch.mean(x)
    vy = y - torch.mean(y)
    return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + 1e-8)


def train_model(model, train_loader, val_loader, optimizer, loss_fn, num_epochs, scheduler, early_stopping, device="cpu"):
    model.to(device)

    train_losses_array = []
    val_losses_array = []
    train_corr_array = []
    val_corr_array = []

    for epoch in range(num_epochs):
        # ----- TRAINING -----
        model.train()
        running_loss = 0.0
        running_corr = 0.0
        total = 0

        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, targets = inputs.float().to(device), targets.float().to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            with torch.no_grad():
                batch_corr = pearson_corr(outputs.flatten(), targets.flatten()).item()
                running_corr += batch_corr * inputs.size(0)
            total += inputs.size(0)

        train_loss = running_loss / total
        train_corr = running_corr / total
        train_losses_array.append(train_loss)
        train_corr_array.append(train_corr)

        # ----- VALIDATION -----
        model.eval()
        val_loss = 0.0
        val_corr = 0.0
        total_val = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.float().to(device), targets.float().to(device)
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)

                val_loss += loss.item() * inputs.size(0)
                batch_corr = pearson_corr(outputs.flatten(), targets.flatten()).item()
                val_corr += batch_corr * inputs.size(0)
                total_val += inputs.size(0)

        val_loss /= total_val
        val_corr /= total_val
        val_losses_array.append(val_loss)
        val_corr_array.append(val_corr)

        scheduler.step(val_loss)

        if early_stopping(val_loss):
            print(f"Stopping early at epoch {epoch+1}")
            break

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f} | Train Corr: {train_corr:.4f} "
              f"| Val Loss: {val_loss:.4f} | Val Corr: {val_corr:.4f}")

    print("Finished training.")
    return train_losses_array, train_corr_array, val_losses_array, val_corr_array

# Train Model

In [None]:
# Training configuration
num_epochs = 500
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use combined loss (MSE + Correlation)
loss_fn = CombinedLoss(mse_weight=0.3, corr_weight=0.7)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
early_stopping = EarlyStopping(patience=10, min_delta=1e-4, mode="min")

print("Starting training...")
t0 = time.time()

train_losses, train_corrs, val_losses, val_corrs = train_model(
    model,
    train_loader,
    validation_loader,
    optimizer,
    loss_fn,
    num_epochs,
    scheduler,
    early_stopping,
    device
)

t1 = time.time()
print(f'\nTotal training time: {(t1-t0)//60:.0f}m {(t1-t0)%60:.0f}s')

# Visualize Results

In [None]:
epochs = range(1, len(val_losses) + 1)

plt.figure(figsize=(12, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', color='tab:blue', linewidth=2)
plt.plot(epochs, val_losses, label='Validation Loss', color='tab:orange', linewidth=2)
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Combined Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

# Correlation plot
plt.subplot(1, 2, 2)
plt.plot(epochs, train_corrs, label='Train Correlation', color='tab:green', linewidth=2)
plt.plot(epochs, val_corrs, label='Validation Correlation', color='tab:red', linewidth=2)
plt.title('Correlation over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Pearson Correlation')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

verify that predictions sum to 1 for each gene

In [None]:
def verify_gene_proportions(predictions, gene_to_isoform_map, sorted_gene_ids, n_samples=5):
    predictions = predictions.squeeze(1)  # [batch, n_isoforms]
    
    violations = []
    all_sums = []
    
    for gene_id in sorted_gene_ids[:10]:
        isoform_indices = gene_to_isoform_map[gene_id]
        
        # Get proportions for this gene across all samples
        gene_proportions = predictions[:n_samples, isoform_indices]  # [n_samples, n_isoforms_for_gene]
        
        sums = gene_proportions.sum(dim=1)
        all_sums.extend(sums.tolist())
        mean_sum = sums.mean().item()
        max_deviation = (sums - 1.0).abs().max().item()
        
        status = "ok" if max_deviation < 1e-5 else "not ok"
        
        print(f"{status} Gene {gene_id}: mean_sum={mean_sum:.6f}, max_deviation={max_deviation:.2e}, n_isoforms={len(isoform_indices)}")
        
        if max_deviation >= 1e-5:
            violations.append((gene_id, max_deviation))
    
    if violations:
        print(f"\n  Found {len(violations)} genes with sum != 1.0")
    else:
        print(f"\n All checked genes have proportions summing to 1.0!")
    
    # Overall statistics
    all_sums_tensor = torch.tensor(all_sums)
    print(f"\nOverall statistics:")
    print(f"  Mean sum: {all_sums_tensor.mean():.6f}")
    print(f"  Std sum: {all_sums_tensor.std():.6f}")
    print(f"  Min sum: {all_sums_tensor.min():.6f}")
    print(f"  Max sum: {all_sums_tensor.max():.6f}")

model.eval()
with torch.no_grad():
    sample_inputs, sample_targets = next(iter(validation_loader))
    sample_inputs = sample_inputs.float().to(device)
    sample_predictions = model(sample_inputs).cpu()

verify_gene_proportions(sample_predictions, gene_to_isoform_map, sorted_gene_ids, n_samples=min(5, len(sample_predictions)))

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_corrs': train_corrs,
    'val_corrs': val_corrs,
    'n_genes': genes_X.shape[1],
    'n_isoforms': isoforms_Y_proportions.shape[1],
    'gene_to_isoform_map': gene_to_isoform_map,
    'sorted_gene_ids': sorted_gene_ids,
    'hidden_dims': hidden_dims,
}, 'isoform_model_gene_aware.pt')

print("Model saved to 'isoform_model_gene_aware.pt'")

# Evaluation and Residual Analysis

In [None]:
model.eval()
preds_list, targets_list = [], []
n_points = 5000

with torch.no_grad():
    for inputs, targets in validation_loader:
        inputs, targets = inputs.float().to(device), targets.float().to(device)
        outputs = model(inputs)
        preds_list.append(outputs.cpu())
        targets_list.append(targets.cpu())

preds = torch.cat(preds_list).flatten()
targets = torch.cat(targets_list).flatten()

if len(preds) > n_points:
    idx = torch.randperm(len(preds))[:n_points]
    preds = preds[idx]
    targets = targets[idx]

residuals = targets - preds

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(residuals, alpha=0.3, bins=100)
plt.axvline(0, color="darkred", linestyle="--")
plt.xlabel("Residual (True - Predicted)")
plt.title("Residual Distribution")

plt.subplot(1, 3, 2)
plt.scatter(preds, residuals, alpha=0.3, s=10)
plt.axhline(0, color="darkred", linestyle="--")
plt.xlabel("Predicted Isoform Proportion")
plt.ylabel("Residual (True - Predicted)")
plt.title("Residual vs Predicted")

plt.subplot(1, 3, 3)
plt.scatter(preds, targets, alpha=0.2, s=10, c='darkred')
plt.plot([preds.min(), preds.max()], [preds.min(), preds.max()], 'k--', lw=2)
plt.xlabel("Predicted Isoform Proportion")
plt.ylabel("True Isoform Proportion")
plt.title("Predicted vs True")

plt.tight_layout()
plt.show()

# Print metrics
rmse = torch.sqrt(torch.mean((preds - targets) ** 2))
corr = pearson_corr(preds, targets)
print(f"\nFinal Metrics:")
print(f"RMSE: {rmse:.4f}")
print(f"Pearson Correlation: {corr:.4f}")

# Save Model

In [None]:
# Save model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_corrs': train_corrs,
    'val_corrs': val_corrs,
    'n_genes': genes_X.shape[1],
    'n_isoforms': isoforms_Y_proportions.shape[1],
}, 'isoform_model_improved.pt')

print("Model saved to 'isoform_model_improved.pt'")