# 02-Training-PyTorch-GNN

## 1. Loading dataset

In [1]:
from torch_geometric.datasets import TUDataset

# download dataset
dataset = TUDataset(root="../datasets/", name="IMDB-BINARY")

## 2. Traing Node2Vec model

In [11]:
import numpy as np
import torch
from torch_geometric.nn import Node2Vec

device = 'cpu'
data = dataset
data.y = data.y

# Train Node2Vec model
emb_model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=30,
                     context_size=30, walks_per_node=20,
                     num_negative_samples=1, p=0.5, q=0.5).to(device)

loader = emb_model.loader(batch_size=128, shuffle=True, num_workers=0)
optimizer = torch.optim.Adam(emb_model.parameters(), lr=0.01)

In [12]:
def train_test_split(test_ratio=0.1):
    y = data.data.y
    test_size = int(len(y) * test_ratio)
    classes = np.unique(y)

    train_idx = []
    test_idx = []

    for cls in classes:
        idx = (y == cls).nonzero().view(-1)
        idx = idx[torch.randperm(idx.size(0))]
        train_idx.extend(idx[:-test_size])
        test_idx.extend(idx[-test_size:])

    train_idx = torch.tensor(train_idx).to(device)
    test_idx = torch.tensor(test_idx).to(device)

    return train_idx, test_idx


train_idx, test_idx = train_test_split(test_ratio=0.1)

In [13]:
def train(epoch, model):
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in tqdm(loader, desc=f'Epoch {epoch:03d}'):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


for epoch in range(1, 21):
    loss = train(epoch, emb_model)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

Epoch 001: 100%|██████████| 155/155 [00:05<00:00, 29.76it/s]


Epoch: 001, Loss: 8.1273


Epoch 002: 100%|██████████| 155/155 [00:05<00:00, 29.83it/s]


Epoch: 002, Loss: 5.1322


Epoch 003: 100%|██████████| 155/155 [00:05<00:00, 29.89it/s]


Epoch: 003, Loss: 4.1379


Epoch 004: 100%|██████████| 155/155 [00:05<00:00, 29.91it/s]


Epoch: 004, Loss: 3.5016


Epoch 005: 100%|██████████| 155/155 [00:05<00:00, 29.95it/s]


Epoch: 005, Loss: 2.9289


Epoch 006: 100%|██████████| 155/155 [00:05<00:00, 29.78it/s]


Epoch: 006, Loss: 2.4471


Epoch 007: 100%|██████████| 155/155 [00:05<00:00, 29.77it/s]


Epoch: 007, Loss: 2.0540


Epoch 008: 100%|██████████| 155/155 [00:05<00:00, 29.89it/s]


Epoch: 008, Loss: 1.7421


Epoch 009: 100%|██████████| 155/155 [00:05<00:00, 29.77it/s]


Epoch: 009, Loss: 1.4961


Epoch 010: 100%|██████████| 155/155 [00:05<00:00, 29.91it/s]


Epoch: 010, Loss: 1.3071


Epoch 011: 100%|██████████| 155/155 [00:05<00:00, 29.87it/s]


Epoch: 011, Loss: 1.1637


Epoch 012: 100%|██████████| 155/155 [00:05<00:00, 29.61it/s]


Epoch: 012, Loss: 1.0554


Epoch 013: 100%|██████████| 155/155 [00:05<00:00, 29.72it/s]


Epoch: 013, Loss: 0.9766


Epoch 014: 100%|██████████| 155/155 [00:05<00:00, 29.75it/s]


Epoch: 014, Loss: 0.9177


Epoch 015: 100%|██████████| 155/155 [00:05<00:00, 29.69it/s]


Epoch: 015, Loss: 0.8744


Epoch 016: 100%|██████████| 155/155 [00:05<00:00, 29.86it/s]


Epoch: 016, Loss: 0.8423


Epoch 017: 100%|██████████| 155/155 [00:05<00:00, 29.80it/s]


Epoch: 017, Loss: 0.8180


Epoch 018: 100%|██████████| 155/155 [00:05<00:00, 29.82it/s]


Epoch: 018, Loss: 0.7997


Epoch 019: 100%|██████████| 155/155 [00:05<00:00, 29.83it/s]


Epoch: 019, Loss: 0.7853


Epoch 020: 100%|██████████| 155/155 [00:05<00:00, 29.83it/s]

Epoch: 020, Loss: 0.7742





In [14]:
import pickle

with open('../models/node2vec.pkl', 'wb') as f:
    pickle.dump(emb_model.state_dict(), f)

## 3. Assigning node embeddings to nodes in dataset

In [15]:
import pickle

with open('../models/node2vec.pkl', 'rb') as f:
    emb_model.load_state_dict(state_dict=pickle.load(f))

In [16]:
node_embeddings = emb_model()

In [17]:
from torch_geometric.io import read_tu_data


class FilmsDataset(TUDataset):
    def __init__(self, root: str, name: str):
        super().__init__(root, name)

    def process(self):
        self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)

        data_list = [self.get(idx) for idx in range(len(self))]
        num_nodes = 0
        for data in data_list:
            data.x = node_embeddings[num_nodes: num_nodes + data.num_nodes]
            num_nodes += data.num_nodes

        self.data, self.slices = self.collate(data_list)
        self._data_list = None

        torch.save((self._data.to_dict(), self.slices, sizes),
                   self.processed_paths[0])

In [18]:
dataset = FilmsDataset(root="../datasets/", name="IMDB-BINARY")

## 4. GCN Class

In [19]:
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

In [20]:
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, return_embeds=False):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList([GCNConv(input_dim, hidden_dim)] + \
                                         [GCNConv(hidden_dim, hidden_dim) for i in range(num_layers - 2)] + \
                                         [GCNConv(hidden_dim, output_dim)])

        self.bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(hidden_dim) for i in range(num_layers - 1)])

        self.softmax = torch.nn.LogSoftmax(dim=-1)

        self.dropout = dropout

        # Skip classification layer and return node embeddings
        self.return_embeds = return_embeds

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i in range(len(self.convs) - 1):
            x = self.convs[i](x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, adj_t)
        if self.return_embeds:
            out = x
        else:
            out = self.softmax(x)

        return out

In [21]:
class GCN_Graph(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super(GCN_Graph, self).__init__()

        self.gnn_node = GCN(input_dim, hidden_dim,
                            hidden_dim, num_layers, dropout, return_embeds=True)

        # Pooling layer
        self.pool = global_mean_pool

        # Output layer
        self.linear = torch.nn.Linear(hidden_dim, output_dim)

    def reset_parameters(self):
        self.gnn_node.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, batched_data):
        # Extract important attributes of our mini-batch
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch

        embed = self.gnn_node(x, edge_index)

        out = self.pool(embed, batch)

        out = self.linear(out)

        return out

## 5. Preparing data

In [22]:
train_dataset, test_dataset = dataset[train_idx], dataset[test_idx]

In [23]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

## 6. Training

In [24]:
model = GCN_Graph(input_dim=256, hidden_dim=128, output_dim=dataset.num_classes,
                  num_layers=2, dropout=0.2).to(device)

In [34]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [35]:
from tqdm import tqdm


def train(epoch, model, optimizer, criterion, train_loader):
    model.train()

    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
        optimizer.zero_grad()

        out = model(batch.to(device))

        loss = criterion(out, batch.y)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

In [36]:
def test(model, test_loader):
    model.eval()

    correct = 0
    for batch in test_loader:
        with torch.no_grad():
            out = model(batch.to(device))

        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())

    return correct / len(test_loader.dataset)

In [37]:
epochs = 50
model.reset_parameters()
for epoch in range(epochs):
    loss = train(epoch, model, optimizer, criterion, train_loader)
    acc = test(model, test_loader)
    print(f"Epoch: {epoch + 1:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}")

Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 41.05it/s]


Epoch: 001, Loss: 0.7206, Acc: 0.5250


Epoch 2: 100%|██████████| 7/7 [00:00<00:00, 40.45it/s]


Epoch: 002, Loss: 0.6608, Acc: 0.5750


Epoch 3: 100%|██████████| 7/7 [00:00<00:00, 39.55it/s]


Epoch: 003, Loss: 0.6027, Acc: 0.5550


Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 42.13it/s]


Epoch: 004, Loss: 0.5497, Acc: 0.5700


Epoch 5: 100%|██████████| 7/7 [00:00<00:00, 41.49it/s]


Epoch: 005, Loss: 0.5075, Acc: 0.5700


Epoch 6: 100%|██████████| 7/7 [00:00<00:00, 40.71it/s]


Epoch: 006, Loss: 0.4397, Acc: 0.5800


Epoch 7: 100%|██████████| 7/7 [00:00<00:00, 31.67it/s]


Epoch: 007, Loss: 0.3793, Acc: 0.5850


Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 39.84it/s]


Epoch: 008, Loss: 0.3100, Acc: 0.5600


Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 41.52it/s]


Epoch: 009, Loss: 0.2680, Acc: 0.5600


Epoch 10: 100%|██████████| 7/7 [00:00<00:00, 41.11it/s]


Epoch: 010, Loss: 0.1997, Acc: 0.5650


Epoch 11: 100%|██████████| 7/7 [00:00<00:00, 40.91it/s]


Epoch: 011, Loss: 0.1524, Acc: 0.5600


Epoch 12: 100%|██████████| 7/7 [00:00<00:00, 41.63it/s]


Epoch: 012, Loss: 0.1208, Acc: 0.5850


Epoch 13: 100%|██████████| 7/7 [00:00<00:00, 39.48it/s]


Epoch: 013, Loss: 0.0887, Acc: 0.5800


Epoch 14: 100%|██████████| 7/7 [00:00<00:00, 40.79it/s]


Epoch: 014, Loss: 0.0665, Acc: 0.5500


Epoch 15: 100%|██████████| 7/7 [00:00<00:00, 41.50it/s]


Epoch: 015, Loss: 0.0385, Acc: 0.5850


Epoch 16: 100%|██████████| 7/7 [00:00<00:00, 40.37it/s]


Epoch: 016, Loss: 0.0281, Acc: 0.5700


Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 41.48it/s]


Epoch: 017, Loss: 0.0227, Acc: 0.5550


Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 40.93it/s]


Epoch: 018, Loss: 0.0149, Acc: 0.5550


Epoch 19: 100%|██████████| 7/7 [00:00<00:00, 39.83it/s]


Epoch: 019, Loss: 0.0174, Acc: 0.5400


Epoch 20: 100%|██████████| 7/7 [00:00<00:00, 41.35it/s]


Epoch: 020, Loss: 0.0114, Acc: 0.5450


Epoch 21: 100%|██████████| 7/7 [00:00<00:00, 39.62it/s]


Epoch: 021, Loss: 0.0096, Acc: 0.5400


Epoch 22: 100%|██████████| 7/7 [00:00<00:00, 40.50it/s]


Epoch: 022, Loss: 0.0081, Acc: 0.5450


Epoch 23: 100%|██████████| 7/7 [00:00<00:00, 41.11it/s]


Epoch: 023, Loss: 0.0062, Acc: 0.5450


Epoch 24: 100%|██████████| 7/7 [00:00<00:00, 40.89it/s]


Epoch: 024, Loss: 0.0067, Acc: 0.5450


Epoch 25: 100%|██████████| 7/7 [00:00<00:00, 40.64it/s]


Epoch: 025, Loss: 0.0050, Acc: 0.5450


Epoch 26: 100%|██████████| 7/7 [00:00<00:00, 40.50it/s]


Epoch: 026, Loss: 0.0048, Acc: 0.5550


Epoch 27: 100%|██████████| 7/7 [00:00<00:00, 41.78it/s]


Epoch: 027, Loss: 0.0035, Acc: 0.5500


Epoch 28: 100%|██████████| 7/7 [00:00<00:00, 41.93it/s]


Epoch: 028, Loss: 0.0042, Acc: 0.5350


Epoch 29: 100%|██████████| 7/7 [00:00<00:00, 40.38it/s]


Epoch: 029, Loss: 0.0032, Acc: 0.5500


Epoch 30: 100%|██████████| 7/7 [00:00<00:00, 42.60it/s]


Epoch: 030, Loss: 0.0033, Acc: 0.5550


Epoch 31: 100%|██████████| 7/7 [00:00<00:00, 41.18it/s]


Epoch: 031, Loss: 0.0032, Acc: 0.5550


Epoch 32: 100%|██████████| 7/7 [00:00<00:00, 41.72it/s]


Epoch: 032, Loss: 0.0049, Acc: 0.5550


Epoch 33: 100%|██████████| 7/7 [00:00<00:00, 42.02it/s]


Epoch: 033, Loss: 0.0063, Acc: 0.5550


Epoch 34: 100%|██████████| 7/7 [00:00<00:00, 40.00it/s]


Epoch: 034, Loss: 0.0024, Acc: 0.5550


Epoch 35: 100%|██████████| 7/7 [00:00<00:00, 41.31it/s]


Epoch: 035, Loss: 0.0022, Acc: 0.5450


Epoch 36: 100%|██████████| 7/7 [00:00<00:00, 40.48it/s]


Epoch: 036, Loss: 0.0022, Acc: 0.5450


Epoch 37: 100%|██████████| 7/7 [00:00<00:00, 41.99it/s]


Epoch: 037, Loss: 0.0094, Acc: 0.5450


Epoch 38: 100%|██████████| 7/7 [00:00<00:00, 42.18it/s]


Epoch: 038, Loss: 0.0063, Acc: 0.5650


Epoch 39: 100%|██████████| 7/7 [00:00<00:00, 40.32it/s]


Epoch: 039, Loss: 0.0150, Acc: 0.5650


Epoch 40: 100%|██████████| 7/7 [00:00<00:00, 40.87it/s]


Epoch: 040, Loss: 0.0053, Acc: 0.5550


Epoch 41: 100%|██████████| 7/7 [00:00<00:00, 41.98it/s]


Epoch: 041, Loss: 0.0085, Acc: 0.5400


Epoch 42: 100%|██████████| 7/7 [00:00<00:00, 39.93it/s]


Epoch: 042, Loss: 0.0058, Acc: 0.5450


Epoch 43: 100%|██████████| 7/7 [00:00<00:00, 42.23it/s]


Epoch: 043, Loss: 0.0028, Acc: 0.5500


Epoch 44: 100%|██████████| 7/7 [00:00<00:00, 38.10it/s]


Epoch: 044, Loss: 0.0062, Acc: 0.5450


Epoch 45: 100%|██████████| 7/7 [00:00<00:00, 41.99it/s]


Epoch: 045, Loss: 0.0031, Acc: 0.5450


Epoch 46: 100%|██████████| 7/7 [00:00<00:00, 40.17it/s]


Epoch: 046, Loss: 0.0045, Acc: 0.5400


Epoch 47: 100%|██████████| 7/7 [00:00<00:00, 39.68it/s]


Epoch: 047, Loss: 0.0021, Acc: 0.5600


Epoch 48: 100%|██████████| 7/7 [00:00<00:00, 41.64it/s]


Epoch: 048, Loss: 0.0020, Acc: 0.5600


Epoch 49: 100%|██████████| 7/7 [00:00<00:00, 40.33it/s]


Epoch: 049, Loss: 0.0020, Acc: 0.5550


Epoch 50: 100%|██████████| 7/7 [00:00<00:00, 41.25it/s]


Epoch: 050, Loss: 0.0013, Acc: 0.5450


# Credits

Notebook was made by GCN-tutorial team for the Innopolis 'Data and Knowledge representation' course.

> [Polina Zelenskaya](github.com/cutefluffyfox) \
> [Said Kamalov](github.com/SaidKamalov) \
> [Lev Rekhlov](github.com/plov-cyber)