## 1. Imports and Configuration

In [None]:
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import sys
import numpy as np
from shapely.geometry import LineString
from shapely.ops import linemerge
import igraph as ig
from pathlib import Path
import warnings
import os
from collections import defaultdict, deque
from itertools import chain
import matplotlib.lines as mlines

warnings.filterwarnings('ignore')

# Output directory
OUT_EDGES_DIR = Path("/soge-home/projects/mistral/miraca/processed_data/processed_unimodal/edges_with_flow")
OUT_EDGES_DIR.mkdir(parents=True, exist_ok=True)
OUT_DIR = OUT_EDGES_DIR

# Map extent
LON_MIN, LON_MAX = -12, 32
LAT_MIN, LAT_MAX = 35, 72

print("✓ Imports complete")

## 2. Helper Functions

Import helper functions from utils module or define them here.

In [None]:
# Import from miraca_flow_utils if available, or define helpers here
try:
    from miraca_flow_utils import (
        _pick_col, _ensure_metric_length, _compute_travel_time_kh,
        build_igraph_from_edges, map_stations_to_nodes, compute_edge_capacity_tons,
        od_flow_allocation_capacity_constrained, plot_edges_by_flow_thickness
    )
    print("✓ Imported helpers from miraca_flow_utils")
except ImportError:
    print("⚠ miraca_flow_utils not found; define functions inline or ensure module is in path")
    # Define helpers inline if needed (copy from original script)

## 3. Load Data

In [None]:
print("Loading data...", flush=True)

# Country boundaries
countries_path = r"/soge-home/projects/mistral/miraca/incoming_data/spatial_data/admin/ne_10m/ne_10m_admin_0_countries.shp"
europe_shape = gpd.read_file(countries_path)
europe_bounds = {"xmin": -12, "xmax": 32, "ymin": 35, "ymax": 72}
europe_shape = europe_shape.cx[europe_bounds["xmin"]:europe_bounds["xmax"], europe_bounds["ymin"]:europe_bounds["ymax"]]

# Rail network
rail_edges_file = "/soge-home/projects/mistral/miraca/processed_data/processed_unimodal/europe_railways_edges_TENT.parquet"
rail_nodes_file = "/soge-home/projects/mistral/miraca/processed_data/processed_unimodal/europe_railways_nodes_TENT.parquet"
rail_stations_file = "/soge-home/projects/mistral/miraca/processed_data/processed_unimodal/europe_railways_stations_TENT.parquet"
od_flows_file = "/soge-home/projects/mistral/miraca/processed_data/lifelines_OD/rail_freight_ths_tons_OD.parquet"

edges_gdf = gpd.read_parquet(rail_edges_file)
nodes_gdf = gpd.read_parquet(rail_nodes_file)
stations_df = gpd.read_parquet(rail_stations_file)
od_flows = pd.read_parquet(od_flows_file)

# Ensure CRS
if edges_gdf.crs is None: edges_gdf.set_crs("EPSG:4326", inplace=True, allow_override=True)
if nodes_gdf.crs is None: nodes_gdf.set_crs("EPSG:4326", inplace=True, allow_override=True)

# Compute travel time
edges_gdf = _compute_travel_time_kh(edges_gdf, speed_col='tag_maxspeed')

# Assign edge IDs
edges_gdf = edges_gdf.copy()
edges_gdf['edge_id'] = np.arange(len(edges_gdf), dtype=int)

print(f"✓ Loaded {len(edges_gdf)} edges, {len(nodes_gdf)} nodes, {len(stations_df)} stations")
print(f"✓ Loaded {len(od_flows)} OD pairs")

## 4. Prepare OD Matrix

In [None]:
# Identify columns
edge_src = _pick_col(edges_gdf, ['from_id'])
edge_dst = _pick_col(edges_gdf, ['to_id'])
node_id_col = _pick_col(nodes_gdf, ['id'])
station_id_col = _pick_col(stations_df, ['id'])

# Map stations to nodes
station2node = map_stations_to_nodes(stations_df, nodes_gdf, 
                                    station_id_col_candidates=(station_id_col,), 
                                    node_id_col=node_id_col)

# Prepare OD
if 'origin_sector' not in od_flows.columns:
    od_flows['origin_sector'] = 'UNKNOWN'

od = od_flows.copy()
od['from_node'] = od['from_id'].astype(str).map(station2node)
od['to_node'] = od['to_id'].astype(str).map(station2node)
od = od.dropna(subset=['from_node', 'to_node'])
od['from_node'] = od['from_node'].astype(str)
od['to_node'] = od['to_node'].astype(str)
od['value'] = pd.to_numeric(od['value'], errors='coerce').fillna(0.0)
od = od[od['value'] > 0]

print(f"✓ Valid OD pairs: {len(od)}")
print(f"✓ Total freight: {od['value'].sum():,.0f} tons/day")

## 5. Build Simplified Station-to-Station Network

This step creates a simplified network where each edge represents a path between consecutive stations. This dramatically reduces the number of edges while preserving capacity bottlenecks.

In [None]:
print("\n" + "="*80)
print("BUILDING SIMPLIFIED STATION-TO-STATION NETWORK")
print("="*80 + "\n")

# Build station-to-station paths (BFS between consecutive stations)
from scipy.spatial import cKDTree

# Map stations to nodes
sx = stations_df.geometry.x.to_numpy()
sy = stations_df.geometry.y.to_numpy()
nx_arr = nodes_gdf.geometry.x.to_numpy()
ny = nodes_gdf.geometry.y.to_numpy()
tree = cKDTree(np.column_stack([ny, nx_arr]))
_, idx = tree.query(np.column_stack([sy, sx]))

station_ids = stations_df[station_id_col].astype(str).to_numpy()
node_ids = nodes_gdf.iloc[idx][node_id_col].astype(str).to_numpy()
station_to_node = dict(zip(station_ids, node_ids))
node_to_station = {v: k for k, v in station_to_node.items()}

# Build adjacency list
adjacency = defaultdict(list)
edge_data = {}

for idx_row, row in edges_gdf.iterrows():
    src = str(row[edge_src])
    dst = str(row[edge_dst])
    edge_id = row['edge_id']
    
    edge_info = {
        'edge_id': edge_id,
        'to_node': dst,
        'travel_time': float(row['travel_time']),
        'geometry': row.geometry
    }
    adjacency[src].append(edge_info)
    
    edge_info_rev = edge_info.copy()
    edge_info_rev['to_node'] = src
    adjacency[dst].append(edge_info_rev)
    
    edge_data[edge_id] = edge_info

# Find paths between consecutive stations using BFS
paths = []
station_nodes = set(station_to_node.values())
processed_pairs = set()

for idx_s, (start_station_id, start_node) in enumerate(station_to_node.items()):
    if idx_s % 100 == 0 and idx_s > 0:
        print(f"  Processed {idx_s}/{len(station_to_node)} stations, found {len(paths)} paths...")
    
    visited = {}
    queue = deque([(start_node, [], 0)])
    
    while queue:
        current_node, path_edges, path_length = queue.popleft()
        
        if current_node in visited:
            if visited[current_node][1] <= path_length:
                continue
        visited[current_node] = (path_edges, path_length)
        
        if current_node in station_nodes and current_node != start_node:
            end_station_id = node_to_station[current_node]
            pair = tuple(sorted([start_station_id, end_station_id]))
            if pair in processed_pairs:
                continue
            processed_pairs.add(pair)
            
            if path_edges:
                travel_times = [edge_data[eid]['travel_time'] for eid in path_edges]
                geometries = [edge_data[eid]['geometry'] for eid in path_edges]
                
                paths.append({
                    'from_station': start_station_id,
                    'to_station': end_station_id,
                    'from_node': start_node,
                    'to_node': current_node,
                    'travel_time': sum(travel_times),
                    'edge_ids': path_edges.copy(),
                    'num_edges': len(path_edges),
                    'geometry': linemerge(geometries) if len(geometries) > 1 else geometries[0]
                })
            continue
        
        for edge_info in adjacency.get(current_node, []):
            next_node = edge_info['to_node']
            next_edge_id = edge_info['edge_id']
            
            if next_edge_id not in path_edges:
                new_path = path_edges + [next_edge_id]
                queue.append((next_node, new_path, path_length + 1))

simplified_edges = gpd.GeoDataFrame(paths, geometry='geometry', crs=edges_gdf.crs)
simplified_edges['path_id'] = np.arange(len(simplified_edges))

print(f"\n✓ Simplified network: {len(simplified_edges)} paths from {len(edges_gdf)} original edges")
print(f"  Reduction: {100*(1 - len(simplified_edges)/len(edges_gdf)):.1f}%")
print(f"  Avg edges/path: {simplified_edges['num_edges'].mean():.1f}")

## 6. Calculate Capacities on Simplified Network

In [None]:
print("\nComputing capacities on simplified paths...")

# Generate stochastic per-train capacity
rng = np.random.default_rng(42)
train_draw = np.maximum(rng.normal(700.0/1000, 35.0/1000, size=len(edges_gdf)), 0.0)
occ_draw = np.clip(rng.normal(0.90, 0.09, size=len(edges_gdf)), 0.0, 1.0)
eff_tons = train_draw * occ_draw

# Compute capacity for original edges
cap_tons_orig = compute_edge_capacity_tons(edges_gdf, tt_col='travel_time', train_tons=eff_tons)
edges_gdf['capacity'] = np.where(np.isfinite(cap_tons_orig), cap_tons_orig, 0.0)

# Bottleneck capacity for simplified paths
def compute_path_capacity(edge_ids_list, edges_with_cap):
    if not edge_ids_list:
        return 0.0
    capacities = [edges_with_cap.loc[edges_with_cap['edge_id'] == eid, 'capacity'].values[0]
                  for eid in edge_ids_list
                  if not edges_with_cap[edges_with_cap['edge_id'] == eid].empty]
    return min(capacities) if capacities else 0.0

simplified_edges['capacity'] = simplified_edges['edge_ids'].apply(
    lambda x: compute_path_capacity(x, edges_gdf)
)

print(f"✓ Capacity range: {simplified_edges['capacity'].min():.0f} - {simplified_edges['capacity'].max():.0f} tons/day")

## 7. Plot Simplified Network Capacities

In [None]:
# Ensure EPSG:4326 for plotting
simplified_edges_plot = simplified_edges.to_crs("EPSG:4326") if simplified_edges.crs and simplified_edges.crs.to_epsg() != 4326 else simplified_edges.copy()
simplified_edges_plot = simplified_edges_plot[simplified_edges_plot.geometry.notna()]
simplified_edges_plot = simplified_edges_plot[simplified_edges_plot['capacity'] > 0]

fig, ax = plt.subplots(figsize=(14, 10))
europe_shape.boundary.plot(ax=ax, color='#cccccc', linewidth=0.5, zorder=0)

simplified_edges_plot.plot(column='capacity',
                          ax=ax,
                          linewidth=1.5,
                          cmap='YlOrRd',
                          legend=True,
                          vmin=0,
                          vmax=simplified_edges_plot['capacity'].quantile(0.95),
                          legend_kwds={'label': 'Capacity (tons/day)', 'shrink': 0.7},
                          zorder=2)

ax.set_title(f'Simplified Station-to-Station Paths ({len(simplified_edges_plot)} paths)', fontsize=14, fontweight='bold')
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')
ax.set_xlim(LON_MIN, LON_MAX)
ax.set_ylim(LAT_MIN, LAT_MAX)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUT_DIR / 'rail_simplified_capacities.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved capacity plot")

## 8. Run Flow Allocation on Simplified Network

In [None]:
print("\n" + "="*80)
print("RUNNING FLOW ALLOCATION ON SIMPLIFIED NETWORK")
print("="*80 + "\n")

# Create network dataframe for simplified paths
network_df_simplified = pd.DataFrame({
    'from_node': simplified_edges['from_station'].values,
    'to_node': simplified_edges['to_station'].values,
    'edge_id': simplified_edges['path_id'].values,
    'travel_time': simplified_edges['travel_time'].values,
    'capacity': simplified_edges['capacity'].values,
    'flow': 0.0
})

# Build igraph from simplified network
base_graph = ig.Graph.TupleList(
    network_df_simplified[['from_node', 'to_node', 'edge_id', 'travel_time']].itertuples(index=False, name=None),
    edge_attrs=['edge_id', 'travel_time'],
    directed=False
)

# Aggregate OD by pair and commodity
od_pairs = od[['from_node', 'to_node', 'value', 'origin_sector']].groupby(
    ['from_node', 'to_node', 'origin_sector'], as_index=False
)['value'].sum()

# Total per pair
pair_totals = od_pairs.groupby(['from_node', 'to_node'], as_index=False)['value'].sum().rename(columns={'value': 'total_value'})
od_with_tot = od_pairs.merge(pair_totals, on=['from_node', 'to_node'], how='left')
od_with_tot['share'] = np.where(od_with_tot['total_value'] > 0,
                                od_with_tot['value'] / od_with_tot['total_value'],
                                0.0)

# Flow ODs with totals
flow_ods_total = od_with_tot[['from_node', 'to_node', 'total_value']].drop_duplicates()
flow_ods_total = flow_ods_total.rename(columns={'total_value': 'flow'})

# Run capacity-constrained allocation
capacity_ods_all, unassigned_paths, network_df_simplified, progress_df = od_flow_allocation_capacity_constrained(
    flow_ods=flow_ods_total,
    network_dataframe=network_df_simplified,
    flow_column='flow',
    cost_column='travel_time',
    path_id_column='edge_id',
    attribute_list=None,
    origin_id_column='from_node',
    destination_id_column='to_node',
    network_capacity_column='capacity',
    directed=False,
    simple=False,
    store_edge_path=True,
    graph_base=base_graph,
    track_progress=True,
    early_stop_share=0.75
)

print(f"\n✓ Flow allocation complete")
print(f"  Assigned: {len(capacity_ods_all)} path groups")
print(f"  Unassigned: {len(unassigned_paths)}")

## 9. Map Flows Back to Original Edges

In [None]:
print("\nMapping flows back to original edges...")

# Get flows on simplified paths
simplified_edges['flow'] = network_df_simplified.set_index('edge_id')['flow'].reindex(
    simplified_edges['path_id']
).fillna(0.0).values

# Initialize flow on original edges
edges_gdf['flow'] = 0.0

# Map flows from simplified paths to original edges
for _, path_row in simplified_edges.iterrows():
    edge_ids = path_row.get('edge_ids', [])
    flow_val = path_row.get('flow', 0.0)
    
    if flow_val > 0 and edge_ids:
        mask = edges_gdf['edge_id'].isin(edge_ids)
        edges_gdf.loc[mask, 'flow'] += flow_val

print(f"✓ Total flow on network: {edges_gdf['flow'].sum():,.0f} tons/day")
print(f"✓ Edges with flow: {(edges_gdf['flow'] > 0).sum()} / {len(edges_gdf)}")

## 10. Save Results

In [None]:
# Prepare output
drop_cols = ['length', 'distance', 'from_infra', 'to_infra']
edges_out = edges_gdf.drop(columns=[c for c in drop_cols if c in edges_gdf.columns]).copy()
edges_out['flow'] = edges_out['flow'].fillna(0.0).astype(float)

# Apply 25% increase (calibration factor)
edges_out['flow'] = edges_out['flow'] * 1.25

# Save
edges_out.to_parquet(OUT_EDGES_DIR / "rail_edges_with_freight_flow.parquet", index=False)
edges_out[edges_out['flow'] > 0].to_parquet(OUT_EDGES_DIR / "rail_edges_freight_flow_positive.parquet", index=False)

print(f"✓ Saved rail_edges_with_freight_flow.parquet")
print(f"✓ Saved rail_edges_freight_flow_positive.parquet ({(edges_out['flow'] > 0).sum()} edges)")

## 11. Visualize Final Flow

In [None]:
fig, ax = plt.subplots(figsize=(12, 10), dpi=200)

# Basemap
europe_shape.boundary.plot(ax=ax, color='#666666', linewidth=0.3, zorder=0)

# Reproject for plotting
edges_plot = edges_out.to_crs(4326) if edges_out.crs and edges_out.crs.to_epsg() != 4326 else edges_out

# Plot with thickness by flow (log scale)
plot_edges_by_flow_thickness(
    ax, edges_plot, flow_col='flow',
    scale='log', log_base=10.0, scale_div=250.0,
    classify_method='none', lw_min=0.1, lw_max=3.5
)

ax.set_xlim(LON_MIN, LON_MAX)
ax.set_ylim(LAT_MIN, LAT_MAX)
ax.set_title("Rail Freight Flow (thousands tons/day)", fontsize=14, fontweight='bold')
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUT_DIR / "rail_edges_freight_flow.png", dpi=200)
plt.show()

print("✓ Saved flow visualization")

## 12. Summary Statistics

In [None]:
print("="*80)
print("RAIL FREIGHT FLOW ASSIGNMENT SUMMARY")
print("="*80)

print(f"\nNetwork:")
print(f"  Original edges: {len(edges_gdf):,}")
print(f"  Simplified paths: {len(simplified_edges):,}")
print(f"  Reduction: {100*(1 - len(simplified_edges)/len(edges_gdf)):.1f}%")

print(f"\nOD Matrix:")
print(f"  Total OD pairs: {len(od):,}")
print(f"  Total demand: {od['value'].sum():,.0f} tons/day")

print(f"\nFlow Assignment:")
print(f"  Edges with flow: {(edges_out['flow'] > 0).sum():,} / {len(edges_out):,}")
print(f"  Total flow on network: {edges_out['flow'].sum():,.0f} tons/day")
print(f"  Max edge flow: {edges_out['flow'].max():,.0f} tons/day")
print(f"  Mean edge flow (non-zero): {edges_out[edges_out['flow'] > 0]['flow'].mean():,.0f} tons/day")

print(f"\nCapacity:")
print(f"  Total network capacity: {edges_gdf['capacity'].sum():,.0f} tons/day")
print(f"  Utilization: {100*edges_out['flow'].sum()/edges_gdf['capacity'].sum():.1f}%")

print("\n" + "="*80)
print("COMPLETE")
print("="*80)