In [62]:
import torch
from torch.utils.tensorboard import SummaryWriter

Definition of a graph `Data(x=[20, 5], edge_index=[2, 20], edge_attr=[20], y=[20])` (A graph with 20 nodes, 20 edges, 6 different edge types and 1 target label per node) is decomposed into

- x: Data stored in nodes (here, num_nodes=20 nodes x dim=5 features):

```
tensor([[ 0.1338,  0.8443, -1.8535,  1.1781, -1.4325],
        [ 0.6020,  0.9776,  0.3087, -0.5191, -0.8351],
        [-0.1933,  0.7537,  0.5730, -1.6758,  0.7305],
        [-0.0866,  0.0894,  1.5008, -0.7608,  0.7714],
        [-0.1519,  1.0402, -0.8959,  0.1330, -0.1550],
        [ 0.0830,  0.4799, -0.5175,  2.2978, -0.9849],
        [ 0.2736,  1.2479, -0.0293, -0.2829,  1.9557],
        [-0.9310,  0.2792, -0.5623, -0.4188, -0.5580],
        [-0.0870,  0.6735,  0.7130, -0.7389,  1.3117],
        [-0.5930, -1.0240, -0.6158,  0.1470, -0.7293],
        [ 0.4986, -1.7924,  0.9244,  1.1364,  1.2943],
        [ 0.5379,  0.4017,  0.3074, -0.8189,  0.5053],
        [ 2.5506,  0.5360, -0.0889, -0.2462, -0.9926],
        [ 0.5416, -0.1752, -0.3871, -0.2850,  0.3497],
        [-0.3282,  1.2847, -0.3605, -0.6254,  0.5127],
        [ 0.1769, -2.4991,  1.4926, -0.4233, -0.9469],
        [ 1.2659, -1.5897, -0.6909, -0.0036,  1.1759],
        [-0.4856,  0.3826, -0.6284, -0.4664,  0.0383],
        [ 0.2546,  1.5097,  1.2746, -0.5847,  1.1207],
        [ 0.9572,  0.4619, -0.4879, -0.4740, -0.1182]])
```

- edge_index: Encodes connections between nodes. It's always (2,num_edges) shape, the first list is source ids and the target ids for connections (here, num_edges=20):

```
tensor([[ 8,  5, 14,  5,  4, 14, 15,  4, 17, 13, 12, 15, 13, 11, 10, 12, 19, 0, 19, 16],
        [ 5,  8, 15,  0, 13,  0,  2, 17,  9, 16,  3, 12, 18,  4,  5, 18, 12, 8, 16, 10]])
```

- edge_attr: Encodes types of each edge (since we deal with multiple different edge types in RGCN). The shape is therefore (num_edges, num_edge_features=1) here with values [0, num_relations-1].

```
tensor([3, 2, 0, 3, 4, 5, 2, 1, 5, 3, 5, 0, 0, 4, 2, 2, 4, 1, 1, 5])
```

- y: target labels with (n_labels=num_nodes,label_dim=1) shape with n_labels=num_nodes since we do node-level labels (graph-level single prediction is also possible) and label_dim=1 since in this toy example we want to have classification (one value per node indicating target class).

```
tensor([1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0])
```

In [63]:
# Create a mock graph dataset to test how graph data works
from torch_geometric.data import Data

num_nodes = 20 # num of nodes
dim = 5 # dim of node features
num_edges = 25 # num of edges
num_classes = 2 # num of target classes
num_relations = 6 # num of edge types

def generate_random_graph():
    # --- Initialize nodes ---
    x = torch.randn(num_nodes, dim) # creates num_nodes x dim tensor of (random) node features <- TODO: replace with actual node features

    # --- Initialize edges --- 
    edge_index = torch.randint(num_nodes, (2, num_edges)) # creates 2 x num_edges tensor of (random) edges <- TODO: replace with actual edges

    # --- Introduce different types of edges ---
    edge_attr = torch.randint(num_relations, (num_edges,)) # creates num_edges x 1 tensor of (random) edge types <- TODO: replace with actual edge types

    # --- Initialize labels ---
    y = torch.randint(num_classes, (num_nodes,)) # creates num_nodes x 1 tensor of (random) labels <- TODO: replace with actual labels

    # --- Create a graph ---
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

    return graph

graph = generate_random_graph()
print(graph)

Data(x=[20, 5], edge_index=[2, 25], edge_attr=[25], y=[20])


In [64]:
import torch_geometric.transforms as T
from torch_geometric.nn import RGCNConv
from torch import nn

class RGCNNet(torch.nn.Module):
    def __init__(
        self, 
        in_channels,
        h_channels_list,
        out_channels, 
        num_relations,
        num_bases=None, 
        aggr='mean',
        dp_rate=0.1, 
        bias=True
    ):
        super(RGCNNet, self).__init__()
        self.num_layers = len(h_channels_list) + 1
        self.layers = []
        
        for i in range(self.num_layers - 1):
            in_channels = in_channels if i == 0 else h_channels_list[i - 1]
            out_channels = h_channels_list[i]
            self.layers += [
                RGCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    num_relations=num_relations,
                    num_bases=num_bases,
                    aggr=aggr,
                    bias=bias
                ),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dp_rate)
            ]
        self.layers += [
            RGCNConv(
                in_channels=h_channels_list[-1],
                out_channels=out_channels,
                num_relations=num_relations,
                num_bases=num_bases,
                aggr=aggr,
                bias=bias
            ),
            nn.ReLU(inplace=True) # TODO: other final activation function here?
        ]
        
        self.layers = torch.nn.ModuleList(self.layers)
        
    def forward(self, x, edge_index, edge_type):
        for layer in self.layers:
            if isinstance(layer, RGCNConv):
                x = layer(x, edge_index, edge_type)
            else:
                x = layer(x)
        return x


In [65]:
from torch.nn import CrossEntropyLoss

# --- Initialize model ---
# num_bases = 6 # num of bases # TODO: experiment with this
h_channels_list = [10, 5] # list of hidden layer sizes

model = RGCNNet(
    in_channels=dim,
    h_channels_list=h_channels_list,
    out_channels=num_classes,
    num_relations=num_relations,
    # num_bases=num_bases,
    aggr="mean",
    dp_rate=0.1,
    bias=True
)

In [66]:
# --- Initialize optimizer ---
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

lr = 0.01 # learning rate
weight_decay = 5e-4 # weight decay
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# --- Initialize loss function ---
loss_fn = CrossEntropyLoss()

# --- Initialize scheduler ---
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001)

In [67]:
# --- Initialize train and val loaders ---
from torch_geometric.loader import DataLoader

OVERFIT_BATCH_SIZE = 20
train_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]
val_batch = [generate_random_graph() for _ in range(OVERFIT_BATCH_SIZE)]

train_loader = DataLoader(train_batch, batch_size=5, shuffle=True)
val_loader = DataLoader(val_batch, batch_size=5, shuffle=False)

In [68]:
# --- Initialize tensorboard ---
writer = SummaryWriter(log_dir=f'runs/rgcn-overfit-{OVERFIT_BATCH_SIZE}')

# --- Setup training loop ---
from tqdm import tqdm

epochs = 200 # num of epochs
best_loss = float('inf') # initialize best loss
best_epoch = 0 # initialize best epoch

for epoch in (pbar:=tqdm(range(epochs))):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= len(train_loader)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    
    train_loss = epoch_loss
    
    model.eval()
    epoch_loss = 0
    for batch in val_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = loss_fn(out, batch.y)
        epoch_loss += loss.detach().item()
    epoch_loss /= len(val_loader)
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    
    val_loss = epoch_loss
    
    print(f"Epoch {epoch} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    # pbar.set_description(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    scheduler.step(epoch_loss)
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch
        torch.save(model.state_dict(), 'models/model.pth')
    
writer.flush()
writer.close()

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 0 - Train loss: 1.6954, Val loss: 1.6175
Epoch 1 - Train loss: 1.5185, Val loss: 1.4984
Epoch 2 - Train loss: 1.3944, Val loss: 1.4113
Epoch 3 - Train loss: 1.3103, Val loss: 1.3372
Epoch 4 - Train loss: 1.2327, Val loss: 1.2723
Epoch 5 - Train loss: 1.1790, Val loss: 1.2205


  4%|▍         | 8/200 [00:00<00:02, 72.71it/s]

Epoch 6 - Train loss: 1.1050, Val loss: 1.1762
Epoch 7 - Train loss: 1.0577, Val loss: 1.1343
Epoch 8 - Train loss: 0.9797, Val loss: 1.0924
Epoch 9 - Train loss: 0.9194, Val loss: 1.0509
Epoch 10 - Train loss: 0.9050, Val loss: 1.0101
Epoch 11 - Train loss: 0.8300, Val loss: 0.9787
Epoch 12 - Train loss: 0.7962, Val loss: 0.9537
Epoch 13 - Train loss: 0.7689, Val loss: 0.9327


  8%|▊         | 16/200 [00:00<00:02, 67.75it/s]

Epoch 14 - Train loss: 0.7375, Val loss: 0.9187
Epoch 15 - Train loss: 0.7063, Val loss: 0.9115
Epoch 16 - Train loss: 0.6942, Val loss: 0.8956
Epoch 17 - Train loss: 0.6755, Val loss: 0.8881
Epoch 18 - Train loss: 0.6499, Val loss: 0.8834
Epoch 19 - Train loss: 0.6448, Val loss: 0.8714


 12%|█▏        | 24/200 [00:00<00:02, 70.58it/s]

Epoch 20 - Train loss: 0.6186, Val loss: 0.8692
Epoch 21 - Train loss: 0.6363, Val loss: 0.8665
Epoch 22 - Train loss: 0.6113, Val loss: 0.8617
Epoch 23 - Train loss: 0.5908, Val loss: 0.8597
Epoch 24 - Train loss: 0.5952, Val loss: 0.8670
Epoch 25 - Train loss: 0.5655, Val loss: 0.8770
Epoch 26 - Train loss: 0.5554, Val loss: 0.8910
Epoch 27 - Train loss: 0.5566, Val loss: 0.8997
Epoch 28 - Train loss: 0.5505, Val loss: 0.9076
Epoch 29 - Train loss: 0.5494, Val loss: 0.9279


 16%|█▌        | 32/200 [00:00<00:02, 73.94it/s]

Epoch 30 - Train loss: 0.5622, Val loss: 0.9344
Epoch 31 - Train loss: 0.5363, Val loss: 0.9377
Epoch 32 - Train loss: 0.5413, Val loss: 0.9419
Epoch 33 - Train loss: 0.5079, Val loss: 0.9419
Epoch 34 - Train loss: 0.5065, Val loss: 0.9497
Epoch 35 - Train loss: 0.5232, Val loss: 0.9592


 20%|██        | 40/200 [00:00<00:02, 75.88it/s]

Epoch 36 - Train loss: 0.5141, Val loss: 0.9688
Epoch 37 - Train loss: 0.5072, Val loss: 0.9800
Epoch 38 - Train loss: 0.5133, Val loss: 0.9859
Epoch 39 - Train loss: 0.5265, Val loss: 0.9930
Epoch 40 - Train loss: 0.5251, Val loss: 0.9890
Epoch 41 - Train loss: 0.5276, Val loss: 0.9851
Epoch 42 - Train loss: 0.5457, Val loss: 0.9851
Epoch 43 - Train loss: 0.5126, Val loss: 0.9863
Epoch 44 - Train loss: 0.4964, Val loss: 0.9906
Epoch 45 - Train loss: 0.4983, Val loss: 0.9975
Epoch 46 - Train loss: 0.5036, Val loss: 1.0050


 24%|██▍       | 48/200 [00:00<00:01, 77.18it/s]

Epoch 47 - Train loss: 0.5030, Val loss: 1.0136
Epoch 48 - Train loss: 0.4767, Val loss: 1.0204
Epoch 49 - Train loss: 0.4723, Val loss: 1.0249
Epoch 50 - Train loss: 0.5320, Val loss: 1.0272
Epoch 51 - Train loss: 0.4709, Val loss: 1.0324
Epoch 52 - Train loss: 0.4799, Val loss: 1.0400


 28%|██▊       | 56/200 [00:00<00:01, 77.66it/s]

Epoch 53 - Train loss: 0.4789, Val loss: 1.0441
Epoch 54 - Train loss: 0.4792, Val loss: 1.0466
Epoch 55 - Train loss: 0.4738, Val loss: 1.0487
Epoch 56 - Train loss: 0.4913, Val loss: 1.0493
Epoch 57 - Train loss: 0.4746, Val loss: 1.0495
Epoch 58 - Train loss: 0.4557, Val loss: 1.0502
Epoch 59 - Train loss: 0.4699, Val loss: 1.0506
Epoch 60 - Train loss: 0.4674, Val loss: 1.0513
Epoch 61 - Train loss: 0.4520, Val loss: 1.0520
Epoch 62 - Train loss: 0.4804, Val loss: 1.0515
Epoch 63 - Train loss: 0.4645, Val loss: 1.0511


 32%|███▎      | 65/200 [00:00<00:01, 78.80it/s]

Epoch 64 - Train loss: 0.4943, Val loss: 1.0505
Epoch 65 - Train loss: 0.4873, Val loss: 1.0509
Epoch 66 - Train loss: 0.5139, Val loss: 1.0517
Epoch 67 - Train loss: 0.4914, Val loss: 1.0524
Epoch 68 - Train loss: 0.4979, Val loss: 1.0533
Epoch 69 - Train loss: 0.4612, Val loss: 1.0544


 36%|███▋      | 73/200 [00:00<00:01, 79.14it/s]

Epoch 70 - Train loss: 0.4645, Val loss: 1.0554
Epoch 71 - Train loss: 0.4572, Val loss: 1.0574
Epoch 72 - Train loss: 0.4548, Val loss: 1.0588
Epoch 73 - Train loss: 0.4808, Val loss: 1.0597
Epoch 74 - Train loss: 0.4608, Val loss: 1.0607
Epoch 75 - Train loss: 0.4925, Val loss: 1.0615
Epoch 76 - Train loss: 0.4658, Val loss: 1.0620
Epoch 77 - Train loss: 0.4582, Val loss: 1.0632
Epoch 78 - Train loss: 0.4868, Val loss: 1.0634
Epoch 79 - Train loss: 0.4880, Val loss: 1.0639
Epoch 80 - Train loss: 0.4665, Val loss: 1.0645


 41%|████      | 82/200 [00:01<00:01, 79.93it/s]

Epoch 81 - Train loss: 0.4747, Val loss: 1.0645
Epoch 82 - Train loss: 0.4731, Val loss: 1.0644
Epoch 83 - Train loss: 0.4412, Val loss: 1.0647
Epoch 84 - Train loss: 0.4318, Val loss: 1.0650
Epoch 85 - Train loss: 0.4733, Val loss: 1.0653
Epoch 86 - Train loss: 0.4734, Val loss: 1.0658


 46%|████▌     | 91/200 [00:01<00:01, 80.12it/s]

Epoch 87 - Train loss: 0.4728, Val loss: 1.0658
Epoch 88 - Train loss: 0.4558, Val loss: 1.0660
Epoch 89 - Train loss: 0.4807, Val loss: 1.0659
Epoch 90 - Train loss: 0.4768, Val loss: 1.0661
Epoch 91 - Train loss: 0.4999, Val loss: 1.0664
Epoch 92 - Train loss: 0.4524, Val loss: 1.0666
Epoch 93 - Train loss: 0.4712, Val loss: 1.0668
Epoch 94 - Train loss: 0.4936, Val loss: 1.0667
Epoch 95 - Train loss: 0.4814, Val loss: 1.0667
Epoch 96 - Train loss: 0.5031, Val loss: 1.0667
Epoch 97 - Train loss: 0.4693, Val loss: 1.0666


 50%|█████     | 100/200 [00:01<00:01, 80.84it/s]

Epoch 98 - Train loss: 0.4652, Val loss: 1.0664
Epoch 99 - Train loss: 0.4841, Val loss: 1.0664
Epoch 100 - Train loss: 0.4530, Val loss: 1.0663
Epoch 101 - Train loss: 0.4721, Val loss: 1.0662
Epoch 102 - Train loss: 0.4510, Val loss: 1.0661
Epoch 103 - Train loss: 0.4585, Val loss: 1.0661


 55%|█████▍    | 109/200 [00:01<00:01, 80.02it/s]

Epoch 104 - Train loss: 0.4719, Val loss: 1.0661
Epoch 105 - Train loss: 0.4487, Val loss: 1.0663
Epoch 106 - Train loss: 0.4889, Val loss: 1.0665
Epoch 107 - Train loss: 0.4770, Val loss: 1.0666
Epoch 108 - Train loss: 0.4589, Val loss: 1.0666
Epoch 109 - Train loss: 0.4402, Val loss: 1.0667
Epoch 110 - Train loss: 0.4599, Val loss: 1.0668
Epoch 111 - Train loss: 0.4748, Val loss: 1.0669
Epoch 112 - Train loss: 0.4697, Val loss: 1.0670
Epoch 113 - Train loss: 0.4471, Val loss: 1.0672
Epoch 114 - Train loss: 0.4822, Val loss: 1.0672


 59%|█████▉    | 118/200 [00:01<00:01, 80.68it/s]

Epoch 115 - Train loss: 0.4398, Val loss: 1.0673
Epoch 116 - Train loss: 0.4432, Val loss: 1.0674
Epoch 117 - Train loss: 0.4340, Val loss: 1.0675
Epoch 118 - Train loss: 0.4726, Val loss: 1.0676
Epoch 119 - Train loss: 0.4979, Val loss: 1.0676
Epoch 120 - Train loss: 0.4751, Val loss: 1.0676


 64%|██████▎   | 127/200 [00:01<00:00, 80.18it/s]

Epoch 121 - Train loss: 0.4859, Val loss: 1.0676
Epoch 122 - Train loss: 0.4847, Val loss: 1.0675
Epoch 123 - Train loss: 0.4693, Val loss: 1.0675
Epoch 124 - Train loss: 0.4771, Val loss: 1.0675
Epoch 125 - Train loss: 0.4452, Val loss: 1.0675
Epoch 126 - Train loss: 0.4510, Val loss: 1.0675
Epoch 127 - Train loss: 0.4899, Val loss: 1.0675
Epoch 128 - Train loss: 0.4487, Val loss: 1.0675
Epoch 129 - Train loss: 0.4392, Val loss: 1.0676
Epoch 130 - Train loss: 0.4760, Val loss: 1.0676


 68%|██████▊   | 136/200 [00:01<00:00, 80.02it/s]

Epoch 131 - Train loss: 0.4886, Val loss: 1.0676
Epoch 132 - Train loss: 0.4530, Val loss: 1.0676
Epoch 133 - Train loss: 0.4778, Val loss: 1.0676
Epoch 134 - Train loss: 0.4675, Val loss: 1.0676
Epoch 135 - Train loss: 0.4773, Val loss: 1.0676
Epoch 136 - Train loss: 0.4897, Val loss: 1.0676
Epoch 137 - Train loss: 0.4837, Val loss: 1.0676


 72%|███████▎  | 145/200 [00:01<00:00, 80.10it/s]

Epoch 138 - Train loss: 0.4679, Val loss: 1.0676
Epoch 139 - Train loss: 0.4520, Val loss: 1.0676
Epoch 140 - Train loss: 0.4589, Val loss: 1.0676
Epoch 141 - Train loss: 0.4696, Val loss: 1.0676
Epoch 142 - Train loss: 0.4667, Val loss: 1.0676
Epoch 143 - Train loss: 0.5019, Val loss: 1.0676
Epoch 144 - Train loss: 0.4781, Val loss: 1.0676
Epoch 145 - Train loss: 0.4715, Val loss: 1.0676
Epoch 146 - Train loss: 0.4593, Val loss: 1.0676
Epoch 147 - Train loss: 0.4860, Val loss: 1.0676


 77%|███████▋  | 154/200 [00:01<00:00, 79.31it/s]

Epoch 148 - Train loss: 0.4677, Val loss: 1.0676
Epoch 149 - Train loss: 0.4893, Val loss: 1.0676
Epoch 150 - Train loss: 0.4616, Val loss: 1.0676
Epoch 151 - Train loss: 0.4597, Val loss: 1.0677
Epoch 152 - Train loss: 0.4746, Val loss: 1.0677
Epoch 153 - Train loss: 0.4515, Val loss: 1.0677


 82%|████████▏ | 163/200 [00:02<00:00, 80.07it/s]

Epoch 154 - Train loss: 0.4399, Val loss: 1.0677
Epoch 155 - Train loss: 0.4784, Val loss: 1.0677
Epoch 156 - Train loss: 0.4754, Val loss: 1.0677
Epoch 157 - Train loss: 0.4925, Val loss: 1.0677
Epoch 158 - Train loss: 0.4528, Val loss: 1.0677
Epoch 159 - Train loss: 0.4737, Val loss: 1.0677
Epoch 160 - Train loss: 0.4781, Val loss: 1.0677
Epoch 161 - Train loss: 0.4500, Val loss: 1.0677
Epoch 162 - Train loss: 0.4716, Val loss: 1.0677
Epoch 163 - Train loss: 0.4651, Val loss: 1.0677
Epoch 164 - Train loss: 0.4660, Val loss: 1.0677
Epoch 165 - Train loss: 0.4527, Val loss: 1.0677
Epoch 166 - Train loss: 0.4769, Val loss: 1.0678
Epoch 167 - Train loss: 0.5100, Val loss: 1.0678
Epoch 168 - Train loss: 0.4820, Val loss: 1.0678


 90%|█████████ | 181/200 [00:02<00:00, 78.16it/s]

Epoch 169 - Train loss: 0.4830, Val loss: 1.0678
Epoch 170 - Train loss: 0.4862, Val loss: 1.0678
Epoch 171 - Train loss: 0.4638, Val loss: 1.0678
Epoch 172 - Train loss: 0.4717, Val loss: 1.0678
Epoch 173 - Train loss: 0.4908, Val loss: 1.0678
Epoch 174 - Train loss: 0.4581, Val loss: 1.0678
Epoch 175 - Train loss: 0.4564, Val loss: 1.0678
Epoch 176 - Train loss: 0.4754, Val loss: 1.0678
Epoch 177 - Train loss: 0.4437, Val loss: 1.0678
Epoch 178 - Train loss: 0.4605, Val loss: 1.0678
Epoch 179 - Train loss: 0.4572, Val loss: 1.0678
Epoch 180 - Train loss: 0.4642, Val loss: 1.0679
Epoch 181 - Train loss: 0.4991, Val loss: 1.0679
Epoch 182 - Train loss: 0.4607, Val loss: 1.0679
Epoch 183 - Train loss: 0.4943, Val loss: 1.0679
Epoch 184 - Train loss: 0.4692, Val loss: 1.0679
Epoch 185 - Train loss: 0.4677, Val loss: 1.0679


 95%|█████████▌| 190/200 [00:02<00:00, 78.81it/s]

Epoch 186 - Train loss: 0.4647, Val loss: 1.0679
Epoch 187 - Train loss: 0.4513, Val loss: 1.0679
Epoch 188 - Train loss: 0.4779, Val loss: 1.0679
Epoch 189 - Train loss: 0.4757, Val loss: 1.0679
Epoch 190 - Train loss: 0.4716, Val loss: 1.0678
Epoch 191 - Train loss: 0.4647, Val loss: 1.0678
Epoch 192 - Train loss: 0.4596, Val loss: 1.0678
Epoch 193 - Train loss: 0.4894, Val loss: 1.0679
Epoch 194 - Train loss: 0.4422, Val loss: 1.0679
Epoch 195 - Train loss: 0.4643, Val loss: 1.0678
Epoch 196 - Train loss: 0.4772, Val loss: 1.0679
Epoch 197 - Train loss: 0.4859, Val loss: 1.0678


100%|██████████| 200/200 [00:02<00:00, 78.41it/s]

Epoch 198 - Train loss: 0.4941, Val loss: 1.0678
Epoch 199 - Train loss: 0.4824, Val loss: 1.0678



