In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter

Definition of a graph `Data(x=[20, 5], edge_index=[2, 20], edge_attr=[20], y=[20])` (A graph with 20 nodes, 20 edges, 6 different edge types and 1 target label per node) is decomposed into

- x: Data stored in nodes (here, num_nodes=20 nodes x dim=5 features):

```
tensor([[ 0.1338,  0.8443, -1.8535,  1.1781, -1.4325],
        [ 0.6020,  0.9776,  0.3087, -0.5191, -0.8351],
        [-0.1933,  0.7537,  0.5730, -1.6758,  0.7305],
        [-0.0866,  0.0894,  1.5008, -0.7608,  0.7714],
        [-0.1519,  1.0402, -0.8959,  0.1330, -0.1550],
        [ 0.0830,  0.4799, -0.5175,  2.2978, -0.9849],
        [ 0.2736,  1.2479, -0.0293, -0.2829,  1.9557],
        [-0.9310,  0.2792, -0.5623, -0.4188, -0.5580],
        [-0.0870,  0.6735,  0.7130, -0.7389,  1.3117],
        [-0.5930, -1.0240, -0.6158,  0.1470, -0.7293],
        [ 0.4986, -1.7924,  0.9244,  1.1364,  1.2943],
        [ 0.5379,  0.4017,  0.3074, -0.8189,  0.5053],
        [ 2.5506,  0.5360, -0.0889, -0.2462, -0.9926],
        [ 0.5416, -0.1752, -0.3871, -0.2850,  0.3497],
        [-0.3282,  1.2847, -0.3605, -0.6254,  0.5127],
        [ 0.1769, -2.4991,  1.4926, -0.4233, -0.9469],
        [ 1.2659, -1.5897, -0.6909, -0.0036,  1.1759],
        [-0.4856,  0.3826, -0.6284, -0.4664,  0.0383],
        [ 0.2546,  1.5097,  1.2746, -0.5847,  1.1207],
        [ 0.9572,  0.4619, -0.4879, -0.4740, -0.1182]])
```

- edge_index: Encodes connections between nodes. It's always (2,num_edges) shape, the first list is source ids and the target ids for connections (here, num_edges=20):

```
tensor([[ 8,  5, 14,  5,  4, 14, 15,  4, 17, 13, 12, 15, 13, 11, 10, 12, 19, 0, 19, 16],
        [ 5,  8, 15,  0, 13,  0,  2, 17,  9, 16,  3, 12, 18,  4,  5, 18, 12, 8, 16, 10]])
```

- edge_attr: Encodes types of each edge (since we deal with multiple different edge types in RGCN). The shape is therefore (num_edges, num_edge_features=1) here with values [0, num_relations-1].

```
tensor([3, 2, 0, 3, 4, 5, 2, 1, 5, 3, 5, 0, 0, 4, 2, 2, 4, 1, 1, 5])
```

- y: target labels with (n_labels=num_nodes,label_dim=1) shape with n_labels=num_nodes since we do node-level labels (graph-level single prediction is also possible) and label_dim=1 since in this toy example we want to have classification (one value per node indicating target class).

```
tensor([1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0])
```

In [2]:
# Create a mock graph dataset to test how graph data works
from torch_geometric.data import Data

num_nodes = 20 # num of nodes
dim = 5 # dim of node features
num_edges = 22 # num of edges
num_classes = 2 # num of target classes
num_relations = 6 # num of edge types

def generate_random_graph():
    # --- Initialize nodes ---
    x = torch.randn(num_nodes, dim) # creates num_nodes x dim tensor of (random) node features <- TODO: replace with actual node features

    # --- Initialize edges --- 
    edge_index = torch.randint(num_nodes, (2, num_edges)) # creates 2 x num_edges tensor of (random) edges <- TODO: replace with actual edges

    # --- Introduce different types of edges ---
    edge_attr = torch.randint(num_relations, (num_edges,)) # creates num_edges x 1 tensor of (random) edge types <- TODO: replace with actual edge types

    # --- Initialize labels ---
    y = torch.randint(num_classes, (num_nodes,)) # creates num_nodes x 1 tensor of (random) labels <- TODO: replace with actual labels

    # --- Create a graph ---
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

    return graph

graph = generate_random_graph()
print(graph)

Data(x=[20, 5], edge_index=[2, 22], edge_attr=[22], y=[20])


In [3]:
import torch_geometric.transforms as T
from torch_geometric.nn import RGCNConv
from torch import nn

class RGCNNet(torch.nn.Module):
    def __init__(
        self, 
        in_channels,
        h_channels_list,
        out_channels, 
        num_relations,
        num_bases=None, 
        aggr='mean',
        dp_rate=0.1, 
        bias=True,
        activation=nn.LeakyReLU(negative_slope=0.2, inplace=True)
    ):
        super(RGCNNet, self).__init__()
        self.num_layers = len(h_channels_list) + 1
        self.layers = []
        
        for i in range(self.num_layers - 1):
            in_channels = in_channels if i == 0 else h_channels_list[i - 1]
            out_channels = h_channels_list[i]
            self.layers += [
                RGCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    num_relations=num_relations,
                    num_bases=num_bases,
                    aggr=aggr,
                    bias=bias
                ),
                activation,
                nn.Dropout(p=dp_rate)
            ]
        self.layers += [
            RGCNConv(
                in_channels=h_channels_list[-1] if self.num_layers > 1 else in_channels,
                out_channels=out_channels,
                num_relations=num_relations,
                num_bases=num_bases,
                aggr=aggr,
                bias=bias
            ),
            activation,
        ]
        
        self.layers = torch.nn.ModuleList(self.layers)
        
    def forward(self, x, edge_index, edge_type):
        for layer in self.layers:
            if isinstance(layer, RGCNConv):
                x = layer(x, edge_index, edge_type)
            else:
                x = layer(x)
        return x


In [4]:
from torch.nn import CrossEntropyLoss

# --- Initialize model ---
# num_bases = 6 # num of bases # TODO: experiment with this
h_channels_list = [10, 5] # list of hidden layer sizes

model = RGCNNet(
    in_channels=dim,
    h_channels_list=h_channels_list,
    out_channels=num_classes,
    num_relations=num_relations,
    # num_bases=num_bases,
    aggr="mean",
    dp_rate=0.1,
    bias=True
)

In [5]:
# --- Initialize train and val loaders ---
from torch_geometric.loader import DataLoader

OVERFIT_BATCH_SIZE = 1
train_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]
val_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]

### Overfit for reconstruction (should learn identity)

In [10]:
# Reset the model
# num_bases = 6 # num of bases # TODO: experiment with this
h_channels_list = [] # list of hidden layer sizes

activation = nn.Identity()

model = RGCNNet(
    in_channels=dim,
    h_channels_list=h_channels_list,
    out_channels=dim,
    num_relations=num_relations,
    # num_bases=num_bases,
    aggr="mean",
    dp_rate=0,
    bias=True,
    activation=activation
)
print(model)

train_loader = DataLoader(train_batch, batch_size=OVERFIT_BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_batch, batch_size=OVERFIT_BATCH_SIZE, shuffle=False)

# --- Setup training loop ---
from tqdm import tqdm

epochs = 1000 # num of epochs
best_loss = float('inf') # initialize best loss
best_epoch = 0 # initialize best epoch

# --- Initialize optimizer ---
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

lr = 1e-1 # learning rate
weight_decay = 5e-4 # weight decay
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# --- Initialize reconstruction loss function ---
loss_fn = torch.nn.MSELoss()

# --- Initialize scheduler ---
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=500, min_lr=0.00001)

# --- Initialize tensorboard ---
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/rgcn-test/overfit-recon/{OVERFIT_BATCH_SIZE}-lr-{lr}-{now.strftime("%Y%m%d-%H:%M:%S")}')

for epoch in (pbar:=tqdm(range(epochs))):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.x)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= len(train_loader)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    
    train_loss = epoch_loss
    
    # Scheduler listens to the train loss because we want to overfit here
    scheduler.step(epoch_loss)
    
    model.eval()
    epoch_loss = 0
    for batch in val_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.x)
        epoch_loss += loss.detach().item()
    epoch_loss /= len(val_loader)
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    
    val_loss = epoch_loss
    
    print(f"Epoch {epoch} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    # if epoch_loss < best_loss:
    #     best_loss = epoch_loss
    #     best_epoch = epoch
    #     torch.save(model.state_dict(), 'models/model.pth')
    
writer.flush()
writer.close()

RGCNNet(
  (layers): ModuleList(
    (0): RGCNConv(5, 5, num_relations=6)
    (1): Identity()
  )
)


  5%|▍         | 47/1000 [00:00<00:02, 467.35it/s]

Epoch 0 - Train loss: 1.6362, Val loss: 1.2069
Epoch 1 - Train loss: 0.7112, Val loss: 1.0650
Epoch 2 - Train loss: 0.3324, Val loss: 0.9865
Epoch 3 - Train loss: 0.2585, Val loss: 0.9456
Epoch 4 - Train loss: 0.2793, Val loss: 0.9267
Epoch 5 - Train loss: 0.2919, Val loss: 0.9202
Epoch 6 - Train loss: 0.2770, Val loss: 0.9193
Epoch 7 - Train loss: 0.2477, Val loss: 0.9156
Epoch 8 - Train loss: 0.2178, Val loss: 0.8990
Epoch 9 - Train loss: 0.1922, Val loss: 0.8655
Epoch 10 - Train loss: 0.1693, Val loss: 0.8209
Epoch 11 - Train loss: 0.1464, Val loss: 0.7746
Epoch 12 - Train loss: 0.1231, Val loss: 0.7342
Epoch 13 - Train loss: 0.1018, Val loss: 0.7034
Epoch 14 - Train loss: 0.0856, Val loss: 0.6830
Epoch 15 - Train loss: 0.0768, Val loss: 0.6717
Epoch 16 - Train loss: 0.0748, Val loss: 0.6668
Epoch 17 - Train loss: 0.0760, Val loss: 0.6650
Epoch 18 - Train loss: 0.0757, Val loss: 0.6635
Epoch 19 - Train loss: 0.0713, Val loss: 0.6609
Epoch 20 - Train loss: 0.0630, Val loss: 0.6574
Ep

 11%|█▏        | 114/1000 [00:00<00:01, 583.47it/s]

Epoch 112 - Train loss: 0.0001, Val loss: 0.2010
Epoch 113 - Train loss: 0.0001, Val loss: 0.1986
Epoch 114 - Train loss: 0.0001, Val loss: 0.1963
Epoch 115 - Train loss: 0.0001, Val loss: 0.1940


 18%|█▊        | 177/1000 [00:00<00:01, 602.11it/s]

Epoch 116 - Train loss: 0.0001, Val loss: 0.1917
Epoch 117 - Train loss: 0.0001, Val loss: 0.1893
Epoch 118 - Train loss: 0.0001, Val loss: 0.1870
Epoch 119 - Train loss: 0.0001, Val loss: 0.1848
Epoch 120 - Train loss: 0.0001, Val loss: 0.1826
Epoch 121 - Train loss: 0.0001, Val loss: 0.1804
Epoch 122 - Train loss: 0.0001, Val loss: 0.1783
Epoch 123 - Train loss: 0.0001, Val loss: 0.1762
Epoch 124 - Train loss: 0.0001, Val loss: 0.1740
Epoch 125 - Train loss: 0.0001, Val loss: 0.1719
Epoch 126 - Train loss: 0.0001, Val loss: 0.1698
Epoch 127 - Train loss: 0.0001, Val loss: 0.1677
Epoch 128 - Train loss: 0.0001, Val loss: 0.1657
Epoch 129 - Train loss: 0.0001, Val loss: 0.1636
Epoch 130 - Train loss: 0.0001, Val loss: 0.1617
Epoch 131 - Train loss: 0.0001, Val loss: 0.1597
Epoch 132 - Train loss: 0.0001, Val loss: 0.1578
Epoch 133 - Train loss: 0.0001, Val loss: 0.1559
Epoch 134 - Train loss: 0.0001, Val loss: 0.1540
Epoch 135 - Train loss: 0.0001, Val loss: 0.1522
Epoch 136 - Train lo

 24%|██▍       | 242/1000 [00:00<00:01, 617.36it/s]

Epoch 240 - Train loss: 0.0000, Val loss: 0.0422
Epoch 241 - Train loss: 0.0000, Val loss: 0.0417
Epoch 242 - Train loss: 0.0000, Val loss: 0.0412
Epoch 243 - Train loss: 0.0000, Val loss: 0.0407


 30%|███       | 304/1000 [00:00<00:01, 565.62it/s]

Epoch 244 - Train loss: 0.0000, Val loss: 0.0402
Epoch 245 - Train loss: 0.0000, Val loss: 0.0397
Epoch 246 - Train loss: 0.0000, Val loss: 0.0393
Epoch 247 - Train loss: 0.0000, Val loss: 0.0388
Epoch 248 - Train loss: 0.0000, Val loss: 0.0383
Epoch 249 - Train loss: 0.0000, Val loss: 0.0379
Epoch 250 - Train loss: 0.0000, Val loss: 0.0374
Epoch 251 - Train loss: 0.0000, Val loss: 0.0369
Epoch 252 - Train loss: 0.0000, Val loss: 0.0365
Epoch 253 - Train loss: 0.0000, Val loss: 0.0361
Epoch 254 - Train loss: 0.0000, Val loss: 0.0356
Epoch 255 - Train loss: 0.0000, Val loss: 0.0352
Epoch 256 - Train loss: 0.0000, Val loss: 0.0348
Epoch 257 - Train loss: 0.0000, Val loss: 0.0343
Epoch 258 - Train loss: 0.0000, Val loss: 0.0339
Epoch 259 - Train loss: 0.0000, Val loss: 0.0335
Epoch 260 - Train loss: 0.0000, Val loss: 0.0331
Epoch 261 - Train loss: 0.0000, Val loss: 0.0327
Epoch 262 - Train loss: 0.0000, Val loss: 0.0323
Epoch 263 - Train loss: 0.0000, Val loss: 0.0319
Epoch 264 - Train lo

 44%|████▎     | 435/1000 [00:00<00:00, 611.77it/s]

Epoch 355 - Train loss: 0.0000, Val loss: 0.0105
Epoch 356 - Train loss: 0.0000, Val loss: 0.0104
Epoch 357 - Train loss: 0.0000, Val loss: 0.0103
Epoch 358 - Train loss: 0.0000, Val loss: 0.0102
Epoch 359 - Train loss: 0.0000, Val loss: 0.0100
Epoch 360 - Train loss: 0.0000, Val loss: 0.0099
Epoch 361 - Train loss: 0.0000, Val loss: 0.0098
Epoch 362 - Train loss: 0.0000, Val loss: 0.0097
Epoch 363 - Train loss: 0.0000, Val loss: 0.0096
Epoch 364 - Train loss: 0.0000, Val loss: 0.0095
Epoch 365 - Train loss: 0.0000, Val loss: 0.0094
Epoch 366 - Train loss: 0.0000, Val loss: 0.0093
Epoch 367 - Train loss: 0.0000, Val loss: 0.0091
Epoch 368 - Train loss: 0.0000, Val loss: 0.0090
Epoch 369 - Train loss: 0.0000, Val loss: 0.0089
Epoch 370 - Train loss: 0.0000, Val loss: 0.0088
Epoch 371 - Train loss: 0.0000, Val loss: 0.0087
Epoch 372 - Train loss: 0.0000, Val loss: 0.0086
Epoch 373 - Train loss: 0.0000, Val loss: 0.0085
Epoch 374 - Train loss: 0.0000, Val loss: 0.0084
Epoch 375 - Train lo

 57%|█████▋    | 568/1000 [00:00<00:00, 637.18it/s]

Epoch 487 - Train loss: 0.0000, Val loss: 0.0027
Epoch 488 - Train loss: 0.0000, Val loss: 0.0027
Epoch 489 - Train loss: 0.0000, Val loss: 0.0027
Epoch 490 - Train loss: 0.0000, Val loss: 0.0027
Epoch 491 - Train loss: 0.0000, Val loss: 0.0026
Epoch 492 - Train loss: 0.0000, Val loss: 0.0026
Epoch 493 - Train loss: 0.0000, Val loss: 0.0026
Epoch 494 - Train loss: 0.0000, Val loss: 0.0026
Epoch 495 - Train loss: 0.0000, Val loss: 0.0026
Epoch 496 - Train loss: 0.0000, Val loss: 0.0025
Epoch 497 - Train loss: 0.0000, Val loss: 0.0025
Epoch 498 - Train loss: 0.0000, Val loss: 0.0025
Epoch 499 - Train loss: 0.0000, Val loss: 0.0025
Epoch 500 - Train loss: 0.0000, Val loss: 0.0025
Epoch 501 - Train loss: 0.0000, Val loss: 0.0024
Epoch 502 - Train loss: 0.0000, Val loss: 0.0024
Epoch 503 - Train loss: 0.0000, Val loss: 0.0024
Epoch 504 - Train loss: 0.0000, Val loss: 0.0024
Epoch 505 - Train loss: 0.0000, Val loss: 0.0024
Epoch 506 - Train loss: 0.0000, Val loss: 0.0024
Epoch 507 - Train lo

 70%|███████   | 700/1000 [00:01<00:00, 645.95it/s]

Epoch 621 - Train loss: 0.0000, Val loss: 0.0015
Epoch 622 - Train loss: 0.0000, Val loss: 0.0015
Epoch 623 - Train loss: 0.0000, Val loss: 0.0015
Epoch 624 - Train loss: 0.0000, Val loss: 0.0015
Epoch 625 - Train loss: 0.0000, Val loss: 0.0015
Epoch 626 - Train loss: 0.0000, Val loss: 0.0015
Epoch 627 - Train loss: 0.0000, Val loss: 0.0015
Epoch 628 - Train loss: 0.0000, Val loss: 0.0015
Epoch 629 - Train loss: 0.0000, Val loss: 0.0014
Epoch 630 - Train loss: 0.0000, Val loss: 0.0014
Epoch 631 - Train loss: 0.0000, Val loss: 0.0014
Epoch 632 - Train loss: 0.0000, Val loss: 0.0014
Epoch 633 - Train loss: 0.0000, Val loss: 0.0014
Epoch 634 - Train loss: 0.0000, Val loss: 0.0014
Epoch 635 - Train loss: 0.0000, Val loss: 0.0014
Epoch 636 - Train loss: 0.0000, Val loss: 0.0014
Epoch 637 - Train loss: 0.0000, Val loss: 0.0014
Epoch 638 - Train loss: 0.0000, Val loss: 0.0014
Epoch 639 - Train loss: 0.0000, Val loss: 0.0014
Epoch 640 - Train loss: 0.0000, Val loss: 0.0014
Epoch 641 - Train lo

 83%|████████▎ | 833/1000 [00:01<00:00, 652.11it/s]

Epoch 753 - Train loss: 0.0000, Val loss: 0.0013
Epoch 754 - Train loss: 0.0000, Val loss: 0.0013
Epoch 755 - Train loss: 0.0000, Val loss: 0.0013
Epoch 756 - Train loss: 0.0000, Val loss: 0.0013
Epoch 757 - Train loss: 0.0000, Val loss: 0.0013
Epoch 758 - Train loss: 0.0000, Val loss: 0.0013
Epoch 759 - Train loss: 0.0000, Val loss: 0.0013
Epoch 760 - Train loss: 0.0000, Val loss: 0.0013
Epoch 761 - Train loss: 0.0000, Val loss: 0.0013
Epoch 762 - Train loss: 0.0000, Val loss: 0.0013
Epoch 763 - Train loss: 0.0000, Val loss: 0.0013
Epoch 764 - Train loss: 0.0000, Val loss: 0.0013
Epoch 765 - Train loss: 0.0000, Val loss: 0.0013
Epoch 766 - Train loss: 0.0000, Val loss: 0.0013
Epoch 767 - Train loss: 0.0000, Val loss: 0.0013
Epoch 768 - Train loss: 0.0000, Val loss: 0.0013
Epoch 769 - Train loss: 0.0000, Val loss: 0.0013
Epoch 770 - Train loss: 0.0000, Val loss: 0.0013
Epoch 771 - Train loss: 0.0000, Val loss: 0.0013
Epoch 772 - Train loss: 0.0000, Val loss: 0.0013
Epoch 773 - Train lo

100%|██████████| 1000/1000 [00:01<00:00, 628.90it/s]

Epoch 885 - Train loss: 0.0000, Val loss: 0.0013
Epoch 886 - Train loss: 0.0000, Val loss: 0.0013
Epoch 887 - Train loss: 0.0000, Val loss: 0.0013
Epoch 888 - Train loss: 0.0000, Val loss: 0.0013
Epoch 889 - Train loss: 0.0000, Val loss: 0.0013
Epoch 890 - Train loss: 0.0000, Val loss: 0.0013
Epoch 891 - Train loss: 0.0000, Val loss: 0.0013
Epoch 892 - Train loss: 0.0000, Val loss: 0.0013
Epoch 893 - Train loss: 0.0000, Val loss: 0.0013
Epoch 894 - Train loss: 0.0000, Val loss: 0.0013
Epoch 895 - Train loss: 0.0000, Val loss: 0.0013
Epoch 896 - Train loss: 0.0000, Val loss: 0.0013
Epoch 897 - Train loss: 0.0000, Val loss: 0.0013
Epoch 898 - Train loss: 0.0000, Val loss: 0.0013
Epoch 899 - Train loss: 0.0000, Val loss: 0.0013
Epoch 900 - Train loss: 0.0000, Val loss: 0.0013
Epoch 901 - Train loss: 0.0000, Val loss: 0.0013
Epoch 902 - Train loss: 0.0000, Val loss: 0.0013
Epoch 903 - Train loss: 0.0000, Val loss: 0.0013
Epoch 904 - Train loss: 0.0000, Val loss: 0.0013
Epoch 905 - Train lo




### Overfit for label prediction

In [11]:
# --- Initialize tensorboard ---
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/rgcn-test/overfit-labelpred/{OVERFIT_BATCH_SIZE}-{now.strftime("%Y%m%d-%H%M%S")}')

# Reset the model
# num_bases = 6 # num of bases # TODO: experiment with this
h_channels_list = [10, 5] # list of hidden layer sizes

activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)

model = RGCNNet(
    in_channels=dim,
    h_channels_list=h_channels_list,
    out_channels=num_classes,
    num_relations=num_relations,
    # num_bases=num_bases,
    aggr="mean",
    dp_rate=0.1,
    bias=True,
    activation=activation
)

train_loader = DataLoader(train_batch, batch_size=OVERFIT_BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_batch, batch_size=OVERFIT_BATCH_SIZE, shuffle=False)

# --- Setup training loop ---
from tqdm import tqdm

epochs = 500 # num of epochs
best_loss = float('inf') # initialize best loss
best_epoch = 0 # initialize best epoch

# --- Initialize optimizer ---
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

lr = 0.01 # learning rate
weight_decay = 5e-4 # weight decay
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# --- Initialize loss function ---
loss_fn = CrossEntropyLoss()

# --- Initialize scheduler ---
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=10, min_lr=0.00001)

for epoch in (pbar:=tqdm(range(epochs))):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= len(train_loader)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    
    train_loss = epoch_loss
    
    # Scheduler listens to the train loss because we want to overfit here
    scheduler.step(epoch_loss)
    
    model.eval()
    epoch_loss = 0
    for batch in val_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        epoch_loss += loss.detach().item()
    epoch_loss /= len(val_loader)
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    
    val_loss = epoch_loss
    
    print(f"Epoch {epoch} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch
        torch.save(model.state_dict(), 'models/model.pth')
    
writer.flush()
writer.close()

 12%|█▏        | 59/500 [00:00<00:01, 293.92it/s]

Epoch 0 - Train loss: 1.8595, Val loss: 1.7822
Epoch 1 - Train loss: 1.8226, Val loss: 1.7457
Epoch 2 - Train loss: 1.6045, Val loss: 1.7179
Epoch 3 - Train loss: 1.6339, Val loss: 1.6939
Epoch 4 - Train loss: 1.4299, Val loss: 1.6765
Epoch 5 - Train loss: 1.4436, Val loss: 1.6638
Epoch 6 - Train loss: 1.4082, Val loss: 1.6533
Epoch 7 - Train loss: 1.3843, Val loss: 1.6466
Epoch 8 - Train loss: 1.2170, Val loss: 1.6422
Epoch 9 - Train loss: 1.2106, Val loss: 1.6393
Epoch 10 - Train loss: 1.1649, Val loss: 1.6376
Epoch 11 - Train loss: 1.0636, Val loss: 1.6373
Epoch 12 - Train loss: 1.0628, Val loss: 1.6396
Epoch 13 - Train loss: 0.9081, Val loss: 1.6425
Epoch 14 - Train loss: 1.0018, Val loss: 1.6464
Epoch 15 - Train loss: 0.8516, Val loss: 1.6512
Epoch 16 - Train loss: 0.8860, Val loss: 1.6594
Epoch 17 - Train loss: 0.8681, Val loss: 1.6706
Epoch 18 - Train loss: 0.7134, Val loss: 1.6839
Epoch 19 - Train loss: 0.7898, Val loss: 1.7000
Epoch 20 - Train loss: 0.7716, Val loss: 1.7173
Ep

 24%|██▍       | 121/500 [00:00<00:01, 301.76it/s]

Epoch 59 - Train loss: 0.2017, Val loss: 2.7131
Epoch 60 - Train loss: 0.2779, Val loss: 2.7039
Epoch 61 - Train loss: 0.2599, Val loss: 2.6977
Epoch 62 - Train loss: 0.2055, Val loss: 2.6951
Epoch 63 - Train loss: 0.1968, Val loss: 2.6928
Epoch 64 - Train loss: 0.2487, Val loss: 2.6934
Epoch 65 - Train loss: 0.2897, Val loss: 2.6976
Epoch 66 - Train loss: 0.2815, Val loss: 2.7059
Epoch 67 - Train loss: 0.2140, Val loss: 2.7199
Epoch 68 - Train loss: 0.3272, Val loss: 2.7430
Epoch 69 - Train loss: 0.2325, Val loss: 2.7656
Epoch 70 - Train loss: 0.2717, Val loss: 2.7953
Epoch 71 - Train loss: 0.2557, Val loss: 2.8284
Epoch 72 - Train loss: 0.2722, Val loss: 2.8676
Epoch 73 - Train loss: 0.2365, Val loss: 2.8962
Epoch 74 - Train loss: 0.1852, Val loss: 2.9230
Epoch 75 - Train loss: 0.1993, Val loss: 2.9495
Epoch 76 - Train loss: 0.1568, Val loss: 2.9749
Epoch 77 - Train loss: 0.2350, Val loss: 2.9975
Epoch 78 - Train loss: 0.2439, Val loss: 3.0183
Epoch 79 - Train loss: 0.2457, Val loss:

 31%|███       | 153/500 [00:00<00:01, 306.62it/s]

Epoch 121 - Train loss: 0.0843, Val loss: 3.6959
Epoch 122 - Train loss: 0.2108, Val loss: 3.7105
Epoch 123 - Train loss: 0.2047, Val loss: 3.7347
Epoch 124 - Train loss: 0.1378, Val loss: 3.7514
Epoch 125 - Train loss: 0.1188, Val loss: 3.7672
Epoch 126 - Train loss: 0.1715, Val loss: 3.7576
Epoch 127 - Train loss: 0.1407, Val loss: 3.7529
Epoch 128 - Train loss: 0.0878, Val loss: 3.7613
Epoch 129 - Train loss: 0.0872, Val loss: 3.7806
Epoch 130 - Train loss: 0.0616, Val loss: 3.7990
Epoch 131 - Train loss: 0.0576, Val loss: 3.8169
Epoch 132 - Train loss: 0.0749, Val loss: 3.8275
Epoch 133 - Train loss: 0.0920, Val loss: 3.8431
Epoch 134 - Train loss: 0.1857, Val loss: 3.8585
Epoch 135 - Train loss: 0.0494, Val loss: 3.8733
Epoch 136 - Train loss: 0.0671, Val loss: 3.8873
Epoch 137 - Train loss: 0.0813, Val loss: 3.9032
Epoch 138 - Train loss: 0.1729, Val loss: 3.9157
Epoch 139 - Train loss: 0.0366, Val loss: 3.9271
Epoch 140 - Train loss: 0.0382, Val loss: 3.9372
Epoch 141 - Train lo

 43%|████▎     | 214/500 [00:00<00:01, 282.16it/s]

Epoch 183 - Train loss: 0.1138, Val loss: 4.5079
Epoch 184 - Train loss: 0.0305, Val loss: 4.4999
Epoch 185 - Train loss: 0.1706, Val loss: 4.4926
Epoch 186 - Train loss: 0.0297, Val loss: 4.4861
Epoch 187 - Train loss: 0.0285, Val loss: 4.4807
Epoch 188 - Train loss: 0.0249, Val loss: 4.4766
Epoch 189 - Train loss: 0.2080, Val loss: 4.4725
Epoch 190 - Train loss: 0.1805, Val loss: 4.4675
Epoch 191 - Train loss: 0.0237, Val loss: 4.4631
Epoch 192 - Train loss: 0.0689, Val loss: 4.4596
Epoch 193 - Train loss: 0.1105, Val loss: 4.4553
Epoch 194 - Train loss: 0.0580, Val loss: 4.4525
Epoch 195 - Train loss: 0.0986, Val loss: 4.4540
Epoch 196 - Train loss: 0.0354, Val loss: 4.4576
Epoch 197 - Train loss: 0.0469, Val loss: 4.4618
Epoch 198 - Train loss: 0.0865, Val loss: 4.4719
Epoch 199 - Train loss: 0.0235, Val loss: 4.4818
Epoch 200 - Train loss: 0.0310, Val loss: 4.4922
Epoch 201 - Train loss: 0.0363, Val loss: 4.5017
Epoch 202 - Train loss: 0.0336, Val loss: 4.5119
Epoch 203 - Train lo

 56%|█████▌    | 278/500 [00:00<00:00, 298.40it/s]

Epoch 240 - Train loss: 0.0293, Val loss: 4.6454
Epoch 241 - Train loss: 0.0254, Val loss: 4.6481
Epoch 242 - Train loss: 0.0408, Val loss: 4.6501
Epoch 243 - Train loss: 0.0245, Val loss: 4.6518
Epoch 244 - Train loss: 0.0251, Val loss: 4.6533
Epoch 245 - Train loss: 0.0238, Val loss: 4.6544
Epoch 246 - Train loss: 0.0139, Val loss: 4.6554
Epoch 247 - Train loss: 0.0263, Val loss: 4.6575
Epoch 248 - Train loss: 0.0135, Val loss: 4.6594
Epoch 249 - Train loss: 0.0332, Val loss: 4.6624
Epoch 250 - Train loss: 0.0189, Val loss: 4.6659
Epoch 251 - Train loss: 0.1708, Val loss: 4.6682
Epoch 252 - Train loss: 0.0168, Val loss: 4.6703
Epoch 253 - Train loss: 0.0187, Val loss: 4.6725
Epoch 254 - Train loss: 0.0485, Val loss: 4.6748
Epoch 255 - Train loss: 0.1625, Val loss: 4.6767
Epoch 256 - Train loss: 0.0223, Val loss: 4.6785
Epoch 257 - Train loss: 0.0287, Val loss: 4.6800
Epoch 258 - Train loss: 0.0153, Val loss: 4.6812
Epoch 259 - Train loss: 0.0238, Val loss: 4.6824
Epoch 260 - Train lo

 68%|██████▊   | 341/500 [00:01<00:00, 302.97it/s]

Epoch 304 - Train loss: 0.0244, Val loss: 4.6713
Epoch 305 - Train loss: 0.0330, Val loss: 4.6708
Epoch 306 - Train loss: 0.0351, Val loss: 4.6703
Epoch 307 - Train loss: 0.0115, Val loss: 4.6701
Epoch 308 - Train loss: 0.0406, Val loss: 4.6700
Epoch 309 - Train loss: 0.0585, Val loss: 4.6702
Epoch 310 - Train loss: 0.1620, Val loss: 4.6704
Epoch 311 - Train loss: 0.0122, Val loss: 4.6705
Epoch 312 - Train loss: 0.0090, Val loss: 4.6707
Epoch 313 - Train loss: 0.0138, Val loss: 4.6709
Epoch 314 - Train loss: 0.0342, Val loss: 4.6712
Epoch 315 - Train loss: 0.0518, Val loss: 4.6718
Epoch 316 - Train loss: 0.0588, Val loss: 4.6724
Epoch 317 - Train loss: 0.0402, Val loss: 4.6730
Epoch 318 - Train loss: 0.0106, Val loss: 4.6735
Epoch 319 - Train loss: 0.0288, Val loss: 4.6740
Epoch 320 - Train loss: 0.1639, Val loss: 4.6745
Epoch 321 - Train loss: 0.0367, Val loss: 4.6747
Epoch 322 - Train loss: 0.0285, Val loss: 4.6748
Epoch 323 - Train loss: 0.0209, Val loss: 4.6750
Epoch 324 - Train lo

 81%|████████  | 405/500 [00:01<00:00, 308.82it/s]

Epoch 367 - Train loss: 0.0509, Val loss: 4.6907
Epoch 368 - Train loss: 0.0558, Val loss: 4.6910
Epoch 369 - Train loss: 0.0287, Val loss: 4.6914
Epoch 370 - Train loss: 0.1559, Val loss: 4.6917
Epoch 371 - Train loss: 0.0613, Val loss: 4.6921
Epoch 372 - Train loss: 0.0460, Val loss: 4.6926
Epoch 373 - Train loss: 0.2206, Val loss: 4.6931
Epoch 374 - Train loss: 0.0282, Val loss: 4.6935
Epoch 375 - Train loss: 0.0110, Val loss: 4.6940
Epoch 376 - Train loss: 0.0813, Val loss: 4.6945
Epoch 377 - Train loss: 0.0100, Val loss: 4.6948
Epoch 378 - Train loss: 0.0134, Val loss: 4.6951
Epoch 379 - Train loss: 0.0512, Val loss: 4.6954
Epoch 380 - Train loss: 0.0160, Val loss: 4.6956
Epoch 381 - Train loss: 0.0189, Val loss: 4.6959
Epoch 382 - Train loss: 0.0428, Val loss: 4.6961
Epoch 383 - Train loss: 0.0281, Val loss: 4.6963
Epoch 384 - Train loss: 0.0381, Val loss: 4.6965
Epoch 385 - Train loss: 0.0288, Val loss: 4.6967
Epoch 386 - Train loss: 0.0104, Val loss: 4.6969
Epoch 387 - Train lo

 94%|█████████▍| 469/500 [00:01<00:00, 311.70it/s]

Epoch 431 - Train loss: 0.0508, Val loss: 4.6983
Epoch 432 - Train loss: 0.0092, Val loss: 4.6983
Epoch 433 - Train loss: 0.0124, Val loss: 4.6983
Epoch 434 - Train loss: 0.1279, Val loss: 4.6983
Epoch 435 - Train loss: 0.0305, Val loss: 4.6983
Epoch 436 - Train loss: 0.1483, Val loss: 4.6982
Epoch 437 - Train loss: 0.0263, Val loss: 4.6982
Epoch 438 - Train loss: 0.0315, Val loss: 4.6982
Epoch 439 - Train loss: 0.0593, Val loss: 4.6983
Epoch 440 - Train loss: 0.0398, Val loss: 4.6983
Epoch 441 - Train loss: 0.0505, Val loss: 4.6983
Epoch 442 - Train loss: 0.0171, Val loss: 4.6983
Epoch 443 - Train loss: 0.0247, Val loss: 4.6983
Epoch 444 - Train loss: 0.0242, Val loss: 4.6984
Epoch 445 - Train loss: 0.0186, Val loss: 4.6984
Epoch 446 - Train loss: 0.0305, Val loss: 4.6984
Epoch 447 - Train loss: 0.1027, Val loss: 4.6984
Epoch 448 - Train loss: 0.0287, Val loss: 4.6984
Epoch 449 - Train loss: 0.0236, Val loss: 4.6985
Epoch 450 - Train loss: 0.0158, Val loss: 4.6985
Epoch 451 - Train lo

100%|██████████| 500/500 [00:01<00:00, 302.94it/s]

Epoch 495 - Train loss: 0.0767, Val loss: 4.6987
Epoch 496 - Train loss: 0.0205, Val loss: 4.6987
Epoch 497 - Train loss: 0.0225, Val loss: 4.6988
Epoch 498 - Train loss: 0.0193, Val loss: 4.6988
Epoch 499 - Train loss: 0.0113, Val loss: 4.6988



