
# Лабораторная работа 4: GAT и RGAT

## Структура ноутбука
- Часть 1. GAT для классификации вершин на Cora: загрузка данных, модель, обучение, подбор lr.
- Часть 2. RGAT для link prediction на FB15k-237: подготовка данных, модель, обучение, метрики.
- Итоговые выводы по двум задачам.


In [1]:

!pip install torch_geometric
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

torch.manual_seed(17)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0



## Часть 1. GAT: классификация вершин на Cora
Cora - датасет цитирования статей по ML: вершины это статьи, ребра показывают кто кого цитирует, признаки bag-of-words по текстам, метка - научное направление.


In [2]:


dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0].to(device)

print(dataset)
print(f"Вершин: {data.num_nodes}, ребер: {data.num_edges}")
print(f"Признаков на вершину: {dataset.num_features}")
print(f"Классов: {dataset.num_classes}")
print(f"Train/val/test: {int(data.train_mask.sum())}/{int(data.val_mask.sum())}/{int(data.test_mask.sum())}")


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...


Cora()
Вершин: 2708, ребер: 10556
Признаков на вершину: 1433
Классов: 7
Train/val/test: 140/500/1000


Done!


Классификация вершин через Graph Attention Network (GAT)


In [3]:


class GAT(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        heads1: int = 8,
        heads2: int = 1,
        dropout: float = 0.6,
    ):
        super().__init__()
        self.dropout = dropout

        # первый слой: multi-head attention
        self.conv1 = GATConv(
            in_channels,
            hidden_channels,
            heads=heads1,
            dropout=dropout,
        )
        # второй слой: выдаем классы без concat чтобы не раздувать размерность
        self.conv2 = GATConv(
            hidden_channels * heads1,
            out_channels,
            heads=heads2,
            concat=False,
            dropout=dropout,
        )

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x  # логиты классов


In [4]:


def train_one_epoch(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    logits = model(data.x, data.edge_index)

    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def accuracy_on_mask(model, data, mask):
    model.eval()
    logits = model(data.x, data.edge_index)
    pred = logits.argmax(dim=-1)
    correct = (pred[mask] == data.y[mask]).sum().item()
    total = int(mask.sum())
    return correct / max(total, 1)

@torch.no_grad()
def gather_stats(model, data):
    masks = {
        "train": data.train_mask,
        "val": data.val_mask,
        "test": data.test_mask,
    }
    return {split: accuracy_on_mask(model, data, mask) for split, mask in masks.items()}


In [5]:
# подбор learning rate

lrs = [0.008, 0.01, 0.02]
num_epochs = 200

best_val_acc = 0.0
best_stats = None
best_lr = None

for lr in lrs:
    print(f"\n=== Обучаем модель с lr = {lr} ===")

    model = GAT(
        in_channels=dataset.num_features,
        hidden_channels=8,
        out_channels=dataset.num_classes,
        heads1=8,
        heads2=1,
        dropout=0.6,
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4
    )

    for epoch in range(1, num_epochs + 1):
        loss = train_one_epoch(model, data, optimizer)

        if epoch % 25 == 0 or epoch == 1 or epoch == num_epochs:
            stats = gather_stats(model, data)
            print(
                f"Epoch {epoch:03d} | loss={loss:.4f} | "
                f"train={stats['train']:.3f} | "
                f"val={stats['val']:.3f} | "
                f"test={stats['test']:.3f}"
            )

    stats = gather_stats(model, data)
    if stats["val"] > best_val_acc:
        best_val_acc = stats["val"]
        best_stats = stats
        best_lr = lr

print("\n=== Лучший результат по валидации ===")
print(f"LR: {best_lr}")
print(f"Train acc: {best_stats['train']:.3f}")
print(f"Val   acc: {best_stats['val']:.3f}")
print(f"Test  acc: {best_stats['test']:.3f}")



=== Обучаем модель с lr = 0.008 ===


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  return float(loss)


Epoch 001 | loss=1.9917 | train=0.636 | val=0.516 | test=0.518
Epoch 025 | loss=0.6638 | train=0.993 | val=0.750 | test=0.769
Epoch 050 | loss=0.4482 | train=1.000 | val=0.758 | test=0.789
Epoch 075 | loss=0.6693 | train=1.000 | val=0.776 | test=0.795
Epoch 100 | loss=0.4335 | train=1.000 | val=0.776 | test=0.791
Epoch 125 | loss=0.6028 | train=1.000 | val=0.772 | test=0.792
Epoch 150 | loss=0.4207 | train=1.000 | val=0.782 | test=0.799
Epoch 175 | loss=0.4107 | train=1.000 | val=0.766 | test=0.788
Epoch 200 | loss=0.4135 | train=1.000 | val=0.766 | test=0.795

=== Обучаем модель с lr = 0.01 ===
Epoch 001 | loss=2.0041 | train=0.686 | val=0.538 | test=0.570
Epoch 025 | loss=0.6802 | train=0.986 | val=0.762 | test=0.773
Epoch 050 | loss=0.4187 | train=1.000 | val=0.788 | test=0.802
Epoch 075 | loss=0.4432 | train=1.000 | val=0.772 | test=0.792
Epoch 100 | loss=0.4657 | train=1.000 | val=0.778 | test=0.793
Epoch 125 | loss=0.4083 | train=1.000 | val=0.786 | test=0.801
Epoch 150 | loss=0.


### Вывод по GAT
Оптимальный lr=0.01: на трейне модель почти идеально подогналась, а валид/тест держатся около 78-80%, значит сильного переобучения нет.


## Часть 2. RGAT: link prediction на FB15k-237


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.datasets import RelLinkPredDataset
from torch_geometric.transforms import RandomLinkSplit

from torch_geometric.utils import subgraph
from torch_geometric.data import Data


FB15k-237 - граф знаний: вершины это сущности (люди, фильмы и т.п.), ребра - типизированные отношения между ними, всего 237 relation-типов.


In [7]:


dataset = RelLinkPredDataset(root='data/FB15k237', name='FB15k-237')
data = dataset[0]

print(data)
print(f"Число вершин (entities): {data.num_nodes}")
print(f"Число рёбер (троек): {data.edge_index.shape[1]}")
print(f"shape edge_type: {data.edge_type.shape}")
print(f"Макс тип ребра: {int(data.edge_type.max())}, всего отношений: {int(data.edge_type.max()) + 1}")


Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/entities.dict
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/relations.dict
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/test.txt
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/train.txt
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/valid.txt
Processing...


Data(edge_index=[2, 544230], num_nodes=14541, edge_type=[544230], train_edge_index=[2, 272115], train_edge_type=[272115], valid_edge_index=[2, 17535], valid_edge_type=[17535], test_edge_index=[2, 20466], test_edge_type=[20466])
Число вершин (entities): 14541
Число рёбер (троек): 544230
shape edge_type: torch.Size([544230])
Макс тип ребра: 473, всего отношений: 474


Done!


In [8]:


# немного подсократим датасет, чтобы обучение было быстрее в ноутбуке
max_edges = 40000

num_edges = data.edge_index.size(1)
perm = torch.randperm(num_edges)[:max_edges]

edge_index = data.edge_index[:, perm]
edge_type = data.edge_type[perm]

nodes = torch.unique(edge_index)
edge_index, edge_type = subgraph(
    nodes,
    edge_index,
    edge_attr=edge_type,
    relabel_nodes=True,
)

data = Data(
    edge_index=edge_index,
    edge_type=edge_type,
    num_nodes=int(edge_index.max()) + 1,
)


In [9]:


# разбиение рёбер на train/val/test для задачи link prediction

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=False,
    add_negative_train_samples=True,
    neg_sampling_ratio=1.0,
)

train_data, val_data, test_data = transform(data)

train_data = train_data.to(device)
val_data   = val_data.to(device)
test_data  = test_data.to(device)

print(train_data)
print("train edge_label_index shape:", train_data.edge_label_index.shape)
print("train edge_label shape:", train_data.edge_label.shape)


Data(edge_index=[2, 32000], edge_type=[32000], num_nodes=12382, edge_label=[64000], edge_label_index=[2, 64000])
train edge_label_index shape: torch.Size([2, 64000])
train edge_label shape: torch.Size([64000])


In [10]:


from torch_geometric.nn import RGATConv

class RGATEncoder(nn.Module):
    def __init__(
        self,
        num_nodes: int,
        num_relations: int,
        emb_dim: int = 64,
        hidden_dim: int = 64,
        out_dim: int = 64,
        num_heads: int = 2,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.dropout = dropout

        # обучаемые эмбеддинги вершин
        self.node_emb = nn.Embedding(num_nodes, emb_dim)

        # attention с учётом типа ребра
        self.conv1 = RGATConv(
            in_channels=emb_dim,
            out_channels=hidden_dim,
            num_relations=num_relations,
            heads=num_heads,
            dropout=dropout,
            concat=True,
        )

        self.conv2 = RGATConv(
            in_channels=hidden_dim * num_heads,
            out_channels=out_dim,
            num_relations=num_relations,
            heads=num_heads,
            dropout=dropout,
            concat=False,
        )

    def forward(self, edge_index, edge_type):
        x = self.node_emb.weight

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_type)

        return x  # размер [num_nodes, out_dim]

class LinkPredictor(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int = 64):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x_i, x_j):
        h = torch.cat([x_i, x_j], dim=-1)
        out = self.mlp(h).squeeze(-1)
        return out


In [11]:


def get_edge_embeddings(emb, edge_label_index):
    src, dst = edge_label_index
    return emb[src], emb[dst]


def train_one_epoch(encoder, predictor, data, optimizer):
    encoder.train()
    predictor.train()
    optimizer.zero_grad()

    node_emb = encoder(data.edge_index, data.edge_type)
    ei, ej = get_edge_embeddings(node_emb, data.edge_label_index)
    logits = predictor(ei, ej)
    labels = data.edge_label.float()

    loss = F.binary_cross_entropy_with_logits(logits, labels)
    loss.backward()
    optimizer.step()

    return float(loss)


In [12]:


@torch.no_grad()
def eval_split(encoder, predictor, data):
    encoder.eval()
    predictor.eval()

    node_emb = encoder(data.edge_index, data.edge_type)
    ei, ej = get_edge_embeddings(node_emb, data.edge_label_index)
    logits = predictor(ei, ej)
    probs = torch.sigmoid(logits)

    labels = data.edge_label.float()

    pred_labels = (probs > 0.5).float()
    acc = (pred_labels == labels).float().mean().item()

    return acc


In [13]:
num_nodes = data.num_nodes
num_relations = int(data.edge_type.max().item()) + 1

lrs = [1e-3, 7e-4]
num_epochs = 20

best_val_acc = 0.0
best_test_acc = 0.0
best_lr = None

for lr in lrs:
    print(f"\n=== Обучаем RGAT-модель для link prediction, lr = {lr} ===")

    encoder = RGATEncoder(
        num_nodes=num_nodes,
        num_relations=num_relations,
        emb_dim=64,
        hidden_dim=64,
        out_dim=64,
        num_heads=2,
        dropout=0.3,
    ).to(device)

    predictor = LinkPredictor(in_dim=64, hidden_dim=64).to(device)

    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(predictor.parameters()),
        lr=lr,
        weight_decay=1e-5,
    )

    for epoch in range(1, num_epochs + 1):
        loss = train_one_epoch(encoder, predictor, train_data, optimizer)

        if epoch % 10 == 0 or epoch == 1:
            train_acc = eval_split(encoder, predictor, train_data)
            val_acc = eval_split(encoder, predictor, val_data)
            test_acc = eval_split(encoder, predictor, test_data)

            print(
                f"Epoch {epoch:03d} | loss={loss:.4f} | "
                f"train_acc={train_acc:.3f} | val_acc={val_acc:.3f} | test_acc={test_acc:.3f}"
            )

    val_acc = eval_split(encoder, predictor, val_data)
    test_acc = eval_split(encoder, predictor, test_data)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
        best_lr = lr

print("\n=== Лучший результат по валидации для RGATLinkPrediction ===")
print(f"LR: {best_lr}")
print(f"Val   acc: {best_val_acc:.3f}")
print(f"Test  acc: {best_test_acc:.3f}")



=== Обучаем RGAT-модель для link prediction, lr = 0.001 ===
Epoch 001 | loss=0.6957 | train_acc=0.503 | val_acc=0.504 | test_acc=0.504
Epoch 010 | loss=0.6262 | train_acc=0.681 | val_acc=0.675 | test_acc=0.666
Epoch 020 | loss=0.5590 | train_acc=0.756 | val_acc=0.740 | test_acc=0.742

=== Обучаем RGAT-модель для link prediction, lr = 0.0007 ===
Epoch 001 | loss=0.6949 | train_acc=0.506 | val_acc=0.503 | test_acc=0.506
Epoch 010 | loss=0.6593 | train_acc=0.623 | val_acc=0.617 | test_acc=0.609
Epoch 020 | loss=0.5856 | train_acc=0.730 | val_acc=0.719 | test_acc=0.715

=== Лучший результат по валидации для RGATLinkPrediction ===
LR: 0.001
Val   acc: 0.740
Test  acc: 0.742



### Вывод по RGAT
RGAT выучил структуру гетерогенного графа: на валидации и тесте держится около 0.74 acc, значимого overfitting нет.
