In [1]:
import math
import random
import pickle

import numpy as np
import pandas as pd
import geopandas as gpd
import tqdm
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, Subset

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import LineGraph

from shapely.geometry import LineString

# Abstract

Here we generate the data, and in the notebook gnn_for_policy_traffic_prediction_2 we do the model.

## 1. Load data and create the dataset

In [2]:
with open('../results/results_pop_1pm_first_1400.pkl', 'rb') as f:
    results_dict = pickle.load(f)

In [3]:
# Define a dictionary to map each mode to an integer
mode_mapping = {
    'bus': 0,
    'car': 1,
    'car_passenger': 2,
    'pt': 3,
    'bus,car,car_passenger': 4,
    'bus,car,car_passenger,pt': 5,
    'car,car_passenger': 6,
    'pt,rail,train': 7,
    'bus,pt': 8,
    'rail': 9,
    'pt,subway': 10,
    'artificial,bus': 11,
    'artificial,rail': 12,
    'artificial,stopFacilityLink,subway': 13,
    'artificial,subway': 14,
    'artificial,stopFacilityLink,tram': 15,
    'artificial,tram': 16,
    'artificial,bus,stopFacilityLink': 17,
    'artificial,funicular,stopFacilityLink': 18,
    'artificial,funicular': 19
}

# Function to encode modes into integer format
def encode_modes(modes):
    return mode_mapping.get(modes, -1)  # Use -1 for any unknown modes

In [4]:
# Create data objects
datalist = []
counter = 0
linegraph_transformation = LineGraph()

for key, df in results_dict.items():
    counter += 1
    if isinstance(df, pd.DataFrame):
        gdf = gpd.GeoDataFrame(df, geometry='geometry')
        gdf.crs = "EPSG:2154"  # Assuming the original CRS is EPSG:2154
        gdf.to_crs("EPSG:4326", inplace=True)
        
        # Create dictionaries for nodes and edges
        nodes = pd.concat([gdf['from_node'], gdf['to_node']]).unique()
        node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        
        gdf['from_idx'] = gdf['from_node'].map(node_to_idx)
        gdf['to_idx'] = gdf['to_node'].map(node_to_idx)
        
        edges = gdf[['from_idx', 'to_idx']].values
        edge_car_volumes = gdf['vol_car'].values
        capacities = gdf['capacity'].values
        freespeeds = gdf['freespeed'].values  
        lengths = gdf['length'].values  
        modes = gdf['modes'].values
        modes_encoded = np.vectorize(encode_modes)(modes)
        
        edge_positions = np.array([((geom.coords[0][0] + geom.coords[-1][0]) / 2, 
                                    (geom.coords[0][1] + geom.coords[-1][1]) / 2) 
                                   for geom in gdf.geometry])

        # Convert lists to tensors
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_positions_tensor = torch.tensor(edge_positions, dtype=torch.float)
        x = torch.zeros((len(nodes), 1), dtype=torch.float)
        
        # Create Data object
        target_values = torch.tensor(edge_car_volumes, dtype=torch.float).unsqueeze(1)
        data = Data(edge_index=edge_index, x=x, pos=edge_positions_tensor)
        
        # Transform to line graph
        linegraph_data = linegraph_transformation(data)
        
        # Prepare the x for line graph: index and capacity
        linegraph_x = torch.tensor(np.column_stack((capacities, freespeeds, lengths, modes_encoded)), dtype=torch.float)

        linegraph_data.x = linegraph_x
        
        # Target tensor for car volumes
        linegraph_data.y = target_values
        
        if linegraph_data.validate(raise_on_error=True):
            datalist.append(linegraph_data)
        else:
            print("Invalid line graph data")
            
# Convert dataset to a list of dictionaries
data_dict_list = [{'x': lg_data.x, 'edge_index': lg_data.edge_index, 'pos': lg_data.pos, 'y': lg_data.y} for lg_data in datalist]

In [None]:
# Save the list of dictionaries
torch.save(data_dict_list, 'dataset_1pm_0-1382_with_more_infos.pt')