# scGPT Model Imputation Inference

This notebook demonstrates how to:
1. Load a pre-trained scGPT model
2. Perform gene expression imputation
3. Evaluate and visualize results

## Model Files Location

The model is loaded from `models/scGPT/` directory. Note that `models/` is typically
a symlink to your GPU storage (e.g., `/scratch/user/st2/models/`).

Place your pre-trained scGPT model files in that directory with the following structure:
```
models/scGPT/
├── args.json              # Model configuration (required)
├── best_model.pt          # Model weights (required, or best_model.ckpt)
├── vocab.json             # Gene vocabulary (required)
├── var_dims.pkl           # Variable dimensions (optional)
└── pert_one-hot-map.pt    # Perturbation mapping (optional)
```

In [None]:
import json
import pickle
from pathlib import Path
import sys

import torch
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

# Set random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration

In [None]:
# Configuration
MODEL_DIR = "models/scGPT"
DATA_PATH = "/home/b5cc/sanjukta.b5cc/aracneseq/datasets/k562_5k.h5ad"  # Update with your data path
MASK_RATIO = 0.2
VAL_FRACTION = 0.1
BATCH_SIZE = 16
NUM_BATCHES = None  # None = evaluate all batches
NUM_CELLS_VISUALIZE = 3

## Load scGPT Model

In [None]:
model_dir = Path(MODEL_DIR)

# Load model configuration
with open(model_dir / "args.json", "r") as f:
    model_args = json.load(f)

print("Model Configuration:")
for key, value in model_args.items():
    print(f"  {key}: {value}")

In [None]:
# Load vocabulary
with open(model_dir / "vocab.json", "r") as f:
    vocab = json.load(f)

print(f"Vocabulary size: {len(vocab)} genes")
print(f"First 5 genes: {list(vocab.keys())[:5]}")

In [None]:
# Load variable dimensions if available
var_dims_path = model_dir / "var_dims.pkl"
if var_dims_path.exists():
    with open(var_dims_path, "rb") as f:
        var_dims = pickle.load(f)
    print(f"Loaded variable dimensions: {var_dims}")
else:
    var_dims = None
    print("No var_dims.pkl found")

In [None]:
# Import scGPT
try:
    from scgpt.model import TransformerModel
    from scgpt.tokenizer import tokenize_and_pad_batch
    from scgpt.preprocess import Preprocessor
    print("scGPT imported successfully")
except ImportError:
    print("ERROR: scGPT not installed. Install with: pip install scgpt")
    raise

In [None]:
# Create model
model = TransformerModel(
    ntoken=model_args.get("ntoken", len(vocab) + 1),
    d_model=model_args.get("d_model", 512),
    nhead=model_args.get("nhead", 8),
    d_hid=model_args.get("d_hid", 512),
    nlayers=model_args.get("nlayers", 12),
    dropout=model_args.get("dropout", 0.2),
    pad_token=model_args.get("pad_token", "<pad>"),
    pad_value=model_args.get("pad_value", 0),
    do_mvc=model_args.get("do_mvc", True),
    do_dab=model_args.get("do_dab", False),
    use_batch_labels=model_args.get("use_batch_labels", False),
    domain_spec_batchnorm=model_args.get("domain_spec_batchnorm", False),
    n_input_bins=model_args.get("n_input_bins", 51),
)

print(f"Created model with {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Load model weights
checkpoint_path = model_dir / "best_model.pt"
if not checkpoint_path.exists():
    checkpoint_path = model_dir / "best_model.ckpt"

print(f"Loading weights from: {checkpoint_path}")

checkpoint = torch.load(checkpoint_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded from checkpoint dict")
else:
    model.load_state_dict(checkpoint)
    print("Loaded state dict directly")

model = model.to(device)
model.eval()

print("Model loaded successfully!")

## Load and Prepare Data

In [None]:
# Load data
print(f"Loading data from: {DATA_PATH}")
adata = sc.read_h5ad(DATA_PATH)
print(f"Loaded {len(adata)} cells with {adata.n_vars} genes")
print(f"Data shape: {adata.X.shape}")

In [None]:
# Convert to tensor
if hasattr(adata.X, 'toarray'):
    expression = adata.X.toarray()
else:
    expression = adata.X

dataset = torch.tensor(expression).long()

# Dataset statistics
NUM_BINS = int(dataset.max().item())
NUM_GENES = dataset.shape[1]
VOCAB_SIZE = NUM_BINS + 1

print(f"Number of genes: {NUM_GENES}")
print(f"Number of bins: {NUM_BINS}")
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Min value: {dataset.min().item()}")
print(f"Max value: {dataset.max().item()}")

In [None]:
# Split data
indices = np.arange(len(dataset))
train_idx, val_idx = train_test_split(
    indices,
    test_size=VAL_FRACTION,
    random_state=SEED
)

# Use validation split for testing
eval_dataset = dataset[val_idx]
print(f"Validation set: {len(eval_dataset)} cells")

# Create data loader
test_loader = DataLoader(
    TensorDataset(eval_dataset),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Use 0 for notebook
    pin_memory=True if device.type == 'cuda' else False
)

print(f"Number of batches: {len(test_loader)}")

## Define Imputation Function

In [None]:
def impute_with_scgpt(model, batch, mask, device):
    """
    Perform imputation using scGPT model.
    
    Args:
        model: scGPT model
        batch: Input data tensor [batch_size, num_genes]
        mask: Boolean mask indicating positions to impute [batch_size, num_genes]
        device: torch device
    
    Returns:
        Imputed data tensor [batch_size, num_genes]
    """
    with torch.no_grad():
        # Create masked input
        masked_batch = batch.clone()
        masked_batch[mask] = 0  # Mask token is typically 0
        
        # Forward pass through model
        output = model(
            masked_batch.to(device),
            src_key_padding_mask=None,
            batch_labels=None
        )
        
        # Get predictions
        if isinstance(output, dict):
            predictions = output.get('mlm_output', output.get('pred', output))
        else:
            predictions = output
        
        # Take argmax to get discrete predictions
        if predictions.dim() == 3:  # [batch, seq, vocab]
            predictions = predictions.argmax(dim=-1)
        
        # Combine original and imputed values
        result = batch.clone()
        result[mask] = predictions[mask]
        
        return result

## Run Imputation

In [None]:
print(f"Running imputation with mask_ratio={MASK_RATIO}")

all_original_masked = []
all_predicted_masked = []
sample_cells_original = []
sample_cells_imputed = []
sample_cells_masks = []

num_batches = NUM_BATCHES or len(test_loader)

for i, (batch,) in enumerate(tqdm(test_loader, total=num_batches, desc="Imputing")):
    if i >= num_batches:
        break
    
    batch = batch.to(device)
    
    # Create random mask
    mask = torch.rand_like(batch.float()) < MASK_RATIO
    
    # Impute masked positions
    imputed = impute_with_scgpt(model, batch, mask, device)
    
    # Collect results
    original_masked = batch[mask]
    predicted_masked = imputed[mask]
    
    all_original_masked.append(original_masked.cpu())
    all_predicted_masked.append(predicted_masked.cpu())
    
    # Save some cells for visualization
    if len(sample_cells_original) < NUM_CELLS_VISUALIZE:
        for j in range(min(NUM_CELLS_VISUALIZE - len(sample_cells_original), batch.size(0))):
            sample_cells_original.append(batch[j].cpu())
            sample_cells_imputed.append(imputed[j].cpu())
            sample_cells_masks.append(mask[j].cpu())

# Concatenate all results
all_original_masked = torch.cat(all_original_masked)
all_predicted_masked = torch.cat(all_predicted_masked)

print(f"Total masked positions evaluated: {len(all_original_masked):,}")

## Calculate Metrics

In [None]:
print("="*50)
print("scGPT Imputation Metrics")
print("="*50)

accuracy = (all_original_masked == all_predicted_masked).float().mean().item()
print(f"Exact match accuracy: {accuracy:.2%}")

mae = (all_original_masked - all_predicted_masked).abs().float().mean().item()
print(f"Mean Absolute Error (bins): {mae:.2f}")

within_k_metrics = {}
for k in [1, 3, 5, 10]:
    within_k = ((all_original_masked - all_predicted_masked).abs() <= k).float().mean().item()
    within_k_metrics[k] = within_k
    print(f"Within {k} bins: {within_k:.2%}")

print("="*50)

## Visualizations

In [None]:
# Scatter plot and error distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].scatter(
    all_original_masked.numpy(),
    all_predicted_masked.numpy(),
    alpha=0.3,
    s=10
)
axes[0].plot([0, NUM_BINS], [0, NUM_BINS], 'r--', label='Perfect prediction')
axes[0].set_xlabel('True Bin')
axes[0].set_ylabel('Predicted Bin')
axes[0].set_title('scGPT Imputation: Predicted vs True')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

errors = (all_predicted_masked - all_original_masked).numpy()
axes[1].hist(errors, bins=50, alpha=0.7)
axes[1].axvline(0, color='r', linestyle='--', label='Zero error')
axes[1].set_xlabel('Prediction Error (bins)')
axes[1].set_ylabel('Count')
axes[1].set_title('Prediction Error Distribution')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Individual cell visualizations
for cell_idx in range(len(sample_cells_original)):
    original_cell = sample_cells_original[cell_idx].numpy()
    imputed_cell = sample_cells_imputed[cell_idx].numpy()
    cell_mask = sample_cells_masks[cell_idx].numpy()
    
    fig, axes = plt.subplots(2, 1, figsize=(14, 6))
    
    gene_indices = np.arange(len(original_cell))
    
    # Original
    axes[0].bar(gene_indices, original_cell, alpha=0.7, label='Original', width=1.0)
    axes[0].scatter(gene_indices[cell_mask], original_cell[cell_mask],
                   c='red', s=20, zorder=5, label='Masked positions')
    axes[0].set_xlabel('Gene Index')
    axes[0].set_ylabel('Expression Bin')
    axes[0].set_title(f'Cell {cell_idx}: Original Expression (masked positions highlighted)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Imputed
    axes[1].bar(gene_indices, imputed_cell, alpha=0.7, label='Imputed', width=1.0, color='green')
    axes[1].scatter(gene_indices[cell_mask], imputed_cell[cell_mask],
                   c='darkgreen', s=20, zorder=5, label='Imputed positions')
    axes[1].set_xlabel('Gene Index')
    axes[1].set_ylabel('Expression Bin')
    axes[1].set_title(f'Cell {cell_idx}: scGPT Imputed Expression')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Cell-specific metrics
    cell_acc = (original_cell[cell_mask] == imputed_cell[cell_mask]).mean()
    cell_mae = np.abs(original_cell[cell_mask] - imputed_cell[cell_mask]).mean()
    print(f"Cell {cell_idx} - Accuracy: {cell_acc:.2%}, MAE: {cell_mae:.2f} bins")
    print()

## Save Results

In [None]:
# Save metrics to model directory
output_dir = model_dir / "imputation_results"
output_dir.mkdir(exist_ok=True)

metrics = {
    "model": "scGPT",
    "accuracy": accuracy,
    "mae_bins": mae,
    "within_k": within_k_metrics,
    "mask_ratio": MASK_RATIO,
    "num_masked_positions": len(all_original_masked),
    "model_dir": str(model_dir)
}

with open(output_dir / "metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)

print(f"Results saved to: {output_dir}")
print(f"Metrics saved to: {output_dir / 'metrics.json'}")

## Summary

This notebook demonstrated:
1. Loading a pre-trained scGPT model from checkpoint files
2. Running gene expression imputation on masked data
3. Evaluating imputation quality with multiple metrics
4. Visualizing results at both aggregate and single-cell levels

The results are saved to the model directory for comparison with other models.