In [1]:
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
import pickle
import pandas as pd
import geopandas as gpd

import torch_geometric
from torch_geometric.data import Data
from torch.utils.data import DataLoader, Dataset
from torch_geometric.transforms import LineGraph

from torch_geometric.data import Batch
from torch_geometric.data import Data, Batch

from shapely.geometry import LineString
import tqdm 
import torch.nn.functional as F

def collate_fn(data_list):
    return Batch.from_data_list(data_list)

# Abstract

This is the current working version.
The steps are the following:

1. Load data
2. Pick a loss function
3. Split into train and test data
4. Training loop

## 1. Load data and create the dataset

In [2]:
with open('../results/results_pop_1pct_toy_example.pkl', 'rb') as f:
    results_dict = pickle.load(f)
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
class MyGeometricDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list
        
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]

# Create your data objects
datalist = []
for key, df in results_dict.items():
    if isinstance(df, pd.DataFrame):
        gdf = gpd.GeoDataFrame(df, geometry='geometry')
        gdf.crs = "EPSG:2154"  # Assuming the original CRS is EPSG:2154
        gdf.to_crs("EPSG:4326", inplace=True)
        
        nodes = []
        edges = []
        edge_car_volumes = []
        node_to_idx = {}
        capacities = {}
        edge_positions = []

        # Iterate through the rows of the GeoDataFrame
        for idx, row in gdf.iterrows():
            from_node = row['from_node']
            to_node = row['to_node']
            car_volume = row['vol_car']
            capacity = row['capacity']
            
            # Get coordinates from the LINESTRING geometry
            coords = list(row.geometry.coords)
            from_position = coords[0]
            to_position = coords[-1]
            
            # Assign unique indices to nodes
            if from_node not in node_to_idx:
                node_to_idx[from_node] = len(nodes)
                nodes.append(from_node)
                capacities[node_to_idx[from_node]] = capacity

            if to_node not in node_to_idx:
                node_to_idx[to_node] = len(nodes)
                nodes.append(to_node)
                capacities[node_to_idx[to_node]] = capacity
            
            # Append edge index and attributes
            edge = (node_to_idx[from_node], node_to_idx[to_node])
            if edge not in edges:
                edges.append(edge)
                edge_car_volumes.append(car_volume)  # Target values
                
                # Compute edge position (e.g., midpoint)
                edge_position = ((from_position[0] + to_position[0]) / 2, (from_position[1] + to_position[1]) / 2)
                edge_positions.append(edge_position)

        # Convert lists to tensors
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_positions_tensor = torch.tensor(edge_positions, dtype=torch.float)
        
        x = torch.tensor([[capacities[i]] for i in range(len(nodes))], dtype=torch.float)
        
        # Create Data object
        data = Data(edge_index=edge_index, x=x)
        
        # Transform to line graph
        linegraph_transformation = LineGraph()
        linegraph_data = linegraph_transformation(data)
        
        # Prepare the x for line graph: index and capacity
        linegraph_x = torch.zeros((linegraph_data.num_nodes, 2), dtype=torch.float)
        
        for i, (from_idx, to_idx) in enumerate(edges):
            capacity = capacities[from_idx]  # Assuming capacity is the same for from and to node
            linegraph_x[i, 0] = i  # Index
            linegraph_x[i, 1] = capacity
        
        linegraph_data.x = linegraph_x
        
        # Target tensor for car volumes
        linegraph_data.y = torch.tensor(edge_car_volumes, dtype=torch.float).unsqueeze(1)
        
        if linegraph_data.validate(raise_on_error=True):
            datalist.append(linegraph_data)
        else:
            print("Invalid line graph data")
            
dataset = MyGeometricDataset(datalist)

## Define the model

In [20]:
class GnnModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch_geometric.nn.GCNConv(2, 16)
        self.conv2 = torch_geometric.nn.GCNConv(16, 1)
        # self.layers = nn.Sequential(
        # nn.Linear(3, 64),
        # nn.ReLU(),
        # nn.Linear(64, 32),
        # nn.ReLU(),
        # nn.Linear(32, 1)
        # )
        
    def forward(self, x):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
    
def get_data_from_dataloader(is_train, batch_size, dataset):
    sub_dataset = torch.utils.data.Subset(dataset, range(0, int(len(dataset) * 0.8)) if is_train else range(int(len(dataset) * 0.2), len(dataset)))
    return DataLoader(dataset=sub_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

def validate_model(model, valid_dl, loss_func, device):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0
    with torch.inference_mode():
        for data in valid_dl:
            data, expected = data.x.to(device), data.y.to(device)
            predicted = model(data)
            val_loss += loss_func(predicted, expected)*expected.size(0)
    return val_loss 

## Train and test the model

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Train the model
for epoch in range(10):
    # 🐝 initialise a wandb run
    wandb.init(
        project="gnn_1",
        config={
            "epochs": 10,
            "batch_size": 1,
            "lr": 0.01,
            "dropout": random.uniform(0.01, 0.80),
            })
    config = wandb.config
    
    # Get data
    train_dl = get_data_from_dataloader(dataset  = dataset, is_train=True, batch_size=config.batch_size)
    valid_dl = get_data_from_dataloader(dataset  = dataset, is_train=False, batch_size=config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

    # Get the model
    model = GnnModel().to(device)
    
    # Define loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    loss_fct = torch.nn.MSELoss()
    
    # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        print(train_dl)
        for step, data in enumerate(train_dl):
            input_node_features, expected_node_feats = data.x.to(device), data.y.to(device)
            predicted = model(data)
            train_loss = loss_fct(predicted, expected_node_feats)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            wandb.log({"train_loss": train_loss.item(), "step": step_ct})
            print(f"epoch: {epoch}, step: {step}, loss: {train_loss.item()}")
            example_ct += len(expected_node_feats)
            step_ct += 1
            
        val_loss = validate_model(model, valid_dl, loss_fct, device)
        wandb.log({"val_loss": val_loss,  "step": step_ct})
        # print(f"epoch: {epoch}, val_loss: {val_loss}")
        
wandb.summary["val_loss"] = val_loss
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▂▃▅▆▇█
train_loss,▄▇▅▆▇█▁

0,1
step,6.0
train_loss,21002.53906


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167530099757843, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x1688183a0>
epoch: 0, step: 0, loss: 21002.5390625
epoch: 0, step: 1, loss: 21086.25
epoch: 0, step: 2, loss: 21170.69921875
epoch: 0, step: 3, loss: 21187.080078125
epoch: 0, step: 4, loss: 21163.212890625
epoch: 0, step: 5, loss: 21110.92578125
epoch: 0, step: 6, loss: 21147.349609375
epoch: 0, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x1688183a0>
epoch: 1, step: 0, loss: 21170.69921875
epoch: 1, step: 1, loss: 21110.92578125
epoch: 1, step: 2, loss: 21163.212890625
epoch: 1, step: 3, loss: 21086.25
epoch: 1, step: 4, loss: 21187.080078125
epoch: 1, step: 5, loss: 21002.5390625
epoch: 1, step: 6, loss: 21147.349609375
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x1688183a0>
epoch: 2, step: 0, loss: 21163.212890625
epoch: 2, step: 1, loss: 21002.5390625
epoch: 2, step: 2, loss: 21086.25
epoch: 2, step: 3, loss: 21147.349609375
epoch: 2, step: 4, loss: 21187.0800781

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▁▄█▅▇▅▄▁▇▁▆▇█▇▇▅▁█▅▇▇▆▅▄▇█▇▄█▆▁▇▅▇▇▁▁▅█▇
val_loss,▅▅▅▁▅▅▁██▁

0,1
step,70.0
train_loss,21170.69922
val_loss,5321843712.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112055100319493, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168c80cd0>
epoch: 0, step: 0, loss: 21147.349609375
epoch: 0, step: 1, loss: 21170.69921875
epoch: 0, step: 2, loss: 21163.212890625
epoch: 0, step: 3, loss: 21110.92578125
epoch: 0, step: 4, loss: 21187.080078125
epoch: 0, step: 5, loss: 21002.5390625
epoch: 0, step: 6, loss: 21086.25
epoch: 0, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168c80cd0>
epoch: 1, step: 0, loss: 21147.349609375
epoch: 1, step: 1, loss: 21110.92578125
epoch: 1, step: 2, loss: 21002.5390625
epoch: 1, step: 3, loss: 21187.080078125
epoch: 1, step: 4, loss: 21086.25
epoch: 1, step: 5, loss: 21170.69921875
epoch: 1, step: 6, loss: 21163.212890625
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168c80cd0>
epoch: 2, step: 0, loss: 21187.080078125
epoch: 2, step: 1, loss: 21147.349609375
epoch: 2, step: 2, loss: 21002.5390625
epoch: 2, step: 3, loss: 21086.25
epoch: 2, step: 4, loss: 21110.9257812

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▆▇▅▁▆▅█▇█▆▄▇▄▆▇▁▁▄▇▇▄▅▇█▅▇▆▁▁▇▄▅█▇▄▇▁▄▇▅
val_loss,▅▅▅▅█▁█▁▅█

0,1
step,70.0
train_loss,21110.92578
val_loss,5321844736.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011143871300090622, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168cff1f0>
epoch: 0, step: 0, loss: 21187.080078125
epoch: 0, step: 1, loss: 21110.92578125
epoch: 0, step: 2, loss: 21147.349609375
epoch: 0, step: 3, loss: 21170.69921875
epoch: 0, step: 4, loss: 21002.5390625
epoch: 0, step: 5, loss: 21086.25
epoch: 0, step: 6, loss: 21163.212890625
epoch: 0, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168cff1f0>
epoch: 1, step: 0, loss: 21163.212890625
epoch: 1, step: 1, loss: 21147.349609375
epoch: 1, step: 2, loss: 21086.25
epoch: 1, step: 3, loss: 21187.080078125
epoch: 1, step: 4, loss: 21170.69921875
epoch: 1, step: 5, loss: 21002.5390625
epoch: 1, step: 6, loss: 21110.92578125
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168cff1f0>
epoch: 2, step: 0, loss: 21170.69921875
epoch: 2, step: 1, loss: 21147.349609375
epoch: 2, step: 2, loss: 21086.25
epoch: 2, step: 3, loss: 21187.080078125
epoch: 2, step: 4, loss: 21002.539062

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,█▅▇▄▇▆█▁▇▆█▅▁▇█▄▄█▆▇▆▇▁▄█▇▁▆▄▅▇▁▆▇▄▁▄▆▇▅
val_loss,▅▅▅▅▅▁█▅▅▅

0,1
step,70.0
train_loss,21110.92578
val_loss,5321844224.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167389355574011, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168c18970>
epoch: 0, step: 0, loss: 21147.349609375
epoch: 0, step: 1, loss: 21086.25
epoch: 0, step: 2, loss: 21187.080078125
epoch: 0, step: 3, loss: 21163.212890625
epoch: 0, step: 4, loss: 21002.5390625
epoch: 0, step: 5, loss: 21170.69921875
epoch: 0, step: 6, loss: 21110.92578125
epoch: 0, val_loss: 5321843712.0
<torch.utils.data.dataloader.DataLoader object at 0x168c18970>
epoch: 1, step: 0, loss: 21163.212890625
epoch: 1, step: 1, loss: 21147.349609375
epoch: 1, step: 2, loss: 21002.5390625
epoch: 1, step: 3, loss: 21187.080078125
epoch: 1, step: 4, loss: 21170.69921875
epoch: 1, step: 5, loss: 21110.92578125
epoch: 1, step: 6, loss: 21086.25
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168c18970>
epoch: 2, step: 0, loss: 21110.92578125
epoch: 2, step: 1, loss: 21147.349609375
epoch: 2, step: 2, loss: 21002.5390625
epoch: 2, step: 3, loss: 21187.080078125
epoch: 2, step: 4, loss: 21163.2

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▆▄▇▇▇▆█▅▅▆█▄▅▇▁██▁▆▄▇█▅▇▄▅▁█▅▁▇▄▄▇▅▁▅▆█▇
val_loss,▁▅▅▁▅▁▅▅█▅

0,1
step,70.0
train_loss,21170.69922
val_loss,5321844224.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167503244359977, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168c8e380>
epoch: 0, step: 0, loss: 21147.349609375
epoch: 0, step: 1, loss: 21002.5390625
epoch: 0, step: 2, loss: 21163.212890625
epoch: 0, step: 3, loss: 21086.25
epoch: 0, step: 4, loss: 21170.69921875
epoch: 0, step: 5, loss: 21110.92578125
epoch: 0, step: 6, loss: 21187.080078125
epoch: 0, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168c8e380>
epoch: 1, step: 0, loss: 21110.92578125
epoch: 1, step: 1, loss: 21170.69921875
epoch: 1, step: 2, loss: 21163.212890625
epoch: 1, step: 3, loss: 21086.25
epoch: 1, step: 4, loss: 21147.349609375
epoch: 1, step: 5, loss: 21187.080078125
epoch: 1, step: 6, loss: 21002.5390625
epoch: 1, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x168c8e380>
epoch: 2, step: 0, loss: 21147.349609375
epoch: 2, step: 1, loss: 21170.69921875
epoch: 2, step: 2, loss: 21086.25
epoch: 2, step: 3, loss: 21187.080078125
epoch: 2, step: 4, loss: 21002.539062

VBox(children=(Label(value='0.001 MB of 0.020 MB uploaded\r'), FloatProgress(value=0.047015344311377244, max=1…

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▆▁▄▅▅▇▄█▆▇█▇█▅▇▄█▇▅▇▅▇▁▇▁█▄▅▅▁▆▄█▇▆▁▄█▅▇
val_loss,▁█▁█▁▁▁▁▁▁

0,1
step,70.0
train_loss,21170.69922
val_loss,5321844224.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168028711108492, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x1664e9030>
epoch: 0, step: 0, loss: 21110.92578125
epoch: 0, step: 1, loss: 21163.212890625
epoch: 0, step: 2, loss: 21086.25
epoch: 0, step: 3, loss: 21002.5390625
epoch: 0, step: 4, loss: 21147.349609375
epoch: 0, step: 5, loss: 21187.080078125
epoch: 0, step: 6, loss: 21170.69921875
epoch: 0, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x1664e9030>
epoch: 1, step: 0, loss: 21086.25
epoch: 1, step: 1, loss: 21002.5390625
epoch: 1, step: 2, loss: 21170.69921875
epoch: 1, step: 3, loss: 21163.212890625
epoch: 1, step: 4, loss: 21110.92578125
epoch: 1, step: 5, loss: 21147.349609375
epoch: 1, step: 6, loss: 21187.080078125
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x1664e9030>
epoch: 2, step: 0, loss: 21110.92578125
epoch: 2, step: 1, loss: 21170.69921875
epoch: 2, step: 2, loss: 21147.349609375
epoch: 2, step: 3, loss: 21163.212890625
epoch: 2, step: 4, loss: 21086.

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▅▇▁█▄▁▇▆▅▇▇▁▁▅█▄▆█▅▁▆▇▁▅▁▅▄▆▄▇▁▇▁▄▅▇▇▇▁▆
val_loss,█▅██▁▅▅▅█▅

0,1
step,70.0
train_loss,21147.34961
val_loss,5321844224.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0111505129665602, max=1.0))…

<torch.utils.data.dataloader.DataLoader object at 0x1664e9db0>
epoch: 0, step: 0, loss: 21147.349609375
epoch: 0, step: 1, loss: 21086.25
epoch: 0, step: 2, loss: 21170.69921875
epoch: 0, step: 3, loss: 21187.080078125
epoch: 0, step: 4, loss: 21163.212890625
epoch: 0, step: 5, loss: 21002.5390625
epoch: 0, step: 6, loss: 21110.92578125
epoch: 0, val_loss: 5321843712.0
<torch.utils.data.dataloader.DataLoader object at 0x1664e9db0>
epoch: 1, step: 0, loss: 21110.92578125
epoch: 1, step: 1, loss: 21002.5390625
epoch: 1, step: 2, loss: 21086.25
epoch: 1, step: 3, loss: 21163.212890625
epoch: 1, step: 4, loss: 21170.69921875
epoch: 1, step: 5, loss: 21147.349609375
epoch: 1, step: 6, loss: 21187.080078125
epoch: 1, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x1664e9db0>
epoch: 2, step: 0, loss: 21163.212890625
epoch: 2, step: 1, loss: 21170.69921875
epoch: 2, step: 2, loss: 21147.349609375
epoch: 2, step: 3, loss: 21187.080078125
epoch: 2, step: 4, loss: 21110

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▆▄█▁▅▁▇▆▇▇█▄█▅▄▇▄▁▅██▇▆▄▄█▅▇▇▇▄▁█▇▅▆▄█▆▇
val_loss,▁██▁█▅███▅

0,1
step,70.0
train_loss,21163.21289
val_loss,5321844224.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167437499777104, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168c8db40>
epoch: 0, step: 0, loss: 21147.349609375
epoch: 0, step: 1, loss: 21086.25
epoch: 0, step: 2, loss: 21002.5390625
epoch: 0, step: 3, loss: 21187.080078125
epoch: 0, step: 4, loss: 21110.92578125
epoch: 0, step: 5, loss: 21163.212890625
epoch: 0, step: 6, loss: 21170.69921875
epoch: 0, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x168c8db40>
epoch: 1, step: 0, loss: 21086.25
epoch: 1, step: 1, loss: 21002.5390625
epoch: 1, step: 2, loss: 21170.69921875
epoch: 1, step: 3, loss: 21147.349609375
epoch: 1, step: 4, loss: 21187.080078125
epoch: 1, step: 5, loss: 21163.212890625
epoch: 1, step: 6, loss: 21110.92578125
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168c8db40>
epoch: 2, step: 0, loss: 21163.212890625
epoch: 2, step: 1, loss: 21170.69921875
epoch: 2, step: 2, loss: 21002.5390625
epoch: 2, step: 3, loss: 21187.080078125
epoch: 2, step: 4, loss: 21110.9

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▆▄█▇▄▁▆▇▇▇█▄▄▅▇▁█▄▇▇▇▁█▄▁█▇▅▇█▇▁▆▅▇█▇▆▄▇
val_loss,█▅▅█▅▅█▁██

0,1
step,70.0
train_loss,21170.69922
val_loss,5321844736.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167901388964513, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x162c6c4c0>
epoch: 0, step: 0, loss: 21086.25
epoch: 0, step: 1, loss: 21147.349609375
epoch: 0, step: 2, loss: 21110.92578125
epoch: 0, step: 3, loss: 21002.5390625
epoch: 0, step: 4, loss: 21187.080078125
epoch: 0, step: 5, loss: 21163.212890625
epoch: 0, step: 6, loss: 21170.69921875
epoch: 0, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x162c6c4c0>
epoch: 1, step: 0, loss: 21002.5390625
epoch: 1, step: 1, loss: 21163.212890625
epoch: 1, step: 2, loss: 21187.080078125
epoch: 1, step: 3, loss: 21170.69921875
epoch: 1, step: 4, loss: 21147.349609375
epoch: 1, step: 5, loss: 21086.25
epoch: 1, step: 6, loss: 21110.92578125
epoch: 1, val_loss: 5321843712.0
<torch.utils.data.dataloader.DataLoader object at 0x162c6c4c0>
epoch: 2, step: 0, loss: 21086.25
epoch: 2, step: 1, loss: 21163.212890625
epoch: 2, step: 2, loss: 21170.69921875
epoch: 2, step: 3, loss: 21110.92578125
epoch: 2, step: 4, loss: 21147.3496093

VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded\r'), FloatProgress(value=0.09911242603550297, max=1.…

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▄▆▁▇▁▇▇▄▄▇▅█▆▄▁▇█▇▇▅▄▇▆▇▆▇█▁▇▅▇█▇▄█▆▄▅▇▇
val_loss,█▁▅▅▅▅▁▅▅█

0,1
step,70.0
train_loss,21170.69922
val_loss,5321844736.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011135936110966011, max=1.0…

<torch.utils.data.dataloader.DataLoader object at 0x168e31390>
epoch: 0, step: 0, loss: 21110.92578125
epoch: 0, step: 1, loss: 21002.5390625
epoch: 0, step: 2, loss: 21163.212890625
epoch: 0, step: 3, loss: 21187.080078125
epoch: 0, step: 4, loss: 21170.69921875
epoch: 0, step: 5, loss: 21086.25
epoch: 0, step: 6, loss: 21147.349609375
epoch: 0, val_loss: 5321844736.0
<torch.utils.data.dataloader.DataLoader object at 0x168e31390>
epoch: 1, step: 0, loss: 21147.349609375
epoch: 1, step: 1, loss: 21187.080078125
epoch: 1, step: 2, loss: 21163.212890625
epoch: 1, step: 3, loss: 21086.25
epoch: 1, step: 4, loss: 21170.69921875
epoch: 1, step: 5, loss: 21110.92578125
epoch: 1, step: 6, loss: 21002.5390625
epoch: 1, val_loss: 5321844224.0
<torch.utils.data.dataloader.DataLoader object at 0x168e31390>
epoch: 2, step: 0, loss: 21110.92578125
epoch: 2, step: 1, loss: 21187.080078125
epoch: 2, step: 2, loss: 21170.69921875
epoch: 2, step: 3, loss: 21147.349609375
epoch: 2, step: 4, loss: 21163.

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▅▁█▄▆█▄▅▅█▆▄▇▄▅█▇▆▁█▇▅▇█▁▇▄▅▇▅▇▆▇▁▄▅▆▁▇▅
val_loss,█▅▅▅▅▁▅▁█▅

0,1
step,70.0
train_loss,21110.92578
val_loss,5321844224.0


## Analysing the model

In [None]:
# Evaluate the model
# model.eval()
# with torch.no_grad():
#     pred = model(data).cpu()
#     target = data.y.view(-1, 1).cpu()
#     mse = F.mse_loss(pred, target).item()
#     rmse = torch.sqrt(torch.tensor(mse)).item()
#     print(f'Mean Squared Error: {mse:.4f}')
#     print(f'Root Mean Squared Error: {rmse:.4f}')

# # Calculate target value statistics for comparison
# target_values = target.numpy()
# mean_target = target_values.mean()
# std_target = target_values.std()
# min_target = target_values.min()
# max_target = target_values.max()

# print(f'Mean of target values: {mean_target:.4f}')
# print(f'Standard deviation of target values: {std_target:.4f}')
# print(f'Minimum target value: {min_target:.4f}')
# print(f'Maximum target value: {max_target:.4f}')

Mean Squared Error: 20835.3457
Root Mean Squared Error: 144.3445
Mean of target values: 51.4052
Standard deviation of target values: 134.8809
Minimum target value: 0.0000
Maximum target value: 1593.0000
