<a href="https://colab.research.google.com/github/bodadineshreddy/indictrans2/blob/main/GNNA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%%capture
!pip uninstall torch-scatter torch-cluster torch-spline-conv torch-sparse -y
!pip install torch-scatter torch-cluster torch-spline-conv torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install torch-geometric ogb
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121




In [3]:
import torch
print(torch.__version__)

2.5.1+cu124


In [4]:
import ogb.utils.url  # Import the full module so you can access ogb.utils.url

# Override the decide_download function to skip user input
def decide_download(url):
    print(f"Auto-approving download for: {url}")
    return True

ogb.utils.url.decide_download = decide_download  # Apply the patch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score
import random
from collections import Counter
from ogb.graphproppred import PygGraphPropPredDataset
import sys

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load dataset and splits
dataset = PygGraphPropPredDataset(name="ogbg-molhiv", root="dataset/")
split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)

# Build a global vocabulary of anonymous walks from a subset of training graphs
def build_global_anonymous_walks(graphs, walk_length=5, num_walks=100, max_graphs=100):
    vocabulary = set()
    for i, graph in enumerate(graphs):
        if i >= max_graphs:
            break
        edge_index = graph.edge_index
        num_nodes = graph.num_nodes
        for _ in range(num_walks):
            start_node = random.randint(0, num_nodes - 1)
            walk = [start_node]
            for _ in range(walk_length - 1):
                neighbors = edge_index[1][edge_index[0] == walk[-1]]
                if len(neighbors) > 0:
                    walk.append(random.choice(neighbors.tolist()))
            anonymous_walk = tuple(walk.index(n) for n in walk)
            vocabulary.add(anonymous_walk)
    return list(vocabulary)

global_vocabulary = build_global_anonymous_walks(dataset[split_idx["train"]], walk_length=5, num_walks=100, max_graphs=100)
print("Global vocabulary size:", len(global_vocabulary))

# Generate Anonymous Walk Embeddings using a fixed vocabulary
def generate_anonymous_walk_embeddings(batch_graphs, batch_size, walk_length=5, num_walks=100, vocabulary=None):
    if vocabulary is None:
        vocab_set = set()
        batch_embeddings = []
        for graph in batch_graphs:
            edge_index = graph.edge_index
            num_nodes = graph.num_nodes
            anonymous_walk_counts = Counter()
            for _ in range(num_walks):
                start_node = random.randint(0, num_nodes - 1)
                walk = [start_node]
                for _ in range(walk_length - 1):
                    neighbors = edge_index[1][edge_index[0] == walk[-1]]
                    if len(neighbors) > 0:
                        walk.append(random.choice(neighbors.tolist()))
                anonymous_walk = tuple(walk.index(n) for n in walk)
                anonymous_walk_counts[anonymous_walk] += 1
                vocab_set.add(anonymous_walk)
            batch_embeddings.append(anonymous_walk_counts)
        vocabulary = list(vocab_set)
    else:
        vocabulary = list(vocabulary)
        batch_embeddings = []
        for graph in batch_graphs:
            edge_index = graph.edge_index
            num_nodes = graph.num_nodes
            anonymous_walk_counts = Counter()
            for _ in range(num_walks):
                start_node = random.randint(0, num_nodes - 1)
                walk = [start_node]
                for _ in range(walk_length - 1):
                    neighbors = edge_index[1][edge_index[0] == walk[-1]]
                    if len(neighbors) > 0:
                        walk.append(random.choice(neighbors.tolist()))
                anonymous_walk = tuple(walk.index(n) for n in walk)
                anonymous_walk_counts[anonymous_walk] += 1
            batch_embeddings.append(anonymous_walk_counts)

    embedding_matrix = torch.tensor(
        [[graph.get(walk, 0) for walk in vocabulary] for graph in batch_embeddings],
        dtype=torch.float32
    ).to(device)

    if embedding_matrix.shape[0] != batch_size:
        print("Warning: Expected batch size", batch_size, "but got", embedding_matrix.shape[0], ". Adjusting.")
        embedding_matrix = embedding_matrix[:batch_size]

    return embedding_matrix, vocabulary

# Use a real batch to determine input dimension (using the global vocabulary)
batch = next(iter(train_loader))
batch_graphs = batch.to_data_list()
sample_batch, _ = generate_anonymous_walk_embeddings(batch_graphs, batch_size=batch.num_graphs, walk_length=5, num_walks=100, vocabulary=global_vocabulary)
input_dim = sample_batch.shape[1]
print("Detected input dimension:", input_dim)

# Define the Graph Classifier model
class GraphClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=1):
        super(GraphClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x))).view(-1, 1)

# Training function with carriage return printing for batch loss updates
def train(model, train_loader, optimizer, loss_fn, device, vocabulary):
    model.train()
    total_loss = 0
    num_batches = len(train_loader)
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        labels = batch.y.float().view(-1, 1).to(device)
        batch_graphs = batch.to_data_list()
        embeddings, _ = generate_anonymous_walk_embeddings(batch_graphs, batch_size=batch.num_graphs, walk_length=5, num_walks=100, vocabulary=vocabulary)
        optimizer.zero_grad()
        outputs = model(embeddings)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        sys.stdout.write("\rTrain batch {}/{} Loss: {:.4f}".format(i+1, num_batches, loss.item()))
        sys.stdout.flush()
    sys.stdout.write("\n")
    return total_loss / num_batches

# Evaluation function with minimal per-batch printing
def evaluate(model, data_loader, device, vocabulary):
    model.eval()
    y_true, y_pred = [], []
    num_batches = len(data_loader)
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            batch = batch.to(device)
            labels = batch.y.float().view(-1, 1).to(device)
            batch_graphs = batch.to_data_list()
            embeddings, _ = generate_anonymous_walk_embeddings(batch_graphs, batch_size=batch.num_graphs, walk_length=5, num_walks=100, vocabulary=vocabulary)
            outputs = torch.sigmoid(model(embeddings))
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(outputs.cpu().numpy())
            sys.stdout.write("\rEval batch {}/{}".format(i+1, num_batches))
            sys.stdout.flush()
    sys.stdout.write("\n")
    return roc_auc_score(y_true, y_pred)

# Initialize model, optimizer, and loss function
model = GraphClassifier(input_dim=input_dim, num_classes=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

# Training loop with epoch summaries
num_epochs = 5
for epoch in range(num_epochs):
    print("Epoch {}/{}".format(epoch+1, num_epochs))
    train_loss = train(model, train_loader, optimizer, loss_fn, device, global_vocabulary)
    val_auc = evaluate(model, valid_loader, device, global_vocabulary)
    print("Epoch {} Loss = {:.4f} Val AUC = {:.4f}".format(epoch+1, train_loss, val_auc))


Using device: cuda


  self.data, self.slices = torch.load(self.processed_paths[0])


Global vocabulary size: 13
Detected input dimension: 13
Epoch 1/5
Train batch 609/1029 Loss: 0.0297