# GEARS Model Usage

This notebook demonstrates how to use the GEARS model for predicting perturbation outcomes.

## Features
- Load GEARS model
- Build perturbation and co-expression graphs
- Prepare data for GEARS
- Run predictions


In [None]:
import torch
import numpy as np
from perturblab.models import Model
from perturblab.models.gears import GEARSConfig
from perturblab.types import GeneVocab
from perturblab.methods.gears import build_perturbation_graph

# Create gene vocabulary
genes = ['TP53', 'BRCA1', 'KRAS', 'MYC', 'EGFR']
gene_vocab = GeneVocab(genes)
print(f"Gene vocabulary: {len(gene_vocab)} genes")


## Build Graphs for GEARS


In [None]:
import pandas as pd

# Build perturbation graph (GO-based)
pert_graph = build_perturbation_graph(
    gene_vocab,
    similarity='jaccard',
    threshold=0.1,
    show_progress=True
)

# Create dummy co-expression graph (in practice, compute from expression data)
# For demo, create a simple graph
from perturblab.methods.gears import dataframe_to_weighted_graph
coexpr_edges = pd.DataFrame({
    'source': ['TP53', 'BRCA1', 'KRAS'],
    'target': ['BRCA1', 'TP53', 'MYC'],
    'weight': [0.5, 0.5, 0.3]
})
coexpr_graph = dataframe_to_weighted_graph(coexpr_edges, node_names=gene_vocab.itos)

print(f"Perturbation graph: {pert_graph.n_nodes} nodes, {pert_graph.n_unique_edges} edges")
print(f"Co-expression graph: {coexpr_graph.n_nodes} nodes, {coexpr_graph.n_unique_edges} edges")


## Load GEARS Model


In [None]:
config = GEARSConfig(
    num_genes=len(gene_vocab),
    num_perts=10,  # Number of unique perturbations in your dataset
    hidden_size=64,
    num_gene_gnn_layers=2,
    num_go_gnn_layers=1,
    decoder_hidden_size=16,
    num_similar_genes_go_graph=20,
    num_similar_genes_co_express_graph=20,
    coexpress_threshold=0.4,
    uncertainty=False,
    uncertainty_reg=1.0,
    direction_lambda=0.1,
    no_perturb=False
)

# Convert graphs to edge indices and weights for model
# In practice, you'd convert pert_graph and coexpr_graph properly
# For demo, create dummy edge indices and weights
from perturblab.methods.gears import weighted_graph_to_dataframe

# Convert perturbation graph to edge list
pert_edges = weighted_graph_to_dataframe(pert_graph, include_node_names=False)
G_go = torch.tensor(pert_edges[['source', 'target']].values, dtype=torch.long).t()
G_go_weight = torch.tensor(pert_edges['weight'].values, dtype=torch.float32)

# Convert co-expression graph to edge list
coexpr_edges = weighted_graph_to_dataframe(coexpr_graph, include_node_names=False)
G_coexpress = torch.tensor(coexpr_edges[['source', 'target']].values, dtype=torch.long).t()
G_coexpress_weight = torch.tensor(coexpr_edges['weight'].values, dtype=torch.float32)

# Create GEARS model (requires graphs in __init__)
from perturblab.models.gears._modeling import GEARSModel
model = GEARSModel(
    config=config,
    G_coexpress=G_coexpress,
    G_coexpress_weight=G_coexpress_weight,
    G_go=G_go,
    G_go_weight=G_go_weight,
    device='cpu'  # Use 'cuda' if GPU available
)

print(f"Model loaded: {type(model).__name__}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


## Prepare Input Data


In [None]:
from perturblab.models.gears.io import GEARSInput
from torch_geometric.data import Data, Batch

# Create sample input data
# In practice, this would come from your dataset
num_samples = 3
num_genes = len(gene_vocab)

# Baseline expression (control condition)
x = torch.randn(num_samples * num_genes)  # Flattened across samples

# Perturbation indices for each sample
# Sample 0: perturb genes 0, 1 (TP53, BRCA1)
# Sample 1: perturb gene 2 (KRAS)
# Sample 2: control (no perturbation)
pert_idx = [[0, 1], [2], [-1]]

# Batch assignment (which sample each gene belongs to)
batch = torch.cat([torch.full((num_genes,), i) for i in range(num_samples)])

# Create PyTorch Geometric Data object
data = Data(x=x, pert_idx=pert_idx, batch=batch)
batch_data = Batch.from_data_list([data])

print(f"Input shape: {batch_data.x.shape}")
print(f"Batch size: {batch_data.batch.max().item() + 1}")
print(f"Perturbations: {pert_idx}")


## Run Prediction

In [None]:
# Create GEARS input
# Note: GEARSInput uses gene_expression and graph_batch_indices, not x and batch
gears_input = GEARSInput(
    gene_expression=batch_data.x,
    pert_idx=pert_idx,
    graph_batch_indices=batch_data.batch,
)

# Run model (in eval mode)
model.eval()
with torch.no_grad():
    output = model(gears_input)

print(f"Output shape: {output.predictions.shape}")
print(f"Predictions for first sample: {output.predictions[0, :5]}")
