In [None]:
!pip install ogb>=1.3.3 torch_geometric pyvis torch torch-scatter


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
from sklearn.metrics import precision_recall_fscore_support
import time
from datetime import timedelta, datetime
from torch.amp import GradScaler, autocast
import random
from collections import defaultdict
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
def preprocess_topological_layers(dataset):
    processed_dataset = []

    for data in dataset:
        num_nodes = data.num_nodes
        in_degree = torch.zeros(num_nodes, dtype=torch.long)

        for _, dst in data.edge_index.T:
            in_degree[dst] += 1

        visited = torch.zeros(num_nodes, dtype=torch.bool)
        current_layer = (in_degree == 0).nonzero(as_tuple=False).view(-1)

        topo_layers = torch.zeros(num_nodes, dtype=torch.long)

        layer_idx = 0
        while current_layer.numel() > 0:
            topo_layers[current_layer] = layer_idx
            visited[current_layer] = True

            successors = []
            for node in current_layer:
                succ = data.edge_index[1][data.edge_index[0] == node]
                successors.extend(succ.tolist())

            for succ in successors:
                in_degree[succ] -= 1

            current_layer = (in_degree == 0).nonzero(as_tuple=False).view(-1)
            current_layer = current_layer[~visited[current_layer]]
            layer_idx += 1

        data.topo_layers = topo_layers
        processed_dataset.append(data)

    return processed_dataset

In [None]:
class MaskedDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, labels, mask_indices):
        self.data_list = []
        self.labels = []
        self.mask_indices = []

        for data, label, mask_idx in zip(data_list, labels, mask_indices):
            if mask_idx < data.num_nodes:
                self.data_list.append(data)
                self.labels.append(label)
                self.mask_indices.append(mask_idx)

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        data.x = data.x.type(torch.float)
        data.y = torch.tensor(self.labels[idx], dtype=torch.long)
        data.mask_index = torch.tensor(self.mask_indices[idx], dtype=torch.long)
        return data

In [None]:
def mask_node_token(data, mask_token_id=-1):
    node_count = data.x.size(0)

    mask_idx = random.randint(0, node_count - 1)
    original_token = data.x[mask_idx, 0].clone()
    masked_data = data.clone()
    masked_data.x[mask_idx] = mask_token_id

    return masked_data, original_token, mask_idx

In [None]:
def create_data_splits(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):

    total_size = len(dataset)
    indices = list(range(total_size))
    random.shuffle(indices)

    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    train_dataset = [dataset[i] for i in train_indices]
    val_dataset = [dataset[i] for i in val_indices]
    test_dataset = [dataset[i] for i in test_indices]

    print(f"set sizes: - train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")

    return train_dataset, val_dataset, test_dataset


In [None]:
def preprocess_dataset(dataset_name='ogbg-code2', sample_ratio=1):
    pyg_dataset = PygGraphPropPredDataset(name=dataset_name)
    split_index = pyg_dataset.get_idx_split()

    train_indices = split_index['train']
    train_subset_indices = random.sample(
        train_indices.tolist(),
        max(1, int(len(train_indices) * sample_ratio))
    )
    dataset_subset = [pyg_dataset[i] for i in train_subset_indices]
    print(f"Number of graphs in subset: {len(dataset_subset)}")

    print("Computing topological batches")
    dataset_subset = preprocess_topological_layers(dataset_subset)

    print("Applying masking to nodes")
    masked_data = []
    labels = []
    mask_indices = []

    for data in dataset_subset:
        masked_data_item, original_token, mask_idx = mask_node_token(data)
        if mask_idx < data.num_nodes:
            masked_data.append(masked_data_item)
            labels.append(int(original_token.item()))
            mask_indices.append(mask_idx)

    print("Creating dataset")
    masked_dataset = MaskedDataset(masked_data, labels, mask_indices)

    print("Creating splits")
    train_data, val_data, test_data = create_data_splits(masked_dataset)

    batch_size = 128
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )

    vocab_size = get_max_vocab_size(train_loader)
    print(f"Vocab size: {vocab_size}")

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'vocab_size': vocab_size
    }

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_softmax

class DAGNNLayer(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, dropout_rate: float = 0.1):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.norm1 = nn.LayerNorm(in_channels)

        self.query = nn.Linear(in_channels, hidden_channels)
        self.key = nn.Linear(in_channels, hidden_channels)
        self.value = nn.Linear(in_channels, hidden_channels)

        self.norm2 = nn.LayerNorm(hidden_channels)

        self.gru = nn.GRUCell(hidden_channels, hidden_channels)

        self.output_proj = nn.Linear(hidden_channels, hidden_channels)

        self.dropout_layer = nn.Dropout(dropout_rate)

    def attention_for_layer(self, x, edge_index, current_nodes, prev_nodes):
        dst_mask = torch.isin(edge_index[1], current_nodes)
        src_mask = torch.isin(edge_index[0], prev_nodes)
        mask = dst_mask & src_mask

        if not mask.any():
            return torch.zeros((len(current_nodes), x.size(1)), device=x.device)

        layer_edges = edge_index[:, mask]
        src, dst = layer_edges

        q = self.query(x[dst])
        k = self.key(x[src])
        v = self.value(x[src])

        scores = torch.sum(q * k, dim=-1) / torch.sqrt(torch.tensor(self.hidden_channels).float())

        attention_weights = scatter_softmax(scores, dst, dim=0)

        weighted_messages = v * attention_weights.unsqueeze(-1)
        node_to_pos = {node.item(): idx for idx, node in enumerate(current_nodes)}
        dst_positions = torch.tensor([node_to_pos[node.item()] for node in dst],
                                   device=dst.device)

        output = scatter_add(weighted_messages, dst_positions, dim=0,
                           dim_size=len(current_nodes))

        return output

    def forward(self, x, edge_index, topo_layers):
        x = self.norm1(x)
        identity = x

        h = torch.zeros_like(x)
        max_layer = topo_layers.max()

        processed_nodes = torch.tensor([], dtype=torch.long, device=x.device)

        for layer_idx in range(max_layer + 1):
            current_nodes = (topo_layers == layer_idx).nonzero().squeeze(-1)
            if len(current_nodes.shape) == 0 or current_nodes.nelement() == 0:
                continue

            # do attention for topological batch
            messages = self.attention_for_layer(x, edge_index, current_nodes, processed_nodes)

            h[current_nodes] = self.gru(x[current_nodes], messages)
            x[current_nodes] = h[current_nodes]

            processed_nodes = torch.cat([processed_nodes, current_nodes])

        h = self.norm2(h)
        h = self.output_proj(h)
        h = self.dropout_layer(h)

        return h + identity

class DAGNN(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int,
                 num_layers: int = 2, dropout_rate: float = 0.1):
        super().__init__()

        self.num_layers = num_layers

        self.embed = nn.Linear(in_channels, hidden_channels)

        self.norm = nn.LayerNorm(hidden_channels)

        self.dropout_layer = nn.Dropout(dropout_rate)

        self.layers = nn.ModuleList([
            DAGNNLayer(
                hidden_channels,
                hidden_channels,
                dropout_rate=dropout_rate
            ) for _ in range(num_layers)
        ])

        self.output_norm = nn.LayerNorm(hidden_channels)
        self.output = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_channels, out_channels)
        )

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        topo_layers = data.topo_layers

        x = self.embed(x.float())
        x = self.norm(x)
        x = F.relu(x)
        x = self.dropout_layer(x)

        initial_x = x

        for layer in self.layers:
            x = layer(x, edge_index, topo_layers)
            x = F.relu(x)

        x = x + initial_x

        x = self.output_norm(x)
        return self.output(x)

In [None]:
def get_max_vocab_size(train_loader):
    max_target = 0
    for batch in train_loader:
        max_target = max(max_target, batch.y.max().item())
    return max_target + 1

In [None]:
random.seed(11)
torch.manual_seed(11)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(11)

processed_data = preprocess_dataset()

In [None]:
def evaluate_model(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            outputs = model(batch)

            predictions = []
            start_idx = 0
            for i in range(batch.num_graphs):
                num_nodes = int(torch.sum((batch.batch == i).int()))
                mask_idx = min(int(batch.mask_index[i]), num_nodes - 1)
                node_idx = start_idx + mask_idx
                predictions.append(outputs[node_idx])
                start_idx += num_nodes

            predictions = torch.stack(predictions)
            loss = criterion(predictions, batch.y).item()
            total_loss += loss

            preds = predictions.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

            for pred, true in zip(preds.cpu(), batch.y.cpu()):
                class_total[true.item()] += 1
                if pred == true:
                    class_correct[true.item()] += 1

    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    unique_classes = np.unique(all_labels)

    f1_macro = f1_score(all_labels, all_preds, average='macro', labels=unique_classes)
    f1_per_class = f1_score(all_labels, all_preds, average=None, labels=unique_classes)

    f1_dict = dict(zip(unique_classes, f1_per_class))

    total = sum(class_total.values())
    correct = sum(class_correct.values())
    accuracy = correct / total if total > 0 else 0

    class_metrics = {}
    for class_idx in class_total.keys():
        class_metrics[class_idx] = {
            'accuracy': class_correct[class_idx] / class_total[class_idx],
            'correct': class_correct[class_idx],
            'total': class_total[class_idx],
            'f1': f1_dict.get(class_idx, 0.0)
        }

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'class_metrics': class_metrics
    }

def train(model, train_loader, val_loader, test_loader, num_epochs, device, learning_rate=1e-3):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=1, verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    val_f1_scores = []
    learning_rates = []

    best_val_f1 = 0.0
    patience = 10
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        batch_correct = 0
        batch_total = 0

        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            optimizer.zero_grad()

            outputs = model(batch)
            predictions = []
            start_idx = 0

            for i in range(batch.num_graphs):
                num_nodes = int(torch.sum((batch.batch == i).int()))
                mask_idx = min(int(batch.mask_index[i]), num_nodes - 1)
                node_idx = start_idx + mask_idx
                predictions.append(outputs[node_idx])
                start_idx += num_nodes

            predictions = torch.stack(predictions)
            loss = criterion(predictions, batch.y)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item()
            preds = predictions.argmax(dim=1)
            batch_correct += (preds == batch.y).sum().item()
            batch_total += len(batch.y)

            train_losses.append(epoch_loss/(batch_idx + 1))
            train_accuracies.append(batch_correct/batch_total)

            if (batch_idx + 1) % 5 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}: "
                      f"Loss = {epoch_loss/(batch_idx + 1):.4f}, "
                      f"Accuracy = {batch_correct/batch_total:.4f}")

        val_metrics = evaluate_model(model, val_loader, criterion, device)

        val_losses.append(val_metrics['loss'])
        val_accuracies.append(val_metrics['accuracy'])
        val_f1_scores.append(val_metrics['f1_macro'])
        learning_rates.append(optimizer.param_groups[0]['lr'])

        scheduler.step(val_metrics['f1_macro'])

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {epoch_loss/len(train_loader):.4f}")
        print(f"Train Accuracy: {batch_correct/batch_total:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val F1 (macro): {val_metrics['f1_macro']:.4f}")


        if val_metrics['f1_macro'] > best_val_f1:
            best_val_f1 = val_metrics['f1_macro']
            torch.save(model.state_dict(), 'best_model.pt')
            print("new best found")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("early stop")
                break

        print("-" * 50)

    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.plot(train_losses, label='Training')
    plt.plot(np.linspace(0, len(train_losses), len(val_losses)), val_losses, label='Validation')
    plt.title('Loss over Time')
    plt.xlabel('Batch (Training) / Epoch (Validation)')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 2, 2)
    plt.plot(train_accuracies, label='Training')
    plt.plot(np.linspace(0, len(train_accuracies), len(val_accuracies)), val_accuracies, label='Validation')
    plt.title('Accuracy over Time')
    plt.xlabel('Batch (Training) / Epoch (Validation)')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 2, 3)
    plt.plot(val_f1_scores)
    plt.title('Validation F1 Macro Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.grid(True)

    plt.subplot(2, 2, 4)
    plt.plot(learning_rates)
    plt.title('Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'training_curves.png')
    plt.close()

    model.load_state_dict(torch.load('best_model.pt'))
    test_metrics = evaluate_model(model, test_loader, criterion, device)

    print(f"\nFinal Test Results:")
    print(f"Test Loss: {test_metrics['loss']:.4f}")
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}")

    metrics = {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'val_f1_scores': val_f1_scores,
        'learning_rates': learning_rates,
        'test_metrics': test_metrics
    }

    return model, test_metrics, metrics

In [None]:
import os
import json
from datetime import datetime

def setup_save_directory():
    save_dir = '/content/model_results'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    return save_dir

def save_run_metrics(model_name, run_id, metrics, save_dir):
    """Save metrics for a single run"""
    results = {
        'model_name': model_name,
        'run_id': run_id,
        'accuracy': metrics['accuracy'],
        'f1_macro': metrics['f1_macro'],
        'class_metrics': metrics['class_metrics']
    }

    filename = os.path.join(save_dir, f"metrics_{model_name}_{run_id}.json")
    with open(filename, 'w') as f:
        json.dump(results, f)
    print(f"Saved metrics to: {filename}")

def run_multiple_trainings(model_class, model_params, train_params, num_runs=3):
    save_dir = setup_save_directory()
    all_metrics = []
    all_training_metrics = []

    for run in range(num_runs):
        print(f"\nStarting Run {run + 1}/{num_runs}")

        model = model_class(**model_params).to('cuda')

        trained_model, test_metrics, training_metrics = train(model=model, **train_params)

        save_run_metrics('dagnn_main', run, test_metrics, save_dir)

        plt.figure(figsize=(15, 10))
        plt.plot(training_metrics['val_f1_scores'])
        plt.title(f'Run {run + 1} Validation F1 Macro Score')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, f'val_f1_run_{run + 1}.png'))
        plt.close()

        all_metrics.append(test_metrics)
        all_training_metrics.append(training_metrics)

        print(f"Run {run + 1} completed")
        print(f"F1 Macro: {test_metrics['f1_macro']:.4f}")

    print(f"\nAll results saved in: {save_dir}")
    return all_metrics, all_training_metrics

model_params = {
    'in_channels': 2,
    'hidden_channels': 64,
    'out_channels': processed_data['vocab_size'],
    'num_layers': 2,
    'dropout_rate': 0.1
}

train_params = {
    'train_loader': processed_data['train_loader'],
    'val_loader': processed_data['val_loader'],
    'test_loader': processed_data['test_loader'],
    'num_epochs': 3,
    'device': 'cuda',
    'learning_rate': 1e-3
}

metrics_list, training_metrics_list = run_multiple_trainings(
    model_class=DAGNN,
    model_params=model_params,
    train_params=train_params,
    num_runs=3
)

In [None]:
def hyperparam_sensitivity():
    learning_rates = [1e-2, 1e-3, 1e-4]
    layer_counts = [2, 3, 4]

    results = {}

    base_model_params = {
        'in_channels': 2,
        'hidden_channels': 64,
        'out_channels': processed_data['vocab_size'],
        'dropout_rate': 0.1
    }

    base_train_params = {
        'train_loader': processed_data['train_loader'],
        'val_loader': processed_data['val_loader'],
        'test_loader': processed_data['test_loader'],
        'num_epochs': 3,
        'device': 'cuda'
    }

    for lr in learning_rates:
        for num_layers in layer_counts:
            print(f"training with lr={lr}, layers={num_layers}")

            model_params = base_model_params.copy()
            model_params['num_layers'] = num_layers

            train_params = base_train_params.copy()
            train_params['learning_rate'] = lr

            model = DAGNN(**model_params).to('cuda')
            _, test_metrics, training_metrics = train_model_with_tracking(model=model, **train_params)

            key = f"lr={lr}_layers={num_layers}"
            results[key] = {
                'test_metrics': test_metrics,
                'training_metrics': training_metrics
            }

    plt.figure(figsize=(20, 15))

    plt.subplot(2, 2, 1)
    for lr in learning_rates:
        key = f"lr={lr}_layers=2"
        metrics = results[key]['training_metrics']
        plt.plot(metrics['val_f1_scores'], label=f'lr={lr}')
    plt.title('Validation F1 Score vs Learning Rate (2 layers)')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 2, 2)
    for layers in layer_counts:
        key = f"lr=0.001_layers={layers}"
        metrics = results[key]['training_metrics']
        plt.plot(metrics['val_f1_scores'], label=f'{layers} layers')
    plt.title('Validation F1 Score vs Number of Layers (lr=1e-3)')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 2, 3)
    x = np.arange(len(learning_rates))
    width = 0.25

    for i, layers in enumerate(layer_counts):
        f1_scores = [results[f"lr={lr}_layers={layers}"]['test_metrics']['f1_macro'] for lr in learning_rates]
        plt.bar(x + i*width, f1_scores, width, label=f'{layers} layers')

    plt.xlabel('Learning Rate')
    plt.ylabel('Final F1 Score')
    plt.title('Final F1 Score Comparison')
    plt.xticks(x + width, [str(lr) for lr in learning_rates])
    plt.legend()

    plt.subplot(2, 2, 4)
    x = np.arange(len(learning_rates))

    for i, layers in enumerate(layer_counts):
        val_losses = [results[f"lr={lr}_layers={layers}"]['training_metrics']['val_losses'][-1] for lr in learning_rates]
        plt.bar(x + i*width, val_losses, width, label=f'{layers} layers')

    plt.xlabel('Learning Rate')
    plt.ylabel('Final Validation Loss')
    plt.title('Final Validation Loss Comparison')
    plt.xticks(x + width, [str(lr) for lr in learning_rates])
    plt.legend()

    plt.tight_layout()
    plt.savefig('hyperparameter_comparison.png')
    plt.close()

    return results

comparison_results = run_hyperparameter_comparison()

best_f1 = 0
best_config = None

for key, result in comparison_results.items():
    f1 = result['test_metrics']['f1_macro']
    if f1 > best_f1:
        best_f1 = f1
        best_config = key

print(f"\nBest configuration: {best_config}")
print(f"Best F1 macro score: {best_f1:.4f}")