# GraphRNN Tutorial for DeepChem

In this tutorial, we demonstrate how to use the GraphRNN model in DeepChem for graph-based learning tasks such as molecular graph generation.
GraphRNN sequentially processes node and edge features and can be applied to generate or predict properties on graph-structured data.

In this notebook, we will:
- Create a synthetic dataset using `GraphData` objects.
- Initialize and train the GraphRNN model.
- Evaluate the model using DeepChem metrics.
- Visualize and interpret the model’s outputs.


In [None]:
# Import necessary libraries and set environment variables if needed
import os
import torch
import numpy as np
import deepchem as dc
from deepchem.feat.graph_data import GraphData
from deepchem.models.torch_models.graphrnn import GraphRNN

# Optionally, disable GraphBolt if not required
os.environ['DGL_GRAPHBOLT_DISABLE'] = '1'

# Display version information
print("DeepChem version:", dc.__version__)
print("PyTorch version:", torch.__version__)

In [None]:
# Create synthetic graph data
import numpy as np
from deepchem.feat.graph_data import GraphData

def create_synthetic_graph(num_nodes=10, node_feature_dim=16, num_neighbors=5, edge_feature_dim=8):
    """
    Create a synthetic GraphData object with random node and edge features.
    """
    # Generate random node features
    node_features = np.random.randn(num_nodes, node_feature_dim)
    # Generate random edge features for each node; shape: (num_nodes, num_neighbors, edge_feature_dim)
    edge_features = np.random.randn(num_nodes, num_neighbors, edge_feature_dim)
    # Create an adjacency list: each node i connects to the next num_neighbors nodes (with wrap-around)
    adjacency_list = {i: [(i + j + 1) % num_nodes for j in range(num_neighbors)] for i in range(num_nodes)}
    # Construct edge_index from the adjacency_list
    edge_index_0, edge_index_1 = [], []
    for i in range(num_nodes):
        for neigh in adjacency_list[i]:
            edge_index_0.append(i)
            edge_index_1.append(neigh)
    edge_index = np.array([edge_index_0, edge_index_1])
    return GraphData(node_features=node_features,
                     edge_features=edge_features,
                     adjacency_list=adjacency_list,
                     edge_index=edge_index)

# Generate a list of synthetic graphs (e.g., 4 graphs)
graphs = [create_synthetic_graph() for _ in range(4)]

# Create a DeepChem dataset using these graphs and random targets
dataset = dc.data.NumpyDataset(X=graphs, y=np.random.randn(4, 1))

print("Synthetic dataset created with", len(dataset.X), "graphs.")

In [None]:
# Initialize the GraphRNN model
node_feature_dim = 16
edge_feature_dim = 8
num_neighbors = 5

model = GraphRNN(node_input_dim=node_feature_dim,
                 node_hidden_dim=32,
                 node_feature_dim=node_feature_dim,
                 edge_input_dim=edge_feature_dim,
                 edge_hidden_dim=32,
                 edge_feature_dim=edge_feature_dim,
                 num_neighbors=num_neighbors,
                 mode='regression')

# Train the model (for demonstration, use a small number of epochs)
print("Training GraphRNN model...")
model.fit(dataset, nb_epoch=5)
print("Training complete.")

In [None]:
# Make predictions using the trained model
predictions = model.predict(dataset)

# Evaluate the model using a DeepChem metric (e.g., mean absolute error)
mae_metric = dc.metrics.Metric(dc.metrics.mean_absolute_error, np.mean)
scores = model.evaluate(dataset, [mae_metric])
print("Evaluation scores:", scores)

# Inspect the shape of the predictions
print("Predictions shape:", predictions.shape)

## Conclusion

In this tutorial, we demonstrated how to:
- Create a synthetic graph dataset using DeepChem's `GraphData`.
- Initialize and train a GraphRNN model within DeepChem.
- Evaluate the model using DeepChem metrics.
- Inspect the output predictions.

GraphRNN can be extended to various applications such as molecular graph generation and property prediction. For more detailed information, refer to:

- [DeepChem Documentation](https://deepchem.io/)
- [Deep Graph Library (DGL) Documentation](https://www.dgl.ai/)
- [Graph Neural Networks Tutorials](https://github.com/stellargraph/stellargraph)
- [GraphRNN Original Paper](https://arxiv.org/abs/1802.08773)
