In [None]:
!pip install ogb>=1.3.3 torch_geometric pyvis torch torch-scatter


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_softmax
import numpy as np
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
from sklearn.metrics import precision_recall_fscore_support
import time
from datetime import timedelta, datetime
from torch.amp import GradScaler, autocast
import random
from collections import defaultdict

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
class MaskedDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, labels, mask_indices):
        self.data_list = []
        self.labels = []
        self.mask_indices = []

        for data, label, mask_idx in zip(data_list, labels, mask_indices):
            if mask_idx < data.num_nodes:
                self.data_list.append(data)
                self.labels.append(label)
                self.mask_indices.append(mask_idx)

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        data.x = data.x.type(torch.float)
        data.y = torch.tensor(self.labels[idx], dtype=torch.long)
        data.mask_index = torch.tensor(self.mask_indices[idx], dtype=torch.long)
        return data

In [None]:
def mask_node_token(data, mask_token_id=-1):
    node_count = data.x.size(0)

    mask_idx = random.randint(0, node_count - 1)
    original_token = data.x[mask_idx, 0].clone()
    masked_data = data.clone()
    masked_data.x[mask_idx] = mask_token_id

    return masked_data, original_token, mask_idx

In [None]:
def create_data_splits(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):

    total_size = len(dataset)
    indices = list(range(total_size))
    random.shuffle(indices)

    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    train_dataset = [dataset[i] for i in train_indices]
    val_dataset = [dataset[i] for i in val_indices]
    test_dataset = [dataset[i] for i in test_indices]

    print(f"set sizes: - train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")

    return train_dataset, val_dataset, test_dataset


In [None]:
def preprocess_dataset(dataset_name='ogbg-code2', sample_ratio=1):
    print("loading dataset")
    pyg_dataset = PygGraphPropPredDataset(name=dataset_name)
    split_index = pyg_dataset.get_idx_split()

    train_indices = split_index['train']
    train_subset_indices = random.sample(train_indices.tolist(), max(1, int(len(train_indices) * sample_ratio)))
    dataset_subset = [pyg_dataset[i] for i in train_subset_indices]
    print(f"Number of graphs in subset: {len(dataset_subset)}")

    print("Applying masking to nodes")
    masked_data = []
    labels = []
    mask_indices = []

    for data in dataset_subset:
        masked_data_item, original_token, mask_idx = mask_node_token(data)
        if mask_idx < data.num_nodes:
            masked_data.append(masked_data_item)
            labels.append(int(original_token.item()))
            mask_indices.append(mask_idx)

    print("Creating dataset")
    masked_dataset = MaskedDataset(masked_data, labels, mask_indices)

    print("Creating splits")
    train_data, val_data, test_data = create_data_splits(masked_dataset)

    batch_size = 128
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )

    vocab_size = get_max_vocab_size(train_loader)
    print(f"vocab size: {vocab_size}")

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'vocab_size': vocab_size
    }

In [None]:
class GATLayer(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, dropout_rate: float = 0.1):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.query = nn.Linear(in_channels, hidden_channels)
        self.key = nn.Linear(in_channels, hidden_channels)
        self.value = nn.Linear(in_channels, hidden_channels)

        self.norm1 = nn.LayerNorm(in_channels)
        self.norm2 = nn.LayerNorm(hidden_channels)

        self.gru = nn.GRUCell(hidden_channels, hidden_channels)

        self.output_proj = nn.Linear(hidden_channels, hidden_channels)

        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index):
        x = self.norm1(x)
        identity = x

        src, dst = edge_index

        q = self.query(x[dst])
        k = self.key(x[src])
        v = self.value(x[src])

        scores = torch.sum(q * k, dim=-1) / torch.sqrt(torch.tensor(self.hidden_channels).float())

        attention_weights = scatter_softmax(scores, dst, dim=0)

        weighted_messages = v * attention_weights.unsqueeze(-1)

        messages = scatter_add(weighted_messages, dst, dim=0, dim_size=x.size(0))

        h = self.gru(x, messages)

        h = self.norm2(h)
        h = self.output_proj(h)
        h = self.dropout_layer(h)

        return h + identity

class GAT(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int,
                 num_layers: int = 2, dropout_rate: float = 0.1):
        super().__init__()

        self.num_layers = num_layers

        self.embed = nn.Linear(in_channels, hidden_channels)

        self.norm = nn.LayerNorm(hidden_channels)

        self.dropout_layer = nn.Dropout(dropout_rate)

        self.layers = nn.ModuleList([
            GATLayer(hidden_channels, hidden_channels, dropout_rate=dropout_rate) for _ in range(num_layers)
        ])

        self.output_norm = nn.LayerNorm(hidden_channels)
        self.output = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_channels, out_channels)
        )

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.embed(x.float())
        x = self.norm(x)
        x = F.relu(x)
        x = self.dropout_layer(x)

        initial_x = x

        for layer in self.layers:
            x = layer(x, edge_index)
            x = F.relu(x)

        x = x + initial_x

        x = self.output_norm(x)
        return self.output(x)

In [None]:
def get_max_vocab_size(train_loader):
    max_target = 0
    for batch in train_loader:
        max_target = max(max_target, batch.y.max().item())
    return max_target + 1

In [None]:
random.seed(11)
torch.manual_seed(11)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(11)

processed_data = preprocess_dataset()

In [None]:
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
from sklearn.metrics import f1_score

def evaluate_model(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            outputs = model(batch)

            predictions = []
            start_idx = 0
            for i in range(batch.num_graphs):
                num_nodes = int(torch.sum((batch.batch == i).int()))
                mask_idx = min(int(batch.mask_index[i]), num_nodes - 1)
                node_idx = start_idx + mask_idx
                predictions.append(outputs[node_idx])
                start_idx += num_nodes

            predictions = torch.stack(predictions)
            loss = criterion(predictions, batch.y).item()
            total_loss += loss

            preds = predictions.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

            for pred, true in zip(preds.cpu(), batch.y.cpu()):
                class_total[true.item()] += 1
                if pred == true:
                    class_correct[true.item()] += 1

    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    unique_classes = np.unique(all_labels)

    f1_macro = f1_score(all_labels, all_preds, average='macro', labels=unique_classes)
    f1_per_class = f1_score(all_labels, all_preds, average=None, labels=unique_classes)

    f1_dict = dict(zip(unique_classes, f1_per_class))

    total = sum(class_total.values())
    correct = sum(class_correct.values())
    accuracy = correct / total if total > 0 else 0

    class_metrics = {}
    for class_idx in class_total.keys():
        class_metrics[class_idx] = {
            'accuracy': class_correct[class_idx] / class_total[class_idx],
            'correct': class_correct[class_idx],
            'total': class_total[class_idx],
            'f1': f1_dict.get(class_idx, 0.0)
        }

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'class_metrics': class_metrics
    }

def train(model, train_loader, val_loader, test_loader, num_epochs, device, learning_rate=1e-3):
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    criterion = nn.CrossEntropyLoss()

    best_val_f1 = 0.0
    patience = 10
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        batch_correct = 0
        batch_total = 0

        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            optimizer.zero_grad()

            outputs = model(batch)

            predictions = []
            start_idx = 0
            for i in range(batch.num_graphs):
                num_nodes = int(torch.sum((batch.batch == i).int()))
                mask_idx = min(int(batch.mask_index[i]), num_nodes - 1)
                node_idx = start_idx + mask_idx
                predictions.append(outputs[node_idx])
                start_idx += num_nodes

            predictions = torch.stack(predictions)
            loss = criterion(predictions, batch.y)

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            preds = predictions.argmax(dim=1)
            batch_correct += (preds == batch.y).sum().item()
            batch_total += len(batch.y)

            if (batch_idx + 1) % 5 == 0:
                current_loss = epoch_loss / (batch_idx + 1)
                current_acc = batch_correct / batch_total
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Loss = {current_loss:.4f}")

        val_metrics = evaluate_model(model, val_loader, criterion, device)

        scheduler.step(val_metrics['f1_macro'])

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {epoch_loss/len(train_loader):.4f}")
        print(f"Train Accuracy: {batch_correct/batch_total:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val F1 (macro): {val_metrics['f1_macro']:.4f}")

        if val_metrics['f1_macro'] > best_val_f1:
            best_val_f1 = val_metrics['f1_macro']
            torch.save(model.state_dict(), 'best_model.pt')
            print("new best model")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("stop early"")
                break

        print("-" * 50)

    model.load_state_dict(torch.load('best_model.pt'))
    test_metrics = evaluate_model(model, test_loader, criterion, device)

    print(f"\nFinal Test Results:")
    print(f"Test Loss: {test_metrics['loss']:.4f}")
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}")

    print("\nPer-class Test Metrics:")
    for class_idx, metrics in test_metrics['class_metrics'].items():
        print(f"Class {class_idx}:")
        print(f"  Accuracy: {metrics['accuracy']:.4f} ({metrics['correct']}/{metrics['total']})")
        print(f"  F1-score: {metrics['f1']:.4f}")

    return model, test_metrics

In [None]:
import os
import json
from datetime import datetime

def setup_save_directory():
    save_dir = '/content/model_results'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    return save_dir

def save_run_metrics(model_name, run_id, metrics, save_dir):
    results = {
        'model_name': model_name,
        'run_id': run_id,
        'accuracy': metrics['accuracy'],
        'f1_macro': metrics['f1_macro'],
        'class_metrics': metrics['class_metrics']
    }

    filename = os.path.join(save_dir, f"metrics_{model_name}_{run_id}.json")
    with open(filename, 'w') as f:
        json.dump(results, f)
    print(f"Saved metrics to: {filename}")

def run_multiple_trainings(model_class, model_params, train_params, num_runs=3):
    save_dir = setup_save_directory()
    all_metrics = []

    for run in range(num_runs):
        print(f"\nStarting Run {run + 1}/{num_runs}")

        model = model_class(**model_params).to('cuda')

        trained_model, metrics = train(model=model, **train_params)
        save_run_metrics('gat_ablation', run, metrics, save_dir)

        all_metrics.append(metrics)

        print(f"Run {run + 1} completed")
        print(f"F1 Macro: {metrics['f1_macro']:.4f}")

    print(f"\nAll results saved in: {save_dir}")
    return all_metrics

model_params = {
    'in_channels': 2,
    'hidden_channels': 64,
    'out_channels': processed_data['vocab_size'],
    'num_layers': 2,
    'dropout_rate': 0.1
}

train_params = {
    'train_loader': processed_data['train_loader'],
    'val_loader': processed_data['val_loader'],
    'test_loader': processed_data['test_loader'],
    'num_epochs': 3,
    'device': 'cuda',
    'learning_rate': 1e-3
}

metrics_list = run_multiple_trainings(
    model_class=GAT,
    model_params=model_params,
    train_params=train_params,
    num_runs=3
)