# Graph Hypernetwork Forge - Getting Started

This notebook demonstrates the core capabilities of Graph Hypernetwork Forge, a framework for generating GNN weights dynamically from textual node descriptions.

## Key Features
- **Dynamic Weight Generation**: GNN parameters generated from text
- **Zero-Shot Transfer**: Apply to unseen knowledge graph domains
- **Modular Architecture**: Swap GNN backends and text encoders
- **Production Ready**: Training, evaluation, and deployment tools

## Installation

```bash
pip install graph-hypernetwork-forge
```

Or for development:
```bash
git clone https://github.com/danieleschmidt/Graph-Hypernetwork-Forge.git
cd Graph-Hypernetwork-Forge
pip install -e .
```

In [None]:
# Import required libraries
import torch
import numpy as np
import matplotlib.pyplot as plt

from graph_hypernetwork_forge import HyperGNN, TextualKnowledgeGraph
from graph_hypernetwork_forge.utils import SyntheticDataGenerator, HyperGNNTrainer

print("Graph Hypernetwork Forge loaded successfully!")

## 1. Creating Knowledge Graphs with Textual Metadata

Let's start by creating a knowledge graph where each node has textual descriptions.

In [None]:
# Create a simple knowledge graph
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3, 3, 0],
    [1, 0, 2, 1, 3, 2, 0, 3]
], dtype=torch.long)

node_texts = [
    "Alice is a machine learning researcher specializing in neural networks.",
    "Bob works as a software engineer developing AI applications.", 
    "Carol is a data scientist focusing on natural language processing.",
    "David is a product manager for AI-powered tools."
]

# Optional: add node features
node_features = torch.randn(4, 16)
node_labels = torch.tensor([0, 1, 0, 1])  # 0: Technical, 1: Management

# Create the knowledge graph
kg = TextualKnowledgeGraph(
    edge_index=edge_index,
    node_texts=node_texts,
    node_features=node_features,
    node_labels=node_labels,
    metadata={"domain": "professional_network"}
)

print(f"Created knowledge graph with {kg.num_nodes} nodes and {kg.num_edges} edges")
print(f"Domain: {kg.metadata['domain']}")

# Display graph statistics
stats = kg.statistics()
for key, value in stats.items():
    if key != "metadata":
        print(f"{key}: {value}")

## 2. HyperGNN Model Initialization

The HyperGNN model consists of three main components:
1. **Text Encoder**: Converts text to embeddings
2. **Hypernetwork**: Generates GNN weights from text embeddings
3. **Dynamic GNN**: Applies generated weights to graph data

In [None]:
# Initialize HyperGNN model
model = HyperGNN(
    text_encoder="sentence-transformers/all-MiniLM-L6-v2",
    gnn_backbone="GAT",  # Can be GCN, GAT, or SAGE
    hidden_dim=128,
    num_layers=2,
    dropout=0.1
)

print(f"Model Architecture:")
print(f"  Text Encoder: {model.text_encoder_name}")
print(f"  GNN Backbone: {model.gnn_backbone}")
print(f"  Hidden Dimension: {model.hidden_dim}")
print(f"  Number of Layers: {model.num_layers}")

# Count model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total Parameters: {total_params:,}")
print(f"  Trainable Parameters: {trainable_params:,}")

## 3. Dynamic Weight Generation

The key innovation is generating GNN weights dynamically from textual descriptions.

In [None]:
# Generate weights from node texts
model.eval()
with torch.no_grad():
    weights = model.generate_weights(kg.node_texts)

print(f"Generated weights for {len(weights)} layers:")
for layer_idx, layer_weights in enumerate(weights):
    print(f"\nLayer {layer_idx}:")
    for weight_name, weight_tensor in layer_weights.items():
        print(f"  {weight_name}: {list(weight_tensor.shape)}")

# Visualize weight variation across different text descriptions
first_layer_weights = weights[0]["weight"]  # [num_nodes, in_dim, out_dim]
weight_norms = torch.norm(first_layer_weights, dim=(1, 2))

plt.figure(figsize=(10, 4))
plt.bar(range(len(kg.node_texts)), weight_norms.numpy())
plt.xlabel('Node Index')
plt.ylabel('Weight Norm')
plt.title('Weight Variation Across Different Text Descriptions')
plt.xticks(range(len(kg.node_texts)), [f'Node {i}' for i in range(len(kg.node_texts))])
plt.tight_layout()
plt.show()

print("\nWeight norms for each node:")
for i, norm in enumerate(weight_norms):
    print(f"  Node {i} ('{kg.node_texts[i][:30]}...'): {norm.item():.4f}")

## 4. Forward Pass and Predictions

Now let's perform a forward pass to get node embeddings.

In [None]:
# Perform forward pass
model.eval()
with torch.no_grad():
    node_embeddings = model(kg.edge_index, kg.node_features, kg.node_texts)

print(f"Generated node embeddings: {node_embeddings.shape}")
print(f"Embedding sample (first 5 dimensions):")
for i, embedding in enumerate(node_embeddings[:, :5]):
    print(f"  Node {i}: {embedding.tolist()}")

# Compute pairwise similarities
similarities = torch.cosine_similarity(
    node_embeddings.unsqueeze(1), 
    node_embeddings.unsqueeze(0), 
    dim=2
)

# Visualize similarity matrix
plt.figure(figsize=(8, 6))
plt.imshow(similarities.numpy(), cmap='viridis', vmin=0, vmax=1)
plt.colorbar(label='Cosine Similarity')
plt.title('Node Embedding Similarities')
plt.xlabel('Node Index')
plt.ylabel('Node Index')
plt.tight_layout()
plt.show()

print("\nPairwise similarities:")
for i in range(len(kg.node_texts)):
    for j in range(i+1, len(kg.node_texts)):
        sim = similarities[i, j].item()
        print(f"  Node {i} ↔ Node {j}: {sim:.4f}")

## 5. Zero-Shot Transfer Demonstration

The key advantage of HyperGNN is its ability to work on completely new domains without retraining.

In [None]:
# Generate synthetic data from different domains
generator = SyntheticDataGenerator(seed=42)

# Source domain: Professional network (similar to our example)
source_kg = generator.generate_social_network(num_nodes=6, num_classes=2)

# Target domain: Academic papers (very different text patterns)
target_kg = generator.generate_citation_network(num_nodes=6, num_classes=2)

print("Source Domain (Social Network):")
for i, text in enumerate(source_kg.node_texts[:3]):
    print(f"  {i+1}. {text}")

print("\nTarget Domain (Citation Network):")
for i, text in enumerate(target_kg.node_texts[:3]):
    print(f"  {i+1}. {text}")

# Apply the SAME model to both domains
model.eval()
with torch.no_grad():
    # Source domain predictions
    source_embeddings = model(
        source_kg.edge_index, 
        source_kg.node_features, 
        source_kg.node_texts
    )
    
    # Target domain predictions (zero-shot!)
    target_embeddings = model(
        target_kg.edge_index, 
        target_kg.node_features, 
        target_kg.node_texts
    )

print(f"\nSource embeddings shape: {source_embeddings.shape}")
print(f"Target embeddings shape: {target_embeddings.shape}")
print("✅ Zero-shot transfer successful!")

# Analyze domain differences
source_mean = source_embeddings.mean(dim=0)
target_mean = target_embeddings.mean(dim=0)
domain_similarity = torch.cosine_similarity(source_mean, target_mean, dim=0)

print(f"\nDomain-level embedding similarity: {domain_similarity.item():.4f}")

## 6. Training a Model

Let's demonstrate how to train a HyperGNN model for a specific task.

In [None]:
# Generate training data
generator = SyntheticDataGenerator(seed=123)
train_graphs = []
for i in range(5):
    graph = generator.generate_social_network(num_nodes=20, num_classes=3)
    train_graphs.append(graph)

val_graphs = []
for i in range(2):
    graph = generator.generate_social_network(num_nodes=15, num_classes=3)
    val_graphs.append(graph)

print(f"Generated {len(train_graphs)} training graphs and {len(val_graphs)} validation graphs")

# Create a smaller model for quick training
train_model = HyperGNN(
    text_encoder="sentence-transformers/all-MiniLM-L6-v2",
    gnn_backbone="GAT",
    hidden_dim=64,
    num_layers=2,
)

# Set up trainer
trainer = HyperGNNTrainer(
    model=train_model,
    device="cpu",  # Use CPU for demo
)

# Train for a few epochs (demo purposes)
print("\nStarting training...")
history = trainer.train(
    train_graphs=train_graphs,
    val_graphs=val_graphs,
    num_epochs=3,  # Short training for demo
    task_type="node_classification",
    early_stopping_patience=10,
)

print("Training completed!")

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
if 'val_accuracy' in history:
    plt.plot(history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()

## 7. Advanced Features

### Custom Text Encoders
You can easily swap different text encoders:

In [None]:
# Compare different text encoders
encoders = [
    "sentence-transformers/all-MiniLM-L6-v2",
    "sentence-transformers/all-mpnet-base-v2",
    "bert-base-uncased"  # This will use transformers instead of sentence-transformers
]

sample_texts = [
    "A machine learning researcher working on deep learning models.",
    "A software engineer developing mobile applications."
]

for encoder_name in encoders:
    try:
        model = HyperGNN(
            text_encoder=encoder_name,
            hidden_dim=32,
            num_layers=1
        )
        
        # Generate embeddings
        model.eval()
        with torch.no_grad():
            embeddings = model.text_encoder(sample_texts)
        
        print(f"Encoder: {encoder_name}")
        print(f"  Embedding shape: {embeddings.shape}")
        print(f"  Embedding norm: {torch.norm(embeddings, dim=1).tolist()}")
        print()
        
    except Exception as e:
        print(f"Encoder {encoder_name}: Error - {e}")
        print()

### Graph Manipulation and Analysis

In [None]:
# Load from JSON (if you have a JSON file)
# kg_from_json = TextualKnowledgeGraph.from_json("your_graph.json")

# Create subgraphs
original_kg = generator.generate_social_network(num_nodes=10, num_classes=2)
subgraph = original_kg.subgraph([0, 1, 2, 3, 4])  # First 5 nodes

print(f"Original graph: {original_kg.num_nodes} nodes, {original_kg.num_edges} edges")
print(f"Subgraph: {subgraph.num_nodes} nodes, {subgraph.num_edges} edges")

# Get neighbor information
center_node = 0
neighbor_texts = original_kg.get_neighbor_texts(center_node, k_hops=1)

print(f"\nCenter node: {original_kg.node_texts[center_node][:50]}...")
print(f"Neighbors:")
for i, text in enumerate(neighbor_texts[:3]):  # Show first 3 neighbors
    print(f"  {i+1}. {text[:50]}...")

# Convert to PyTorch Geometric format
pyg_data = original_kg.to_pyg_data()
print(f"\nPyTorch Geometric Data:")
print(f"  Number of nodes: {pyg_data.num_nodes}")
print(f"  Edge index shape: {pyg_data.edge_index.shape}")
if hasattr(pyg_data, 'x'):
    print(f"  Node features shape: {pyg_data.x.shape}")
if hasattr(pyg_data, 'y'):
    print(f"  Node labels shape: {pyg_data.y.shape}")

## 8. Performance Tips

### Batch Processing
For large graphs, process in batches:

In [None]:
# Generate a larger graph
large_kg = generator.generate_social_network(num_nodes=100, num_classes=5)

# Process in batches (simulate large-scale processing)
batch_size = 20
all_embeddings = []

model.eval()
with torch.no_grad():
    for i in range(0, large_kg.num_nodes, batch_size):
        end_idx = min(i + batch_size, large_kg.num_nodes)
        
        # Create subgraph for this batch
        batch_nodes = list(range(i, end_idx))
        batch_kg = large_kg.subgraph(batch_nodes)
        
        # Process batch
        batch_embeddings = model(
            batch_kg.edge_index,
            batch_kg.node_features,
            batch_kg.node_texts
        )
        
        all_embeddings.append(batch_embeddings)
        print(f"Processed batch {i//batch_size + 1}: nodes {i}-{end_idx-1}")

# Combine all embeddings
final_embeddings = torch.cat(all_embeddings, dim=0)
print(f"\nFinal embeddings shape: {final_embeddings.shape}")
print(f"Successfully processed {large_kg.num_nodes} nodes!")

## Summary

In this notebook, we've demonstrated:

1. **Knowledge Graph Creation**: Building graphs with textual metadata
2. **Dynamic Weight Generation**: Creating GNN weights from text descriptions
3. **Zero-Shot Transfer**: Applying models to new domains without retraining
4. **Training Pipeline**: Complete training and evaluation workflow
5. **Advanced Features**: Custom encoders, graph manipulation, and batch processing

## Next Steps

- Explore the `scripts/` directory for command-line tools
- Check out `examples/` for more specialized use cases
- Read the documentation for API details
- Try different GNN backbones (GCN, GraphSAGE)
- Experiment with domain-specific text encoders

Happy hypernetworking! 🚀