In [12]:
import torch
from torch.utils.tensorboard import SummaryWriter

Definition of a graph `Data(x=[20, 5], edge_index=[2, 20], edge_attr=[20], y=[20])` (A graph with 20 nodes, 20 edges, 6 different edge types and 1 target label per node) is decomposed into

- x: Data stored in nodes (here, num_nodes=20 nodes x dim=5 features):

```
tensor([[ 0.1338,  0.8443, -1.8535,  1.1781, -1.4325],
        [ 0.6020,  0.9776,  0.3087, -0.5191, -0.8351],
        [-0.1933,  0.7537,  0.5730, -1.6758,  0.7305],
        [-0.0866,  0.0894,  1.5008, -0.7608,  0.7714],
        [-0.1519,  1.0402, -0.8959,  0.1330, -0.1550],
        [ 0.0830,  0.4799, -0.5175,  2.2978, -0.9849],
        [ 0.2736,  1.2479, -0.0293, -0.2829,  1.9557],
        [-0.9310,  0.2792, -0.5623, -0.4188, -0.5580],
        [-0.0870,  0.6735,  0.7130, -0.7389,  1.3117],
        [-0.5930, -1.0240, -0.6158,  0.1470, -0.7293],
        [ 0.4986, -1.7924,  0.9244,  1.1364,  1.2943],
        [ 0.5379,  0.4017,  0.3074, -0.8189,  0.5053],
        [ 2.5506,  0.5360, -0.0889, -0.2462, -0.9926],
        [ 0.5416, -0.1752, -0.3871, -0.2850,  0.3497],
        [-0.3282,  1.2847, -0.3605, -0.6254,  0.5127],
        [ 0.1769, -2.4991,  1.4926, -0.4233, -0.9469],
        [ 1.2659, -1.5897, -0.6909, -0.0036,  1.1759],
        [-0.4856,  0.3826, -0.6284, -0.4664,  0.0383],
        [ 0.2546,  1.5097,  1.2746, -0.5847,  1.1207],
        [ 0.9572,  0.4619, -0.4879, -0.4740, -0.1182]])
```

- edge_index: Encodes connections between nodes. It's always (2,num_edges) shape, the first list is source ids and the target ids for connections (here, num_edges=20):

```
tensor([[ 8,  5, 14,  5,  4, 14, 15,  4, 17, 13, 12, 15, 13, 11, 10, 12, 19, 0, 19, 16],
        [ 5,  8, 15,  0, 13,  0,  2, 17,  9, 16,  3, 12, 18,  4,  5, 18, 12, 8, 16, 10]])
```

- edge_attr: Encodes types of each edge (since we deal with multiple different edge types in RGCN). The shape is therefore (num_edges, num_edge_features=1) here with values [0, num_relations-1].

```
tensor([3, 2, 0, 3, 4, 5, 2, 1, 5, 3, 5, 0, 0, 4, 2, 2, 4, 1, 1, 5])
```

- y: target labels with (n_labels=num_nodes,label_dim=1) shape with n_labels=num_nodes since we do node-level labels (graph-level single prediction is also possible) and label_dim=1 since in this toy example we want to have classification (one value per node indicating target class).

```
tensor([1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0])
```

In [13]:
# Create a mock graph dataset to test how graph data works
from torch_geometric.data import Data

num_nodes = 20 # num of nodes
dim = 5 # dim of node features
num_edges = 22 # num of edges
num_classes = 2 # num of target classes
num_relations = 6 # num of edge types

def generate_random_graph():
    # --- Initialize nodes ---
    x = torch.randn(num_nodes, dim) # creates num_nodes x dim tensor of (random) node features <- TODO: replace with actual node features

    # --- Initialize edges --- 
    edge_index = torch.randint(num_nodes, (2, num_edges)) # creates 2 x num_edges tensor of (random) edges <- TODO: replace with actual edges

    # --- Introduce different types of edges ---
    edge_attr = torch.randint(num_relations, (num_edges,)) # creates num_edges x 1 tensor of (random) edge types <- TODO: replace with actual edge types

    # --- Initialize labels ---
    y = torch.randint(num_classes, (num_nodes,)) # creates num_nodes x 1 tensor of (random) labels <- TODO: replace with actual labels

    # --- Create a graph ---
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

    return graph

graph = generate_random_graph()
print(graph)

Data(x=[20, 5], edge_index=[2, 22], edge_attr=[22], y=[20])


In [14]:
import torch_geometric.transforms as T
from torch_geometric.nn import RGCNConv
from torch import nn

class RGCNNet(torch.nn.Module):
    def __init__(
        self, 
        in_channels,
        h_channels_list,
        out_channels, 
        num_relations,
        num_bases=None, 
        aggr='mean',
        dp_rate=0.1, 
        bias=True
    ):
        super(RGCNNet, self).__init__()
        self.num_layers = len(h_channels_list) + 1
        self.layers = []
        
        for i in range(self.num_layers - 1):
            in_channels = in_channels if i == 0 else h_channels_list[i - 1]
            out_channels = h_channels_list[i]
            self.layers += [
                RGCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    num_relations=num_relations,
                    num_bases=num_bases,
                    aggr=aggr,
                    bias=bias
                ),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dp_rate)
            ]
        self.layers += [
            RGCNConv(
                in_channels=h_channels_list[-1],
                out_channels=out_channels,
                num_relations=num_relations,
                num_bases=num_bases,
                aggr=aggr,
                bias=bias
            ),
            nn.ReLU(inplace=True) # TODO: other final activation function here?
        ]
        
        self.layers = torch.nn.ModuleList(self.layers)
        
    def forward(self, x, edge_index, edge_type):
        for layer in self.layers:
            if isinstance(layer, RGCNConv):
                x = layer(x, edge_index, edge_type)
            else:
                x = layer(x)
        return x


In [15]:
from torch.nn import CrossEntropyLoss

# --- Initialize model ---
# num_bases = 6 # num of bases # TODO: experiment with this
h_channels_list = [10, 5] # list of hidden layer sizes

model = RGCNNet(
    in_channels=dim,
    h_channels_list=h_channels_list,
    out_channels=num_classes,
    num_relations=num_relations,
    # num_bases=num_bases,
    aggr="mean",
    dp_rate=0.1,
    bias=True
)

In [16]:
# --- Initialize optimizer ---
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

lr = 0.01 # learning rate
weight_decay = 5e-4 # weight decay
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# --- Initialize loss function ---
loss_fn = CrossEntropyLoss()

# --- Initialize scheduler ---
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001)

In [17]:
# --- Initialize train and val loaders ---
from torch_geometric.loader import DataLoader

OVERFIT_BATCH_SIZE = 20
train_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]
val_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]

train_loader = DataLoader(train_batch, batch_size=4, shuffle=True)
val_loader = DataLoader(val_batch, batch_size=4, shuffle=False)

In [18]:
# --- Initialize tensorboard ---
writer = SummaryWriter(log_dir=f'runs/rgcn-overfit-{OVERFIT_BATCH_SIZE}')

# --- Setup training loop ---
from tqdm import tqdm

epochs = 200 # num of epochs
best_loss = float('inf') # initialize best loss
best_epoch = 0 # initialize best epoch

for epoch in (pbar:=tqdm(range(epochs))):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= len(train_loader)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    
    train_loss = epoch_loss
    
    model.eval()
    epoch_loss = 0
    for batch in val_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        epoch_loss += loss.detach().item()
    epoch_loss /= len(val_loader)
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    
    val_loss = epoch_loss
    
    print(f"Epoch {epoch} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    # pbar.set_description(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    scheduler.step(epoch_loss)
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch
        torch.save(model.state_dict(), 'models/model.pth')
    
writer.flush()
writer.close()

  0%|          | 0/200 [00:00<?, ?it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 0 - Train loss: 1.6910, Val loss: 1.5972
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of bat

  2%|▎         | 5/200 [00:00<00:03, 49.42it/s]

Epoch 3 - Train loss: 1.3861, Val loss: 1.4258
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 4 - Train loss: 1.3297, Val loss: 1.3685
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batc

  6%|▌         | 12/200 [00:00<00:03, 58.55it/s]

Epoch 11 - Train loss: 0.7759, Val loss: 0.9423
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 12 - Train loss: 0.7164, Val loss: 0.9023
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 10%|▉         | 19/200 [00:00<00:02, 60.64it/s]

Epoch 16 - Train loss: 0.5884, Val loss: 0.8565
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 17 - Train loss: 0.6127, Val loss: 0.8658
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 13%|█▎        | 26/200 [00:00<00:03, 57.82it/s]

Epoch 23 - Train loss: 0.5588, Val loss: 0.8768
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 24 - Train loss: 0.5371, Val loss: 0.8819
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 16%|█▌        | 32/200 [00:00<00:02, 57.96it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 28 - Train loss: 0.5124, Val loss: 0.9336
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 29 - Train

 19%|█▉        | 38/200 [00:00<00:02, 56.92it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 35 - Train loss: 0.4963, Val loss: 0.9580
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 36 - Train loss: 0.5006, Val loss: 0.9708
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 22%|██▏       | 44/200 [00:00<00:02, 57.51it/s]

Epoch 39 - Train loss: 0.4670, Val loss: 1.0110
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 40 - Train loss: 0.4768, Val loss: 1.0279
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 25%|██▌       | 50/200 [00:00<00:02, 55.17it/s]

Epoch 46 - Train loss: 0.4904, Val loss: 1.0440
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 47 - Train loss: 0.4526, Val loss: 1.0398
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 28%|██▊       | 56/200 [00:01<00:02, 52.51it/s]

Epoch 50 - Train loss: 0.4511, Val loss: 1.0492
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 51 - Train loss: 0.4675, Val loss: 1.0521
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 31%|███       | 62/200 [00:01<00:02, 52.90it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 61 - Train loss: 0.4465, Val loss: 1.0604
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 62 - Train

 34%|███▍      | 69/200 [00:01<00:02, 55.02it/s]

Epoch 68 - Train loss: 0.4538, Val loss: 1.0659
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 69 - Train loss: 0.4481, Val loss: 1.0665
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 38%|███▊      | 75/200 [00:01<00:02, 55.11it/s]

Epoch 72 - Train loss: 0.4466, Val loss: 1.0672
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 73 - Train loss: 0.4368, Val loss: 1.0670
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 40%|████      | 81/200 [00:01<00:02, 54.93it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 80 - Train loss: 0.4996, Val loss: 1.0714
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of ba

 44%|████▎     | 87/200 [00:01<00:02, 55.95it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 84 - Train loss: 0.4621, Val loss: 1.0719
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 85 - Train loss: 0.4616, Val loss: 1.0721
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 46%|████▋     | 93/200 [00:01<00:01, 55.91it/s]

Epoch 91 - Train loss: 0.4431, Val loss: 1.0725
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 92 - Train loss: 0.4338, Val loss: 1.0728
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 50%|████▉     | 99/200 [00:01<00:01, 56.65it/s]

Epoch 95 - Train loss: 0.4514, Val loss: 1.0735
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 96 - Train loss: 0.4526, Val loss: 1.0737
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), ba

 53%|█████▎    | 106/200 [00:01<00:01, 57.67it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 104 - Train loss: 0.4143, Val loss: 1.0746
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of b

 56%|█████▌    | 112/200 [00:01<00:01, 58.22it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 108 - Train loss: 0.4477, Val loss: 1.0747
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of b

 60%|█████▉    | 119/200 [00:02<00:01, 59.10it/s]

Epoch 115 - Train loss: 0.4897, Val loss: 1.0752
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 116 - Train loss: 0.4354, Val loss: 1.0754
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 62%|██████▎   | 125/200 [00:02<00:01, 59.29it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 120 - Train loss: 0.4339, Val loss: 1.0760
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of b

 66%|██████▌   | 132/200 [00:02<00:01, 59.93it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 128 - Train loss: 0.4624, Val loss: 1.0764
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 129 - Tra

 70%|██████▉   | 139/200 [00:02<00:01, 60.34it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 132 - Train loss: 0.4576, Val loss: 1.0764
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 133 - Train loss: 0.4301, Val loss: 1.0764
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 76%|███████▋  | 153/200 [00:02<00:00, 61.58it/s]

Epoch 144 - Train loss: 0.4372, Val loss: 1.0765
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 145 - Train loss: 0.4518, Val loss: 1.0765
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 80%|████████  | 160/200 [00:02<00:00, 61.35it/s]

Epoch 157 - Train loss: 0.4581, Val loss: 1.0768
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 158 - Train loss: 0.4403, Val loss: 1.0768
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 84%|████████▎ | 167/200 [00:02<00:00, 60.95it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 165 - Train loss: 0.4659, Val loss: 1.0768
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 166 - Train loss: 0.4642, Val loss: 1.0768
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 87%|████████▋ | 174/200 [00:03<00:00, 60.46it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 170 - Train loss: 0.4224, Val loss: 1.0769
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of b

 90%|█████████ | 181/200 [00:03<00:00, 60.23it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 177 - Train loss: 0.4474, Val loss: 1.0770
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 178 - Train loss: 0.4790, Val loss: 1.0770
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 

 94%|█████████▍| 188/200 [00:03<00:00, 60.90it/s]

Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 182 - Train loss: 0.4757, Val loss: 1.0770
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of b

100%|██████████| 200/200 [00:03<00:00, 58.33it/s]

Epoch 194 - Train loss: 0.4676, Val loss: 1.0772
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Epoch 195 - Train loss: 0.4198, Val loss: 1.0773
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), batch.edge_attr: torch.Size([88])
Dimensions of batch.x: torch.Size([80, 5]), batch.edge_index: torch.Size([2, 88]), 


