# Protein Structure to Graph Converter

## Tool Description

This notebook converts protein structures (PDB format) into residue-level graph representations for geometric deep learning applications.

### Purpose
- **Input:** Directory containing one or more PDB files
- **Output:** Graph objects representing residue-level protein structure graphs
- **Graph Construction:** Distance cutoff-based edges between Cα atoms

### Graph Definition (Scientific Specification)

**Nodes:** One node per residue with Cα atom present

**Node Features:**
- Amino acid one-hot encoding (20 standard amino acids)
- Residue index (sequential position)
- Chain ID (encoded as integer)
- Cα xyz coordinates (3D spatial position)

**Edges:**
- Undirected edges when Cα-Cα distance ≤ cutoff threshold
- Edge features: Euclidean distance (Å)
- Bidirectional representation (both edge directions stored)

### Two Graph Generation Modes

**1. Monomer Mode:**
- One graph per chain
- Only intra-chain edges
- Output: `protein_chainA.pt`

**2. Complex Mode:**
- One graph per structure (all chains combined)
- Includes inter-chain edges if within cutoff
- Additional edge feature: same_chain_flag (0 or 1)
- Output: `protein_complex.pt`

### Usage Instructions

1. **Edit the User Configuration cell** below with your paths and parameters
2. **Run all cells** sequentially (Cell → Run All)
3. **Check outputs** in the specified OUTPUT_DIR

### Output Files

```
OUTPUT_DIR/
├── graphs_monomer/              # Per-chain graphs
├── graphs_complex/              # Multi-chain complex graphs
├── dataset_summary_monomer.csv
├── dataset_summary_complex.csv
├── failed_structures.txt
└── graph_definition.json
```

---

## 1. Installation & Imports

In [None]:
# Install required packages (if not already installed)
# Uncomment the following line if running in Google Colab or a fresh environment
# !pip install biopython torch torch-geometric numpy pandas tqdm

In [None]:
import os
import json
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Set
from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from tqdm.auto import tqdm

from Bio import BiopythonWarning
from Bio.PDB import PDBParser, PPBuilder
from Bio.PDB.Structure import Structure
from Bio.PDB.Chain import Chain
from Bio.PDB.Residue import Residue

# Suppress Biopython warnings for cleaner output
warnings.simplefilter('ignore', BiopythonWarning)

print("✓ All imports successful")
print(f"  PyTorch version: {torch.__version__}")
print(f"  NumPy version: {np.__version__}")
print(f"  Pandas version: {pd.__version__}")

## 2. Version & Reproducibility

In [None]:
# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"✓ Random seed set to {RANDOM_SEED}")
print(f"  Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 3. User Configuration

**⚠️ EDIT THIS CELL WITH YOUR PARAMETERS**

In [None]:
# ============================================================
# USER CONFIGURATION - EDIT THIS CELL ONLY
# ============================================================

# Input/Output Directories
INPUT_DIR = "data/pdb_files/"           # Folder containing PDB files
OUTPUT_DIR = "data/graphs/"             # Output directory for graphs

# Graph Construction Parameters
DISTANCE_CUTOFF_ANGSTROM = 8.0          # Cα-Cα distance cutoff (Å)

# Structure Filtering Options
REMOVE_WATER = True                      # Remove water residues
REMOVE_HETERO = True                     # Remove heteroatoms (ligands, ions)
MIN_RESIDUES = 10                        # Minimum residues per chain

# Output Format
SAVE_FORMAT = "pyg"                      # 'pyg' for PyTorch Geometric

# ============================================================
# END OF USER CONFIGURATION
# ============================================================

# Validate configuration
assert Path(INPUT_DIR).exists(), f"INPUT_DIR does not exist: {INPUT_DIR}"
assert DISTANCE_CUTOFF_ANGSTROM > 0, "DISTANCE_CUTOFF_ANGSTROM must be positive"
assert MIN_RESIDUES > 0, "MIN_RESIDUES must be positive"
assert SAVE_FORMAT in ["pyg"], f"SAVE_FORMAT must be 'pyg', got {SAVE_FORMAT}"

# Create output directories
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR, "graphs_monomer").mkdir(exist_ok=True)
Path(OUTPUT_DIR, "graphs_complex").mkdir(exist_ok=True)

print("✓ Configuration validated")
print(f"  Input directory: {INPUT_DIR}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Distance cutoff: {DISTANCE_CUTOFF_ANGSTROM} Å")
print(f"  Min residues: {MIN_RESIDUES}")

## 4. Amino Acid Encoding Tables

In [None]:
# Standard 20 amino acids (3-letter to 1-letter code)
AA_3_TO_1 = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}

# Mapping amino acids to indices for one-hot encoding
AA_TO_INDEX = {aa: idx for idx, aa in enumerate(sorted(AA_3_TO_1.keys()))}
INDEX_TO_AA = {idx: aa for aa, idx in AA_TO_INDEX.items()}

NUM_AA_TYPES = len(AA_TO_INDEX)

print(f"✓ Amino acid encoding tables created")
print(f"  Number of standard amino acids: {NUM_AA_TYPES}")
print(f"  Example: {list(AA_TO_INDEX.items())[:3]}")

## 5. PDB Parsing Utilities

In [None]:
@dataclass
class ResidueData:
    """Container for residue information."""
    residue_name: str
    residue_index: int
    chain_id: str
    ca_coords: np.ndarray
    

def is_valid_residue(residue: Residue, remove_water: bool, remove_hetero: bool) -> bool:
    """
    Check if a residue should be included in the graph.
    
    Args:
        residue: BioPython Residue object
        remove_water: Whether to exclude water molecules
        remove_hetero: Whether to exclude heteroatoms
        
    Returns:
        bool: True if residue is valid for inclusion
    """
    hetfield = residue.get_id()[0]
    
    # Check if water
    if remove_water and hetfield == 'W':
        return False
    
    # Check if heteroatom (ligands, ions, modified residues)
    if remove_hetero and hetfield.startswith('H'):
        return False
    
    # Must be a standard amino acid
    if residue.get_resname() not in AA_TO_INDEX:
        return False
    
    # Must have Cα atom
    if 'CA' not in residue:
        return False
    
    return True


def parse_pdb_structure(pdb_path: str, remove_water: bool, remove_hetero: bool) -> Optional[Structure]:
    """
    Parse a PDB file and return the structure.
    
    Args:
        pdb_path: Path to PDB file
        remove_water: Whether to exclude water molecules
        remove_hetero: Whether to exclude heteroatoms
        
    Returns:
        Structure object or None if parsing fails
    """
    parser = PDBParser(QUIET=True)
    
    try:
        structure_id = Path(pdb_path).stem
        structure = parser.get_structure(structure_id, pdb_path)
        return structure
    except Exception as e:
        print(f"  ✗ Failed to parse {pdb_path}: {str(e)}")
        return None


def extract_residue_data(
    chain: Chain,
    remove_water: bool,
    remove_hetero: bool
) -> List[ResidueData]:
    """
    Extract residue data from a protein chain.
    
    Args:
        chain: BioPython Chain object
        remove_water: Whether to exclude water molecules
        remove_hetero: Whether to exclude heteroatoms
        
    Returns:
        List of ResidueData objects
    """
    residue_list = []
    
    for residue in chain:
        if not is_valid_residue(residue, remove_water, remove_hetero):
            continue
        
        # Extract Cα coordinates
        ca_atom = residue['CA']
        ca_coords = ca_atom.get_coord()
        
        # Store residue data
        residue_data = ResidueData(
            residue_name=residue.get_resname(),
            residue_index=residue.get_id()[1],  # Residue sequence number
            chain_id=chain.get_id(),
            ca_coords=ca_coords
        )
        residue_list.append(residue_data)
    
    return residue_list


print("✓ PDB parsing utilities defined")

## 6. Residue Feature Builder

In [None]:
def build_node_features(
    residues: List[ResidueData],
    chain_to_idx: Optional[Dict[str, int]] = None
) -> np.ndarray:
    """
    Build node feature matrix from residue data.
    
    Features per node:
    - [0:20]: Amino acid one-hot encoding
    - [20]: Residue index
    - [21]: Chain ID (encoded)
    - [22:25]: Cα xyz coordinates
    
    Args:
        residues: List of ResidueData objects
        chain_to_idx: Optional mapping from chain_id to integer index
        
    Returns:
        Node feature matrix of shape (num_residues, 25)
    """
    num_residues = len(residues)
    node_features = np.zeros((num_residues, NUM_AA_TYPES + 5), dtype=np.float32)
    
    # Create chain encoding if not provided
    if chain_to_idx is None:
        unique_chains = sorted(set(r.chain_id for r in residues))
        chain_to_idx = {chain: idx for idx, chain in enumerate(unique_chains)}
    
    for i, res in enumerate(residues):
        # One-hot encode amino acid
        aa_idx = AA_TO_INDEX[res.residue_name]
        node_features[i, aa_idx] = 1.0
        
        # Residue index
        node_features[i, NUM_AA_TYPES] = res.residue_index
        
        # Chain ID (encoded)
        node_features[i, NUM_AA_TYPES + 1] = chain_to_idx[res.chain_id]
        
        # Cα coordinates
        node_features[i, NUM_AA_TYPES + 2:NUM_AA_TYPES + 5] = res.ca_coords
    
    return node_features


print("✓ Node feature builder defined")

## 7. Distance Computation Function

In [None]:
def compute_distance_matrix(ca_coords: np.ndarray) -> np.ndarray:
    """
    Compute pairwise Euclidean distance matrix between Cα atoms.
    
    Args:
        ca_coords: Cα coordinates array of shape (num_residues, 3)
        
    Returns:
        Distance matrix of shape (num_residues, num_residues)
    """
    # Efficient vectorized distance computation
    diff = ca_coords[:, np.newaxis, :] - ca_coords[np.newaxis, :, :]
    distances = np.sqrt(np.sum(diff ** 2, axis=2))
    return distances


print("✓ Distance computation function defined")

## 8. Cutoff Edge Builder

In [None]:
def build_edges_from_cutoff(
    distance_matrix: np.ndarray,
    cutoff: float,
    self_loops: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build edge index and edge attributes based on distance cutoff.
    
    Creates undirected edges (both directions) for residue pairs within cutoff.
    
    Args:
        distance_matrix: Pairwise distance matrix
        cutoff: Maximum distance for edge creation (Angstroms)
        self_loops: Whether to include self-loops (default: False)
        
    Returns:
        edge_index: Edge connectivity (2, num_edges)
        edge_attr: Edge features (num_edges, 1) containing distances
    """
    num_nodes = distance_matrix.shape[0]
    
    # Find edges within cutoff
    if self_loops:
        mask = distance_matrix <= cutoff
    else:
        mask = (distance_matrix <= cutoff) & (distance_matrix > 0)
    
    # Get edge indices
    src, dst = np.where(mask)
    
    # Edge features (distances)
    edge_distances = distance_matrix[src, dst]
    
    # Create edge index (PyTorch Geometric format)
    edge_index = np.stack([src, dst], axis=0)
    edge_attr = edge_distances.reshape(-1, 1)
    
    return edge_index, edge_attr


def build_edges_with_chain_info(
    distance_matrix: np.ndarray,
    chain_ids: np.ndarray,
    cutoff: float,
    self_loops: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build edges with additional chain information for multi-chain graphs.
    
    Args:
        distance_matrix: Pairwise distance matrix
        chain_ids: Array of chain IDs for each node
        cutoff: Maximum distance for edge creation
        self_loops: Whether to include self-loops
        
    Returns:
        edge_index: Edge connectivity (2, num_edges)
        edge_attr: Edge features (num_edges, 2) [distance, same_chain_flag]
    """
    # Build basic edges
    edge_index, edge_distances = build_edges_from_cutoff(
        distance_matrix, cutoff, self_loops
    )
    
    # Determine if edges are intra-chain or inter-chain
    src_chains = chain_ids[edge_index[0]]
    dst_chains = chain_ids[edge_index[1]]
    same_chain = (src_chains == dst_chains).astype(np.float32).reshape(-1, 1)
    
    # Combine edge features
    edge_attr = np.hstack([edge_distances, same_chain])
    
    return edge_index, edge_attr


print("✓ Edge builder functions defined")

## 9. Graph Builder Function

In [None]:
def build_protein_graph(
    residues: List[ResidueData],
    cutoff: float,
    include_chain_info: bool = False,
    chain_to_idx: Optional[Dict[str, int]] = None
) -> Data:
    """
    Build a PyTorch Geometric graph from residue data.
    
    Args:
        residues: List of ResidueData objects
        cutoff: Distance cutoff for edge creation (Angstroms)
        include_chain_info: Whether to include chain info in edge features
        chain_to_idx: Optional chain ID to index mapping
        
    Returns:
        PyTorch Geometric Data object
    """
    # Build node features
    node_features = build_node_features(residues, chain_to_idx)
    
    # Extract Cα coordinates for distance computation
    ca_coords = np.array([r.ca_coords for r in residues])
    
    # Compute distance matrix
    distance_matrix = compute_distance_matrix(ca_coords)
    
    # Build edges
    if include_chain_info:
        chain_ids = node_features[:, NUM_AA_TYPES + 1]  # Extract chain IDs
        edge_index, edge_attr = build_edges_with_chain_info(
            distance_matrix, chain_ids, cutoff, self_loops=False
        )
    else:
        edge_index, edge_attr = build_edges_from_cutoff(
            distance_matrix, cutoff, self_loops=False
        )
    
    # Convert to PyTorch tensors
    x = torch.from_numpy(node_features).float()
    edge_index = torch.from_numpy(edge_index).long()
    edge_attr = torch.from_numpy(edge_attr).float()
    
    # Create PyTorch Geometric Data object
    graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        num_nodes=len(residues)
    )
    
    return graph


print("✓ Graph builder function defined")

## 10. Monomer Graph Pipeline

### Generate Single-Chain Graphs

In [None]:
def process_monomer_graphs(
    input_dir: str,
    output_dir: str,
    cutoff: float,
    remove_water: bool,
    remove_hetero: bool,
    min_residues: int
) -> Tuple[List[Dict], List[str]]:
    """
    Process all PDB files and generate monomer (single-chain) graphs.
    
    Args:
        input_dir: Directory containing PDB files
        output_dir: Directory to save graphs
        cutoff: Distance cutoff (Angstroms)
        remove_water: Whether to exclude water
        remove_hetero: Whether to exclude heteroatoms
        min_residues: Minimum residues per chain
        
    Returns:
        summary_data: List of dictionaries with graph statistics
        failed_structures: List of failed structure IDs with reasons
    """
    pdb_files = list(Path(input_dir).glob("*.pdb"))
    
    if len(pdb_files) == 0:
        print(f"⚠ No PDB files found in {input_dir}")
        return [], []
    
    print(f"\n{'='*60}")
    print(f"MONOMER GRAPH GENERATION")
    print(f"{'='*60}")
    print(f"Found {len(pdb_files)} PDB files\n")
    
    summary_data = []
    failed_structures = []
    
    output_subdir = Path(output_dir) / "graphs_monomer"
    output_subdir.mkdir(exist_ok=True)
    
    for pdb_file in tqdm(pdb_files, desc="Processing structures"):
        protein_id = pdb_file.stem
        
        # Parse structure
        structure = parse_pdb_structure(str(pdb_file), remove_water, remove_hetero)
        if structure is None:
            failed_structures.append(f"{protein_id}: Failed to parse PDB")
            continue
        
        # Process each chain independently
        for model in structure:
            for chain in model:
                chain_id = chain.get_id()
                
                # Extract residues
                residues = extract_residue_data(chain, remove_water, remove_hetero)
                
                # Skip chains that are too short
                if len(residues) < min_residues:
                    failed_structures.append(
                        f"{protein_id}_chain{chain_id}: Only {len(residues)} residues (< {min_residues})"
                    )
                    continue
                
                # Build graph
                try:
                    graph = build_protein_graph(
                        residues,
                        cutoff=cutoff,
                        include_chain_info=False
                    )
                    
                    # Save graph
                    output_path = output_subdir / f"{protein_id}_{chain_id}.pt"
                    torch.save(graph, output_path)
                    
                    # Record statistics
                    summary_data.append({
                        'protein_id': protein_id,
                        'chain_id': chain_id,
                        'num_nodes': graph.num_nodes,
                        'num_edges': graph.num_edges,
                    })
                    
                except Exception as e:
                    failed_structures.append(
                        f"{protein_id}_chain{chain_id}: Graph build failed - {str(e)}"
                    )
    
    print(f"\n✓ Monomer graphs generated: {len(summary_data)}")
    print(f"✗ Failed structures: {len(failed_structures)}")
    
    return summary_data, failed_structures


print("✓ Monomer graph pipeline defined")

### Run Monomer Graph Generation

In [None]:
# Generate monomer graphs
monomer_summary, monomer_failed = process_monomer_graphs(
    input_dir=INPUT_DIR,
    output_dir=OUTPUT_DIR,
    cutoff=DISTANCE_CUTOFF_ANGSTROM,
    remove_water=REMOVE_WATER,
    remove_hetero=REMOVE_HETERO,
    min_residues=MIN_RESIDUES
)

## 11. Complex (Multi-Chain) Graph Pipeline

### Generate Multi-Chain Complex Graphs

In [None]:
def process_complex_graphs(
    input_dir: str,
    output_dir: str,
    cutoff: float,
    remove_water: bool,
    remove_hetero: bool,
    min_residues: int
) -> Tuple[List[Dict], List[str]]:
    """
    Process all PDB files and generate complex (multi-chain) graphs.
    
    Args:
        input_dir: Directory containing PDB files
        output_dir: Directory to save graphs
        cutoff: Distance cutoff (Angstroms)
        remove_water: Whether to exclude water
        remove_hetero: Whether to exclude heteroatoms
        min_residues: Minimum total residues in structure
        
    Returns:
        summary_data: List of dictionaries with graph statistics
        failed_structures: List of failed structure IDs with reasons
    """
    pdb_files = list(Path(input_dir).glob("*.pdb"))
    
    if len(pdb_files) == 0:
        print(f"⚠ No PDB files found in {input_dir}")
        return [], []
    
    print(f"\n{'='*60}")
    print(f"COMPLEX (MULTI-CHAIN) GRAPH GENERATION")
    print(f"{'='*60}")
    print(f"Found {len(pdb_files)} PDB files\n")
    
    summary_data = []
    failed_structures = []
    
    output_subdir = Path(output_dir) / "graphs_complex"
    output_subdir.mkdir(exist_ok=True)
    
    for pdb_file in tqdm(pdb_files, desc="Processing structures"):
        protein_id = pdb_file.stem
        
        # Parse structure
        structure = parse_pdb_structure(str(pdb_file), remove_water, remove_hetero)
        if structure is None:
            failed_structures.append(f"{protein_id}: Failed to parse PDB")
            continue
        
        # Combine all chains in the structure
        all_residues = []
        chain_ids = []
        
        for model in structure:
            for chain in model:
                chain_id = chain.get_id()
                residues = extract_residue_data(chain, remove_water, remove_hetero)
                
                if len(residues) > 0:
                    all_residues.extend(residues)
                    chain_ids.append(chain_id)
        
        # Skip if structure has too few residues
        if len(all_residues) < min_residues:
            failed_structures.append(
                f"{protein_id}: Only {len(all_residues)} total residues (< {min_residues})"
            )
            continue
        
        # Build complex graph
        try:
            # Create chain ID to index mapping
            unique_chains = sorted(set(chain_ids))
            chain_to_idx = {chain: idx for idx, chain in enumerate(unique_chains)}
            
            graph = build_protein_graph(
                all_residues,
                cutoff=cutoff,
                include_chain_info=True,
                chain_to_idx=chain_to_idx
            )
            
            # Save graph
            output_path = output_subdir / f"{protein_id}_complex.pt"
            torch.save(graph, output_path)
            
            # Count inter-chain edges
            same_chain_flags = graph.edge_attr[:, 1]
            num_inter_chain = int((same_chain_flags == 0).sum().item())
            
            # Record statistics
            summary_data.append({
                'protein_id': protein_id,
                'num_chains': len(unique_chains),
                'num_nodes': graph.num_nodes,
                'num_edges': graph.num_edges,
                'num_inter_chain_edges': num_inter_chain,
            })
            
        except Exception as e:
            failed_structures.append(
                f"{protein_id}: Graph build failed - {str(e)}"
            )
    
    print(f"\n✓ Complex graphs generated: {len(summary_data)}")
    print(f"✗ Failed structures: {len(failed_structures)}")
    
    return summary_data, failed_structures


print("✓ Complex graph pipeline defined")

### Run Complex Graph Generation

In [None]:
# Generate complex graphs
complex_summary, complex_failed = process_complex_graphs(
    input_dir=INPUT_DIR,
    output_dir=OUTPUT_DIR,
    cutoff=DISTANCE_CUTOFF_ANGSTROM,
    remove_water=REMOVE_WATER,
    remove_hetero=REMOVE_HETERO,
    min_residues=MIN_RESIDUES
)

## 12. Save Summary Data & Logs

In [None]:
# Save monomer summary
if len(monomer_summary) > 0:
    df_monomer = pd.DataFrame(monomer_summary)
    monomer_csv_path = Path(OUTPUT_DIR) / "dataset_summary_monomer.csv"
    df_monomer.to_csv(monomer_csv_path, index=False)
    print(f"✓ Saved monomer summary: {monomer_csv_path}")
    
    # Display statistics
    print("\nMonomer Graph Statistics:")
    print(df_monomer.describe())
else:
    print("⚠ No monomer graphs generated")

# Save complex summary
if len(complex_summary) > 0:
    df_complex = pd.DataFrame(complex_summary)
    complex_csv_path = Path(OUTPUT_DIR) / "dataset_summary_complex.csv"
    df_complex.to_csv(complex_csv_path, index=False)
    print(f"\n✓ Saved complex summary: {complex_csv_path}")
    
    # Display statistics
    print("\nComplex Graph Statistics:")
    print(df_complex.describe())
else:
    print("⚠ No complex graphs generated")

# Save failed structures log
all_failed = monomer_failed + complex_failed
if len(all_failed) > 0:
    failed_path = Path(OUTPUT_DIR) / "failed_structures.txt"
    with open(failed_path, 'w') as f:
        f.write("Failed Structures and Reasons\n")
        f.write("="*60 + "\n\n")
        for item in all_failed:
            f.write(f"{item}\n")
    print(f"\n✓ Saved failed structures log: {failed_path}")
else:
    print("\n✓ All structures processed successfully")

## 13. Save Graph Definition (Scientific Transparency)

In [None]:
# Create graph definition metadata
graph_definition = {
    "version": "1.0",
    "created_date": pd.Timestamp.now().isoformat(),
    "graph_specification": {
        "node_definition": "One node per residue with Cα atom present",
        "edge_rule": f"Undirected edges when Cα-Cα distance ≤ {DISTANCE_CUTOFF_ANGSTROM} Å",
        "distance_cutoff_angstrom": DISTANCE_CUTOFF_ANGSTROM,
        "node_features": {
            "amino_acid_onehot": "20 standard amino acids (indices 0-19)",
            "residue_index": "Sequential residue number (index 20)",
            "chain_id": "Encoded chain identifier (index 21)",
            "ca_coordinates": "xyz coordinates in Angstroms (indices 22-24)"
        },
        "edge_features": {
            "distance": "Euclidean distance between Cα atoms (Angstroms)",
            "same_chain_flag": "Binary flag (complex mode only): 1=same chain, 0=different chains"
        }
    },
    "processing_parameters": {
        "remove_water": REMOVE_WATER,
        "remove_hetero": REMOVE_HETERO,
        "min_residues": MIN_RESIDUES
    },
    "amino_acid_encoding": AA_TO_INDEX,
    "statistics": {
        "total_monomer_graphs": len(monomer_summary),
        "total_complex_graphs": len(complex_summary),
        "failed_structures": len(all_failed)
    }
}

# Save graph definition
definition_path = Path(OUTPUT_DIR) / "graph_definition.json"
with open(definition_path, 'w') as f:
    json.dump(graph_definition, f, indent=2)

print(f"✓ Saved graph definition: {definition_path}")
print("\n" + "="*60)
print("GRAPH SPECIFICATION SUMMARY")
print("="*60)
print(f"Node definition: One node per residue with Cα atom")
print(f"Distance cutoff: {DISTANCE_CUTOFF_ANGSTROM} Å")
print(f"Edge rule: Undirected edges when Cα-Cα distance ≤ cutoff")
print(f"\nNode features ({NUM_AA_TYPES + 5} total):")
print(f"  - Amino acid one-hot (20)")
print(f"  - Residue index (1)")
print(f"  - Chain ID encoded (1)")
print(f"  - Cα coordinates xyz (3)")
print(f"\nEdge features:")
print(f"  - Monomer mode: Distance (1)")
print(f"  - Complex mode: Distance + same_chain_flag (2)")
print("="*60)

## 14. Example: Load and Inspect a Graph

In [None]:
# Example: Load a generated graph and inspect its properties
print("\nExample: Loading and inspecting a graph\n")

# Try to load a monomer graph
monomer_graphs = list(Path(OUTPUT_DIR, "graphs_monomer").glob("*.pt"))
if len(monomer_graphs) > 0:
    example_graph_path = monomer_graphs[0]
    graph = torch.load(example_graph_path)
    
    print(f"Loaded graph: {example_graph_path.name}")
    print(f"  Number of nodes (residues): {graph.num_nodes}")
    print(f"  Number of edges: {graph.num_edges}")
    print(f"  Node feature shape: {graph.x.shape}")
    print(f"  Edge index shape: {graph.edge_index.shape}")
    print(f"  Edge attribute shape: {graph.edge_attr.shape}")
    
    # Show feature breakdown
    print(f"\n  Node features breakdown:")
    print(f"    - Amino acid one-hot: [:, 0:20]")
    print(f"    - Residue index: [:, 20]")
    print(f"    - Chain ID: [:, 21]")
    print(f"    - Cα coordinates: [:, 22:25]")
    
    # Show edge statistics
    edge_distances = graph.edge_attr[:, 0]
    print(f"\n  Edge distance statistics (Å):")
    print(f"    - Min: {edge_distances.min():.2f}")
    print(f"    - Max: {edge_distances.max():.2f}")
    print(f"    - Mean: {edge_distances.mean():.2f}")
    print(f"    - Median: {edge_distances.median():.2f}")
else:
    print("No graphs generated to display example.")

print("\n" + "="*60)
print("PROCESSING COMPLETE")
print("="*60)
print(f"All outputs saved to: {OUTPUT_DIR}")

---

## Next Steps

After running this notebook, you can:

1. **Load graphs for machine learning:**
   ```python
   import torch
   graph = torch.load('data/graphs/graphs_monomer/protein_A.pt')
   ```

2. **Create PyTorch Geometric datasets:**
   ```python
   from torch_geometric.data import Dataset, DataLoader
   # Build your custom dataset class
   ```

3. **Train graph neural networks:**
   - Use the generated graphs as input to GNN models
   - Tasks: protein function prediction, interaction prediction, etc.

4. **Analyze graph properties:**
   - Load the CSV summaries for statistical analysis
   - Visualize graph size distributions
   - Study inter-chain connectivity patterns

---