# GraphSAGE Training for Memory R1

This notebook trains GraphSAGE for structural embeddings using the random walk co-occurrence approach.

**Requirements:**
- Google Colab with GPU runtime (T4 or L4)
- PyTorch Geometric installed

**Training Pipeline:**
1. Generate/load graph data
2. Extract GraphSAGE view and features
3. Generate random walks and pairs
4. Train model with skip-gram objective
5. Evaluate and export

## 1. Setup Environment

In [None]:
# Install PyTorch Geometric if not already installed
try:
    import torch_geometric
    print(f"PyTorch Geometric version: {torch_geometric.__version__}")
except ImportError:
    print("Installing PyTorch Geometric...")
    !pip install torch-geometric -q
    import torch_geometric
    print(f"Installed PyTorch Geometric: {torch_geometric.__version__}")

!git clone https://github.com/celestice106/graphsage
%cd graphsage


In [None]:
# Install other requirements
!pip install pyyaml scikit-learn matplotlib tqdm -q

In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = torch.device('cuda')
else:
    print("WARNING: No GPU available. Training will be slow.")
    device = torch.device('cpu')

print(f"\nUsing device: {device}")

In [None]:
# Add project to path (adjust if needed)
import sys
from pathlib import Path

# If running from notebooks folder
project_root = Path('.').absolute().parent
if (project_root / 'src').exists():
    sys.path.insert(0, str(project_root))
else:
    # If running from project root
    project_root = Path('.').absolute()
    sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")

## 2. Load Configuration

In [None]:
from config import load_config

# Load default config
config = load_config()

# Override for Colab training
config['training']['device'] = 'cuda'
config['training']['batch_size'] = 1024  # Larger batch for GPU
config['training']['epochs'] = 100
config['training']['log_every'] = 5

# Print key settings
print("Configuration:")
print(f"  Walk length: {config['walks']['length']}")
print(f"  Walks per node: {config['walks']['per_node']}")
print(f"  Context window: {config['walks']['context_window']}")
print(f"  Hidden dim: {config['model']['hidden_dim']}")
print(f"  Output dim: {config['model']['output_dim']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")

## 3. Create/Load Graph Data

In [None]:
from src.data import GraphSAGEDataset, GraphLoader

# Configuration for graph generation
NUM_MEMORIES = 500  # Number of memory nodes
NUM_ENTITIES = 100  # Number of entity nodes
SEED = 42

# Create synthetic graph (or load from file)
print("Creating synthetic graph...")
dataset = GraphSAGEDataset.from_mock(
    num_memories=NUM_MEMORIES,
    num_entities=NUM_ENTITIES,
    seed=SEED,
    undirected=True
)

# Get statistics
stats = dataset.get_statistics()
print(f"\nGraph Statistics:")
print(f"  Nodes: {stats['num_nodes']}")
print(f"  Edges: {stats['num_edges']}")
print(f"  Avg degree: {stats['avg_degree']:.2f}")
print(f"  Density: {stats['density']:.4f}")
print(f"  Isolated nodes: {stats['isolated_nodes']}")

In [None]:
# Move data to GPU
data = dataset.get_data(device=device)

print(f"Features shape: {data.x.shape}")
print(f"Edge index shape: {data.edge_index.shape}")
print(f"Data on: {data.x.device}")

## 4. Generate Random Walks and Pairs

In [None]:
from src.walks import RandomWalkGenerator, CooccurrencePairSampler, DegreeBiasedNegativeSampler
import time

# Configuration
walk_config = config['walks']

print("Generating random walks...")
start_time = time.time()

# Create walker (works on CPU for efficiency)
walker = RandomWalkGenerator(
    edge_index=data.edge_index.cpu(),
    num_nodes=data.num_nodes,
    walk_length=walk_config['length'],
    walks_per_node=walk_config['per_node'],
    seed=walk_config['seed']
)

walks = walker.generate_all_walks(verbose=True)

print(f"\nGenerated {len(walks)} walks in {time.time() - start_time:.2f}s")
print(f"Average walk length: {sum(len(w) for w in walks) / len(walks):.1f}")

In [None]:
# Extract co-occurrence pairs
print("\nExtracting positive pairs...")
start_time = time.time()

pair_sampler = CooccurrencePairSampler(
    context_window=walk_config['context_window']
)
pairs = pair_sampler.extract_pairs(walks)

print(f"Extracted {len(pairs):,} pairs in {time.time() - start_time:.2f}s")

# Get pair statistics
pair_stats = pair_sampler.get_statistics(walks)
print(f"Unique pairs: {pair_stats['unique_pairs']:,}")
print(f"Avg contexts per target: {pair_stats['avg_contexts_per_target']:.1f}")

In [None]:
# Create negative sampler
neg_config = config['negatives']

neg_sampler = DegreeBiasedNegativeSampler(
    edge_index=data.edge_index,
    num_nodes=data.num_nodes,
    exponent=neg_config['exponent'],
    device=device
)

print(f"Negative sampler ready (exponent={neg_config['exponent']})")

## 5. Create Model

In [None]:
from src.model import ProductionGraphSAGE

model_config = config['model']
feature_dim = config['features']['dimensions']

model = ProductionGraphSAGE(
    in_channels=feature_dim,
    hidden_channels=model_config['hidden_dim'],
    out_channels=model_config['output_dim'],
    num_layers=model_config['num_layers'],
    dropout=model_config['dropout'],
    normalize_output=model_config['normalize_output']
).to(device)

print(f"Model created:")
print(f"  Parameters: {model.count_parameters():,}")
print(f"  Input dim: {feature_dim}")
print(f"  Hidden dim: {model_config['hidden_dim']}")
print(f"  Output dim: {model_config['output_dim']}")
print(model)

## 6. Train Model

In [None]:
from src.training import GraphSAGETrainer

# Create output directories
import os
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# Update paths in config
config['paths']['checkpoints'] = 'checkpoints'
config['paths']['logs'] = 'logs'

# Create trainer
trainer = GraphSAGETrainer(
    model=model,
    features=data.x,
    edge_index=data.edge_index,
    positive_pairs=pairs,
    negative_sampler=neg_sampler,
    config=config
)

In [None]:
# Train!
train_config = config['training']
num_epochs = train_config['epochs']

print(f"Starting training for {num_epochs} epochs...")
print("=" * 60)

best_loss = trainer.train(num_epochs=num_epochs)

print("=" * 60)
print(f"Training complete! Best loss: {best_loss:.4f}")

## 7. Evaluate Embeddings

In [None]:
from src.utils.metrics import evaluate_embeddings, check_embedding_health

# Get embeddings
embeddings = trainer.get_embeddings()
print(f"Embeddings shape: {embeddings.shape}")

# Health check
is_healthy, issues = check_embedding_health(embeddings)
print(f"\nHealth check: {'PASSED' if is_healthy else 'FAILED'}")
if issues:
    for issue in issues:
        print(f"  - {issue}")

In [None]:
# Full evaluation
results = evaluate_embeddings(embeddings, data.edge_index)

print("\n" + "=" * 60)
print("Evaluation Results")
print("=" * 60)

print(f"\nNeighbor Similarity:")
print(f"  Connected pairs: {results['neighbor_similarity']['neighbor_sim_mean']:.4f}")
print(f"  Random pairs: {results['neighbor_similarity']['random_sim_mean']:.4f}")
print(f"  Gap: {results['neighbor_similarity']['sim_gap']:.4f}")

print(f"\nLink Prediction:")
print(f"  AUC-ROC: {results['link_prediction']['auc_roc']:.4f}")
print(f"  Avg Precision: {results['link_prediction']['avg_precision']:.4f}")

print(f"\nEmbedding Quality:")
print(f"  Normalized: {results['embedding_stats']['is_normalized']}")
print(f"  Collapsed: {results['embedding_stats']['is_collapsed']}")

## 8. Visualize Embeddings

In [None]:
from src.utils.visualization import plot_embeddings_tsne, plot_similarity_distribution

# t-SNE visualization
print("Creating t-SNE visualization...")
plot_embeddings_tsne(
    embeddings,
    output_path='embeddings_tsne.png',
    show=True,
    sample_size=300  # Sample for faster visualization
)

In [None]:
# Similarity distribution
print("Creating similarity distribution plot...")
plot_similarity_distribution(
    embeddings,
    data.edge_index,
    output_path='similarity_distribution.png',
    show=True
)

## 9. Export Model

In [None]:
import os
os.makedirs('exports', exist_ok=True)

# Save for production use
export_path = 'exports/graphsage_production.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'embedding_dim': model.out_channels,
    'in_channels': model.in_channels,
    'best_loss': best_loss,
    'evaluation': results
}, export_path)

print(f"Model exported to: {export_path}")
print(f"File size: {os.path.getsize(export_path) / 1024:.1f} KB")

## 10. Test Inference

In [None]:
from src.inference import MemoryR1StructuralEncoder

# Create encoder from exported model
encoder = MemoryR1StructuralEncoder(
    model_path=export_path,
    device='cuda',
    cache_embeddings=True
)

# Test inference
test_embeddings = encoder.encode_all(dataset.full_graph)

print(f"Inference successful!")
print(f"Embeddings shape: {test_embeddings.shape}")
print(f"Embedding dim: {encoder.embedding_dim}")

In [None]:
# Benchmark inference
import time

# Warm up
for _ in range(5):
    _ = encoder.encode_all(dataset.full_graph, force_recompute=True)
torch.cuda.synchronize()

# Benchmark
times = []
for _ in range(50):
    start = time.perf_counter()
    _ = encoder.encode_all(dataset.full_graph, force_recompute=True)
    torch.cuda.synchronize()
    times.append(time.perf_counter() - start)

import numpy as np
times_ms = [t * 1000 for t in times]

print(f"\nInference Benchmark:")
print(f"  Mean: {np.mean(times_ms):.2f} ms")
print(f"  Std: {np.std(times_ms):.2f} ms")
print(f"  Min: {np.min(times_ms):.2f} ms")
print(f"  Max: {np.max(times_ms):.2f} ms")
print(f"  Throughput: {1000 / np.mean(times_ms):.1f} inferences/sec")

## Summary

Training complete! The model is exported and ready for use with Memory R1.

**Next Steps:**
1. Download `exports/graphsage_production.pt` for use in Memory R1
2. Or save to Google Drive for persistence:
   ```python
   from google.colab import drive
   drive.mount('/content/drive')
   !cp exports/graphsage_production.pt /content/drive/MyDrive/
   ```