In [2]:
import importlib
from pathlib import Path
import os
import sys
import logging
import geopandas as gpd
import numpy as np
import pandas as pd
import networkx as nx
#import plotly.express as px
import matplotlib.pyplot as plt
parent_dir = str(Path().resolve().parent)
sys.path.append(parent_dir)

In [None]:
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.patches as mpatches

def add_full_trace(fig, px_data, pos):
    for trace in px_data:
        fig.add_trace(trace, row=pos[0], col=pos[1])

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from disruptsc.parameters import Parameters
from disruptsc.model.model import Model
import disruptsc.paths as paths
import disruptsc.network.transport_network

In [5]:
scope = "Cambodia"
parameters = Parameters.load_parameters(paths.PARAMETER_FOLDER, scope)
parameters.export_files = False
parameters.adjust_logging_behavior()
model = Model(parameters)
parameters.io_cutoff = 0.1
parameters.adaptive_supplier_weight = False
#parameters.logistics['basic_cost_random'] = False
cache_params = (False, False, False, False)
cache_params = (True, True, True, True)

In [6]:
from  disruptsc.network.mrio import Mrio

In [7]:
mrio = Mrio.load_mrio_from_filepath(parameters.filepaths['mrio'], monetary_units="mUSD")

AttributeError: Can only use .str accessor with string values!

In [6]:
import disruptsc.network.transport_network
from disruptsc.network.transport_network import TransportNetwork

In [7]:
model.setup_transport_network(cache_params[0])

2025-05-05 14:00:58,539 - root - INFO - Transport network generated from temp file.
2025-05-05 14:00:58,770 - root - INFO - Total length of transport network is: 269893 km
2025-05-05 14:00:58,771 - root - INFO - maritime: 61760 km
2025-05-05 14:00:58,772 - root - INFO - multimodal: 8 km
2025-05-05 14:00:58,773 - root - INFO - pipelines: 21064 km
2025-05-05 14:00:58,773 - root - INFO - railways: 35336 km
2025-05-05 14:00:58,774 - root - INFO - roads: 151724 km
2025-05-05 14:00:58,778 - root - INFO - Nb of nodes: 8066, nb of edges: 12153


In [8]:
model.setup_agents(cache_params[1])

2025-05-05 14:00:58,941 - root - INFO - Firms, households, and countries generated from temp file.
2025-05-05 14:00:58,942 - root - INFO - Nb firms: 3145
2025-05-05 14:00:58,943 - root - INFO - Nb households: 56
2025-05-05 14:00:58,944 - root - INFO - Nb countries: 11


In [9]:
model.setup_sc_network(cache_params[2])

2025-05-05 14:00:59,389 - root - INFO - Supply chain generated from temp file.


In [10]:
model.set_initial_conditions()

2025-05-05 14:00:59,445 - root - INFO - Setting initial conditions to input-output equilibrium
2025-05-05 14:00:59,446 - root - INFO - Resetting variables on transport network
2025-05-05 14:00:59,485 - root - INFO - Resetting agents and commercial links variables


In [11]:
model.setup_logistic_routes(cache_params[3])

2025-05-05 14:01:01,140 - root - INFO - Logistic routes generated from temp file.


In [13]:
transport_input_share_per_firm = [
            sum([
                model.sc_network[u][v]['weight']
                for u, v in model.sc_network.in_edges(firm)
                if model.sc_network[u][v]['object'].product_type == 'transport'
            ])
            for firm in model.firms.values()
        ]

In [33]:
model.firm_table[['sector', 'sector_type']].drop_duplicates().set_index('sector')['sector_type'].to_dict()


{'ATP': 'transport',
 'BPH': 'manufacturing',
 'B_T': 'manufacturing',
 'CHM': 'manufacturing',
 'CMT': 'manufacturing',
 'CNS': 'construction',
 'CTL': 'agriculture',
 'EEQ': 'manufacturing',
 'ELE': 'manufacturing',
 'FMP': 'manufacturing',
 'FRS': 'agriculture',
 'FSH': 'agriculture',
 'GAS': 'oil_and_gas',
 'GDT': 'utility',
 'I_S': 'manufacturing',
 'LEA': 'manufacturing',
 'LUM': 'manufacturing',
 'MIL': 'agriculture',
 'MVH': 'manufacturing',
 'NFM': 'manufacturing',
 'NMM': 'manufacturing',
 'OAP': 'agriculture',
 'OFD': 'manufacturing',
 'OIL': 'oil_and_gas',
 'OME': 'manufacturing',
 'OMF': 'manufacturing',
 'OMT': 'manufacturing',
 'OTN': 'manufacturing',
 'OTP': 'transport',
 'PCR': 'manufacturing',
 'PDR': 'agriculture',
 'PFB': 'agriculture',
 'PPP': 'manufacturing',
 'P_C': 'manufacturing',
 'RMK': 'agriculture',
 'RPP': 'manufacturing',
 'SER': 'service',
 'SGR': 'agriculture',
 'TEX': 'manufacturing',
 'TRD': 'trade',
 'WAP': 'manufacturing',
 'WHS': 'transport',
 'WOL

In [26]:
model.firm_table['region_sector'].str.split('_')

id
0        [ARM, ATP]
1        [ARM, BPH]
2       [ARM, B, T]
3        [ARM, CHM]
4        [ARM, CMT]
           ...     
3140     [KGZ, RMK]
3141     [TJK, TEX]
3142     [AZE, RMK]
3143     [KGZ, WOL]
3144    [KAZ, P, C]
Name: region_sector, Length: 3145, dtype: object

### Static analysis

In [None]:
EPSILON = 1e-5

def add_borders(ax, selected_countries):
    countries = gpd.read_file("countries.geojson")
    selected_boundaries = countries[countries["iso_a3"].isin(selected_countries)] #ISO_A3
    minx, miny, maxx, maxy = selected_boundaries.total_bounds
    selected_boundaries.boundary.plot(ax=ax, color="black", linewidth=1)
    margin = 1
    ax.set_xlim(min(minx, maxx) - margin, max(minx, maxx) + margin)
    ax.set_ylim(min(miny, maxy) - margin, max(miny, maxy) + margin)

def calc_usd_km(df, what):
    return (df[what] * df['km']).sum()

def add_edge_id(ax, gdf):
    lengths = gdf.geometry.length
    threshold = lengths.quantile(0.995)  # Adjust as needed
    for idx, row in gdf.iterrows():
        line_length = row.geometry.length  # Compute line length
        if line_length > threshold:  # Only label big features
            centroid = row.geometry.centroid  # Get the center of the line
            ax.annotate(
                text=row["id"],  # Replace with the actual column name for edge ID
                xy=(centroid.x, centroid.y),  # Position at the centroid of the edge
                xytext=(0, 3),  # Offset slightly for better visibility
                textcoords="offset points",
                fontsize=8,
                color="black",
                ha="center",
                va="center",
                bbox=dict(facecolor="white", edgecolor="none", alpha=0.7)  # Background for readability
            )

def plot_route(route, model, ax):
    model.transport_edges.loc[route.transport_edge_ids].plot(ax=ax, color="red")

def plot_perturbed_link(perturbed_links, time_step, what):
    perturbed_links_one_ts = perturbed_links[perturbed_links['time_step'] == time_step]
    fig = make_subplots(rows=1, cols=3, subplot_titles=["Shipment Method", "Product Type", "Category"])
    add_full_trace(fig, px.histogram(perturbed_links_one_ts, x=what, color='shipment_method').data, (1,1))
    add_full_trace(fig, px.histogram(perturbed_links_one_ts, x=what, color='product_type').data, (1,2))
    add_full_trace(fig, px.histogram(perturbed_links_one_ts, x=what, color='category').data, (1,3))
    fig.update_layout(title_text="Perturbed commercial relationships", showlegend=False,
                      barmode="stack", height=300, width=800)
    fig.show()


def plot_pie_chart(flow_dif_on_edges, groups):
    # Calculate flow per category
    positive_flow_dif_on_edges = flow_dif_on_edges[flow_dif_on_edges["flow_total"] > EPSILON]
    positive_flow_per_cat = {cat: calc_usd_km(positive_flow_dif_on_edges, 'flow_' + cat) for cat in groups}
    positive_flow_per_cat = pd.Series(positive_flow_per_cat).reset_index()
    positive_flow_per_cat.columns = ["Category", "Value"]
    
    negative_flow_dif_on_edges = flow_dif_on_edges[flow_dif_on_edges["flow_total"] < EPSILON]
    negative_flow_per_cat = {cat: calc_usd_km(negative_flow_dif_on_edges, 'flow_' + cat) for cat in groups}
    negative_flow_per_cat = pd.Series(negative_flow_per_cat).reset_index()
    negative_flow_per_cat.columns = ["Category", "Value"]
    negative_flow_per_cat['Value'] = -negative_flow_per_cat['Value']

    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Negative Flow", "Positive Flow"),
                        specs=[[{"type": "domain"}, {"type": "domain"}]])  # Domain for pie charts
    fig.add_trace(px.pie(negative_flow_per_cat, names="Category", values="Value").data[0], row=1, col=1)
    fig.add_trace(px.pie(positive_flow_per_cat, names="Category", values="Value").data[0], row=1, col=2)
    fig.update_layout(title_text="Comparison of Positive and Negative Flow", showlegend=True, height=300)
    fig.show()

In [None]:
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
add_borders(ax, model.mrio.regions)
model.transport_edges.plot(column="type", legend=True, ax=ax, cmap="tab10")  # Use a colormap like "tab10"*
plt.show()

### Simulation

In [None]:
from disruptsc.disruption.disruption import DisruptionList

parameters.simulation_type = "event"
ecuador_event = {
    "type": "capital_destruction",
    "description_type": "sectors_homogeneous",
    #"attribute": "name",
    'region_sectors': model.mrio.region_sector_names, #['ECU_CAR', 'ECU_CIN'],
    #"values": ["railways-GEO"],
    "destroyed_capital": 2510.0,
    "unit": "mUSD",
    'reconstruction_market': False,
    "start_time": 1,
    "duration": 1
}
eca_event = {
    "type": "transport_disruption",
    "description_type": "edge_attributes",
    "duration": 3,
    "start_time": 1,
    "attribute": "name",
    "values": ["baku"]
}
parameters.events = [ecuador_event]

model.disruption_list = DisruptionList.from_events_parameter(parameters.events,
                                                            parameters.monetary_units_in_model,
                                                            model.transport_edges, model.firm_table,
                                                            model.firms)
if len(model.disruption_list) == 0:
    raise ValueError("No disruption could be read")
logging.info(f"{len(model.disruption_list)} disruption(s) will occur")
model.disruption_list.log_info()

disrupted_edges = [list(event.keys()) for event in model.disruption_list]
disrupted_edges = [item for sublist in disrupted_edges for item in sublist]

In [None]:
from disruptsc.simulation.simulation import Simulation

simulation = Simulation("disruption")
model.set_initial_conditions()

t_final = 180
for t in range(t_final + 1):
    model.run_one_time_step(time_step=t, current_simulation=simulation)

household_loss_per_region = simulation.calculate_household_loss(model.household_table, per_region=True)
household_loss = sum(household_loss_per_region.values())
country_loss = simulation.calculate_country_loss()
print("")
print("======== Simulation terminated ========")
print(f"Household loss: {int(household_loss)} {parameters.monetary_units_in_model}.")
print(f"Country loss: {int(country_loss)} {parameters.monetary_units_in_model}.")

In [None]:
from disruptsc.simulation.simulation import Simulation

simulation = Simulation("stationary_test")
model.set_initial_conditions()
t_final = 1
for t in range(t_final + 1):
    model.run_one_time_step(time_step=t, current_simulation=simulation)

### Change of flows

In [None]:
flow_df = pd.DataFrame(simulation.transport_network_data)
#flow_df = flow_df[flow_df['flow_total'] != 0]
transport_edges_with_flows = {}
for time_step in flow_df['time_step'].unique():
    transport_edges_with_flows[time_step] = pd.merge(
        model.transport_edges, flow_df[flow_df['time_step'] == time_step],
        how="left", on="id")

#transport_edges_with_flows[1].set_index('id').loc[disrupted_edges, "flow_total"]

flow_types = [flow_type for flow_type in flow_df.columns if flow_type[:4] == "flow"]
flow_dif = transport_edges_with_flows[1].set_index('id')[flow_types] - transport_edges_with_flows[0].set_index('id')[flow_types]
flow_dif = flow_dif.dropna(subset=['flow_total'])
flow_dif = flow_dif[abs(flow_dif['flow_total']) > 1e-9]
flow_dif = flow_dif.sort_values('flow_total', ascending=False)
flow_dif_on_edges = pd.merge(flow_dif.reset_index(), model.transport_edges, how="left", on="id")
flow_dif_on_edges = gpd.GeoDataFrame(flow_dif_on_edges, crs=model.transport_edges.crs)
flow_dif_on_edges['change'] = "decrease"
flow_dif_on_edges.loc[flow_dif_on_edges['flow_total'] > 0, 'change'] = "increase"
flow_dif_on_edges.head()

In [None]:
print(flow_dif_on_edges.groupby('change').apply(calc_usd_km, "flow_total", include_groups=False).to_dict())
print("net:", calc_usd_km(flow_dif_on_edges, 'flow_total'))

In [None]:
main_sectors = ['agriculture', 'manufacturing', 'mining', 'import'] #!+ ['oil_and_gas']
plot_pie_chart(flow_dif_on_edges, main_sectors)

In [None]:
flow_types = ['domestic_B2B', 'domestic_B2C', 'import', 'export']
plot_pie_chart(flow_dif_on_edges, flow_types)

In [None]:
transport_edges_with_flows[0].to_file("flows.geojson", driver="GeoJSON")

In [None]:
ee = [7887, 7889]
for u, v, dat in tn2.edges(data=True):
    if (dat['type'] == "railways") and (dat['special'] == "custom"):
        print(u,v, dat)

In [None]:
tn2 = model.transport_network.copy()
tn2.remove_edge(7875, 7876)

In [None]:
for u, v in r1.transport_edges[:]:
    print(u, v, model.transport_network[u][v]['km'], model.transport_network[u][v]['cost_per_ton_0_container'])

In [None]:
for u, v in r2.transport_edges[:]:
    print(u, v, model.transport_network[u][v]['km'], model.transport_network[u][v]['cost_per_ton_0_container'])

In [None]:
parameters = Parameters.load_parameters(paths.PARAMETER_FOLDER, scope)
parameters.add_variability_to_basic_cost()
model.transport_network.ingest_logistic_data(parameters.logistics)

r1 = model.firms[who].choose_route(model.transport_network, model.firms[who].od_point, model.countries['Europe'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r1, model, ax)
print(r1.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
r2 = model.firms[who].choose_route(model.transport_network, model.firms[who].od_point, model.countries['RUS'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r2, model, ax)
print(r2.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
r3 = model.firms[who].choose_route(model.transport_network, model.countries['RUS'].od_point, model.countries['Europe'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r3, model, ax)
print(r3.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
print(r2.sum_indicator(model.transport_network, "cost_per_ton_0_container") + r3.sum_indicator(model.transport_network, "cost_per_ton_0_container"))

In [None]:
who = 928
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
add_borders(ax, model.mrio.regions)
time_step = 0
transport_edges_with_flows[time_step]["line_width"] = transport_edges_with_flows[time_step]["flow_total"] / transport_edges_with_flows[0]["flow_total"].max() * 5  # Adjust factor for visibility
transport_edges_with_flows[time_step].plot(
    ax=ax,
    linewidth=transport_edges_with_flows[time_step]["line_width"],  
    linestyle="-",
    alpha=0.8,
    capstyle="round"
)
#add_edge_id(ax, transport_edges_with_flows[time_step])
model.transport_nodes.loc[[185, 318, 6929,
                           #model.households['hh_25'].od_point, 
                           model.firms[who].od_point,
                           model.countries['Europe'].od_point]].plot(ax=ax, color="red")
r1 = model.firms[who].choose_route(model.transport_network, model.firms[who].od_point, model.countries['Europe'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r1, model, ax)
print(r1.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
r2 = model.firms[who].choose_route(model.transport_network, model.firms[who].od_point, model.countries['RUS'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r2, model, ax)
print(r2.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
r3 = model.firms[who].choose_route(model.transport_network, model.countries['RUS'].od_point, model.countries['Europe'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
plot_route(r3, model, ax)
print(r3.sum_indicator(model.transport_network, "cost_per_ton_0_container"))
if model.disruption_list:
    if pd.Series(['Transport' in str(type(event)) for event in model.disruption_list]).any():
        disrupted_edges_midpoint = model.transport_edges[model.transport_edges['id'].isin(disrupted_edges)].copy()
        disrupted_edges_midpoint["geometry"] = disrupted_edges_midpoint["geometry"].apply(lambda line: line.interpolate(0.5, normalized=True) if line else None)
        disrupted_edges_midpoint.plot(ax=ax, color="red", markersize=100, marker="o", label="Midpoints", zorder=10)

plt.show()

In [None]:
model.firms[931]

In [None]:
r2 = model.firms[931].choose_route(model.transport_network, model.firms[931].od_point, model.countries['CHN'].od_point, shipment_method="container", capacity_constraint=False, transport_cost_noise_level=0)
r2.transport_modes

In [None]:
from disruptsc.model.basic_functions import rescale_monetary_values, find_nearest_node_id

buying_countries = model.mrio.external_buying_countries
selling_countries = model.mrio.external_selling_countries
country_list = list(set(buying_countries) | set(selling_countries))
country_table = gpd.read_file(parameters.filepaths['region_table']).set_index('region').loc[country_list]
admissible_node_mode = ['roads', 'railways', 'maritime']
potential_nodes = model.transport_nodes[model.transport_nodes['type'].isin(admissible_node_mode)]
country_table['od_point'] = find_nearest_node_id(potential_nodes, country_table)
country_table

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
add_borders(ax, model.mrio.regions)
flow_dif_on_edges["line_width"] = abs(flow_dif_on_edges["flow_total"]) / abs(flow_dif_on_edges["flow_total"]).max() * 5  # Adjust factor for visibility
color_map = {"increase": "green", "decrease": "orange"}
if flow_dif_on_edges.shape[0] > 0:
    for change_type, color in color_map.items():
        cond = flow_dif_on_edges["change"] == change_type
        if cond.any():
            flow_dif_on_edges[cond].plot(
                ax=ax,
                color=color,
                linewidth=flow_dif_on_edges.loc[flow_dif_on_edges["change"] == change_type, "line_width"],  
                linestyle="-",
                alpha=0.8,
                capstyle="round",
                label=f"Change: {change_type}"
            )
if pd.Series(['Transport' in str(type(event)) for event in model.disruption_list]).any():
    disrupted_edges_midpoint = model.transport_edges[model.transport_edges['id'].isin(disrupted_edges)].copy()
    disrupted_edges_midpoint["geometry"] = disrupted_edges_midpoint["geometry"].apply(lambda line: line.interpolate(0.5, normalized=True) if line else None)
    disrupted_edges_midpoint.plot(ax=ax, color="red", markersize=100, marker="o", label="Midpoints", zorder=10)

legend_patches = [
    mpatches.Patch(color="green", label="Increase (+)"),
    mpatches.Patch(color="orange", label="Decrease (-)"),
    mpatches.Patch(color="red", label="Disrupted edges"),
]
ax.legend(handles=legend_patches, title="Legend", loc="upper right")

plt.show()

### Perturbed commercial relationships

In [None]:
  if len (simulation.sc_network_data) == 0:
    print("No perturbed links")

else:
    perturbed_links = pd.merge(pd.DataFrame(simulation.sc_network_data), model.commercial_link_table, how='left', on='pid')
    print((perturbed_links['time_step'] == 0).any())
    print(model.commercial_link_table.shape[0])
    print(perturbed_links.groupby('time_step')['status'].value_counts() / model.commercial_link_table.shape[0])

In [None]:
who = 'RUS'
sum([model.firms[firm_id].order_book[who] for firm_id in model.commercial_link_table.loc[model.commercial_link_table['buyer_id'] == who, 'supplier_id'].to_list()])

In [None]:
model.commercial_link_table.loc[model.commercial_link_table['buyer_id'] == "hh_25", :]

In [None]:
plot_perturbed_link(perturbed_links, 1, "price")

In [None]:
plot_perturbed_link(perturbed_links, 100, "fulfilment_rate")

In [None]:
perturbed_links['undelivered'] = perturbed_links['order'] - perturbed_links['delivery']
fig = make_subplots(rows=1, cols=2,
                    subplot_titles=("Product Type", "Category"),
                    specs=[[{"type": "domain"}, {"type": "domain"}]])  # Domain for pie charts
fig.add_trace(px.pie(perturbed_links, values="undelivered", names="product_type").data[0], row=1, col=1)
fig.add_trace(px.pie(perturbed_links, values="undelivered", names="category").data[0], row=1, col=2)
fig.update_layout(title_text="Undelivered product", showlegend=True, height=300)
fig.show()

In [None]:
perturbed_links.sort_values("undelivered", ascending=False)

In [None]:
perturbed_products = perturbed_links.groupby(['product', 'product_type', 'category', 'time_step'], as_index=False)[['order', 'delivery']].sum()
perturbed_products['baseline_order'] = perturbed_products["product"].map(perturbed_products[perturbed_products['time_step'] == 1].groupby("product")["order"].sum())
perturbed_products['relative_delivery'] = perturbed_products['delivery'] / perturbed_products['baseline_order']
px.line(perturbed_products.groupby(['time_step', 'product'], as_index=False)['relative_delivery'].mean(), x="time_step", y="relative_delivery", color="product")

In [None]:
perturbed_products = perturbed_links.groupby(['product_type', 'category', 'time_step'], as_index=False)[['order', 'delivery']].sum()
perturbed_products['baseline_order'] = perturbed_products["product_type"].map(perturbed_products[perturbed_products['time_step'] == 1].groupby("product_type")["order"].sum())
perturbed_products['relative_delivery'] = perturbed_products['delivery'] / perturbed_products['baseline_order']
px.line(perturbed_products.groupby(['time_step', 'product_type'], as_index=False)['relative_delivery'].mean(), x="time_step", y="relative_delivery", color="product_type")

In [None]:
more_expensive_links = perturbed_links[(perturbed_links['price'] > 1)].copy()
more_expensive_links['extra_spending'] = more_expensive_links['delivery'] * (more_expensive_links['price'] - 1)
more_expensive_links['household_region'] = more_expensive_links['buyer_id'].map(model.household_table.set_index('household')['region'])
more_expensive_links['product_origin'] = more_expensive_links['product'].apply(lambda s: s.split('_')[0])
more_expensive_links = more_expensive_links.dropna(subset='household_region')
more_expensive_links.groupby(['product_origin', 'product_type', 'household_region'])['extra_spending'].sum()

In [None]:
unfulfilled_links = perturbed_links[(perturbed_links['fulfilment_rate'] < 1)].copy()
unfulfilled_links['consumption_loss'] = unfulfilled_links['order'] - unfulfilled_links['delivery']
unfulfilled_links['household_region'] = unfulfilled_links['buyer_id'].map(model.household_table.set_index('household')['region'])
unfulfilled_links['product_origin'] = unfulfilled_links['product'].apply(lambda s: s.split('_')[0])
unfulfilled_links = unfulfilled_links.dropna(subset='household_region')
unfulfilled_links.groupby(['product_origin', 'product_type', 'household_region'])['consumption_loss'].sum()

In [None]:
import matplotlib.patches as mpatches
from shapely.geometry import LineString

what = "fulfilment_rate"  # price
threshold = 1.05
time_step = 1

# Create LineString geometries for connections
perturbed_links_gdf = perturbed_links[perturbed_links['use_transport_network']].copy()
perturbed_links_gdf = perturbed_links_gdf[perturbed_links_gdf['time_step'] == time_step]
perturbed_links_gdf["geometry"] = perturbed_links_gdf.apply(
    lambda row: LineString([model.transport_nodes.loc[row["from"], "geometry"], 
                            model.transport_nodes.loc[row["to"], "geometry"]]), axis=1)
perturbed_links_gdf = gpd.GeoDataFrame(perturbed_links_gdf, geometry="geometry", crs=model.transport_nodes.crs)

# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
add_borders(ax, model.mrio.regions)
flow_dif_on_edges["line_width"] = flow_dif_on_edges["flow_total"] / flow_dif_on_edges["flow_total"].max() * 1  # Adjust factor for visibility
color_map = {"increase": "green", "decrease": "orange"}
for change_type, color in color_map.items():
    flow_dif_on_edges[flow_dif_on_edges["change"] == change_type].plot(
        ax=ax,
        color=color,
        linewidth=flow_dif_on_edges.loc[flow_dif_on_edges["change"] == change_type, "line_width"],  
        linestyle="-",
        alpha=0.8,
        capstyle="round",
        label=f"Change: {change_type}"
    )
disrupted_edges_midpoint = model.transport_edges[model.transport_edges['id'].isin(disrupted_edges)].copy()
disrupted_edges_midpoint["geometry"] = disrupted_edges_midpoint["geometry"].apply(lambda line: line.interpolate(0.5, normalized=True) if line else None)
disrupted_edges_midpoint.plot(ax=ax, color="red", markersize=100, marker="o", label="Midpoints", zorder=10)

cond = perturbed_links_gdf['fulfilment_rate'] < 0.95
#cond = perturbed_links_gdf[what] > threshold
if cond.any():
    perturbed_links_gdf[cond].plot(ax=ax, color="red", linewidth=2, label="Connections")

legend_patches = [
    mpatches.Patch(color="green", label="Increase (+)"),
    mpatches.Patch(color="orange", label="Decrease (-)"),
    mpatches.Patch(color="red", label="Disrupted edges"),
]
ax.legend(handles=legend_patches, title="Legend", loc="upper right")

plt.show()

### Perturbed agents

In [None]:
def plot_impact(df, relative=True, agg=True):
    agent = ''
    if "country" in df.columns:
        agent = "country"
        color_col = "country"
        baseline = "spending"
    elif "household" in df.columns:
        agent = "household"
        color_col = "region"
        baseline = "tot_consumption"
    df['baseline_spending'] = df[agent].map(df[df['time_step'] == 0].set_index(agent)[baseline])
    if agg:
        df = df.groupby('time_step', as_index=False)[['baseline_spending', "extra_spending", "consumption_loss"]].sum()
        color_col = "time_step"
    else:
        color_col = agent
    df['relative_extra_spending'] = 1 - df['extra_spending'] / df['baseline_spending']
    df['relative_consumption_loss'] = 1 - df['consumption_loss'] / df['baseline_spending']
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Extra spending", "Consumption loss"))
    if relative:
        add_full_trace(fig, px.line(df, x="time_step", y="relative_extra_spending", color=color_col).data, (1,1))
        add_full_trace(fig, px.line(df, x="time_step", y="relative_consumption_loss", color=color_col).data, (1,2))
        fig.update_layout(title_text="Impacts", showlegend=False, height=300)
    else:
        add_full_trace(fig, px.bar(df, x="time_step", y="extra_spending", color=color_col).data, (1,1))
        add_full_trace(fig, px.bar(df, x="time_step", y="consumption_loss", color=color_col).data, (1,2))
        fig.update_layout(title_text="Relative impacts", showlegend=False, height=300, barmode="stack")
    fig.show()

In [None]:
country_df = pd.DataFrame(simulation.country_data)
household_df = pd.DataFrame(simulation.household_data)
model.household_table['household'] = 'hh_' + model.household_table['id'].astype(str)
household_df['region'] = household_df['household'].map(model.household_table.set_index('household')["region"])

In [None]:
household_df.groupby(['region'])[['extra_spending', 'consumption_loss']].sum()

In [None]:
plot_impact(country_df, relative=True)

In [None]:
periods = [30, 90, 180]
household_df['total_loss'] = household_df['extra_spending'] + household_df['consumption_loss']
ts = household_df.groupby('time_step')['total_loss'].sum()
baseline = household_df.loc[household_df['time_step'] == 0, 'tot_consumption'].sum()
res2 = {
    period: ts[:period].sum() / (baseline * period)
    for period in periods
}
res2

In [None]:
periods = [30, 90, 180]
household_df['total_loss'] = household_df['extra_spending'] + household_df['consumption_loss']
ts = household_df.groupby('time_step')['total_loss'].sum()
baseline = household_df.loc[household_df['time_step'] == 0, 'tot_consumption'].sum()
res2 = {
    period: ts[:period].sum() / (baseline * period)
    for period in periods
}
res2

In [None]:
res1

In [None]:
periods = [30, 90, 180]
household_df['total_loss'] = household_df['extra_spending'] + household_df['consumption_loss']
ts = household_df.groupby('time_step')['total_loss'].sum()
baseline = household_df.loc[household_df['time_step'] == 0, 'tot_consumption'].sum()
res1 = {
    period: ts[:period].sum() / (baseline * period)
    for period in periods
}

In [None]:
plot_impact(household_df, relative=False, agg=True)

In [None]:
plot_impact(household_df, relative=True, agg=False)

In [None]:
firm_df = pd.DataFrame(simulation.firm_data)
model.firm_table['firm'] = model.firm_table['id']
firm_df = firm_df.merge(model.firm_table[['firm', 'region', 'sector', "region_sector", 'sector_type']], on="firm", how="left")
firm_df['baseline_production'] = firm_df["firm"].map(firm_df[firm_df['time_step'] == 0].set_index("firm")["production"])
firm_df['relative_production'] = firm_df['production'] / firm_df['baseline_production']

In [None]:
sector_df = firm_df.groupby(['region_sector', 'region', "time_step"])['production'].sum().reset_index()
sector_df['baseline_production'] = sector_df["region_sector"].map(sector_df[sector_df['time_step'] == 0].set_index("region_sector")["production"])
sector_df['relative_production'] = sector_df['production'] / sector_df['baseline_production']
sector_df = sector_df.sort_values('time_step')

In [None]:
px.line(sector_df, x="time_step", y="relative_production", color="region_sector")

In [None]:
px.line(firm_df, x="time_step", y="relative_production", color="firm")

In [None]:
inventory_df = [(row['region_sector'], row['time_step'], input_name, inventory) 
                for _, row in firm_df.iterrows()
                for input_name, inventory in row['inventory_duration'].items()]
inventory_df = pd.DataFrame(inventory_df, columns=['region_sector', 'time_step', 'input', 'inventory'])
inventory_df = inventory_df.groupby(['region_sector', 'time_step', 'input'], as_index=False)['inventory'].sum()
inventory_df['id'] = inventory_df['input'] + "->" + inventory_df['region_sector']
inventory_df['baseline'] = inventory_df["id"].map(inventory_df[inventory_df['time_step'] == 0].set_index("id")["inventory"])
inventory_df['relative_inventory'] = inventory_df['inventory'] / inventory_df['baseline']
print(inventory_df.shape)
inventory_df.head()

In [None]:
px.line(inventory_df, x="time_step", y="relative_inventory", color="id")

# Sandbox

In [None]:
capital = pd.concat([pd.DataFrame(model.firms.get_properties('sector'), index=[0]).transpose(), pd.DataFrame(model.firms.get_properties('capital_initial'), index=[0]).transpose()], axis=1)
capital.columns = ["sector", "capital"]
capital = capital.groupby('sector')['capital'].sum()
capital.head()

In [None]:
output = pd.concat([pd.DataFrame(model.firms.get_properties('sector'), index=[0]).transpose(), pd.DataFrame(model.firms.get_properties('eq_production'), index=[0]).transpose()], axis=1)
output.columns = ["sector", "output"]
output = output.groupby('sector')['output'].sum() * 365
output.head()

In [None]:
firm = model.firms[461]
print(firm.eq_production, firm.production, firm.rationing, firm.total_order, firm.production_target)
df = pd.DataFrame({"eq_needs": firm.eq_needs, "input_needs": firm.input_needs, "inventory": firm.inventory, "input_mix": firm.input_mix})
df

In [None]:
df['inventory'] / df['input_mix']

In [None]:
input_used = {input_id: firm.production * mix for input_id, mix in firm.input_mix.items()}
input_used

In [None]:
{input_id: quantity - input_used[input_id] for input_id, quantity in firm.inventory.items()}

In [None]:
firm.suppliers

In [None]:
model.firms.select_by_property('region_sector', ['ECU_AYG'])

In [None]:
inventory_df[inventory_df['inventory'] < 1e-3]

In [None]:
country_df = pd.DataFrame(simulation.country_data)
country_df['baseline_spending'] = country_df['country'].map(country_df[country_df['time_step'] == 0].set_index('country')['spending'])
country_df['relative_extra_spending'] = country_df['extra_spending'] / country_df['baseline_spending']
country_df['relative_consumption_loss'] = country_df['consumption_loss'] / country_df['baseline_spending']
fig = make_subplots(rows=1, cols=2,
                    subplot_titles=("Extra spending", "Consumption loss"))
fig_extra = px.bar(country_df, x="time_step", y="relative_extra_spending", color="country")
for trace in fig_extra.data:
    fig.add_trace(trace, row=1, col=1)
fig_loss = px.bar(country_df, x="time_step", y="relative_consumption_loss", color="country")
for trace in fig_loss.data:
    fig.add_trace(trace, row=1, col=2)
fig.update_layout(title_text="Country impacts", showlegend=True, height=300, barmode="stack")
fig.show()

# Test transport connectivity

In [None]:
from itertools import combinations, product
import networkx as nx
from tqdm import tqdm


def identify_unconnected_pairs(graph, pairs):
    print(f"There are {len(pairs)} pairs of nodes")
    pairs_not_connected = []
    for node1, node2 in tqdm(pairs, desc="Processing pairs", total=len(pairs)):
        if not nx.has_path(model.transport_network, node1, node2):
            pairs_not_connected.append((node1, node2))
    print(f"There are {len(pairs_not_connected)} disconnected pairs")
    return pairs_not_connected


def identify_unconnected_pairs_two_sets(graph, set1, set2):
    pairs = list(product(set1, set2))
    return identify_unconnected_pairs(graph, pairs)


def identify_unconnected_pairs_one_set(graph, set1):
    pairs = list(combinations(set1, 2))
    return identify_unconnected_pairs(graph, pairs)

In [None]:
countries_not_connected = identify_unconnected_pairs_one_set(model.transport_network, set(model.countries.get_properties('od_point').values()))

In [None]:
countries_households_not_connected = identify_unconnected_pairs_two_sets(model.transport_network, 
                                                              set(model.countries.get_properties('od_point').values()),
                                                              set(model.household_table['od_point'].unique())
                                                             )

In [None]:
firms_households_not_connected = identify_unconnected_pairs_two_sets(model.transport_network, 
                                                                     set(model.firm_table['od_point'].unique()),
                                                                     set(model.household_table['od_point'].unique())
                                                                    )

In [None]:
firms_countries_not_connected = identify_unconnected_pairs_two_sets(model.transport_network, 
                                                                     set(model.firm_table['od_point'].unique()),
                                                                     set(model.countries.get_properties('od_point').values())
                                                                    )

In [None]:
firms_not_connected = identify_unconnected_pairs_one_set(model.transport_network, set(model.firm_table['od_point'].unique()))

In [None]:
# Create the plot
trip = (220, 1453)
margin = 1
nodes = model.transport_nodes[model.transport_nodes['id'].isin(trip)]
minx, miny, maxx, maxy = nodes.total_bounds

fig, ax = plt.subplots(figsize=(10, 8))
model.transport_edges.plot(column="type", legend=True, ax=ax, cmap="tab10")  # Use a colormap like "tab10"*
nodes.plot(ax=ax)  # Use a colormap like "tab10"*
ax.set_xlim(min(minx, maxx) - margin, max(minx, maxx) + margin)
ax.set_ylim(min(miny, maxy) - margin, max(miny, maxy) + margin)
plt.show()

In [None]:
# get nodes disconnected to many pairs
import pandas as pd
df = pd.DataFrame(firms_not_connected)
print(str(df[0].value_counts().iloc[:30].index.to_list()))
print(pd.concat([df[0]]).value_counts().index)
pd.concat([df[0], df[1]]).value_counts().iloc[:10]

In [None]:
country_firm_pairs = product(set(model.country_table['od_point'].unique()), set(model.firm_table['od_point'].unique()))
country_firm_not_connected = []
for country, firm in tqdm(country_firm_pairs, desc="Processing pairs"):
    if not G.has_edge(country, firm):
        country_firm_not_connected.append((country, firm))

In [None]:
country_household_pairs = product(set(model.country_table['od_point'].unique()), set(model.household_table['od_point'].unique()))
country_household_not_connected = []
for country, firm in tqdm(country_household_pairs, desc="Processing pairs"):
    if not G.has_edge(country, firm):
        country_household_not_connected.append((country, firm))

In [None]:
firm_household_pairs = product(set(model.firm_table['od_point'].unique()), set(model.household_table['od_point'].unique()))
firm_household_not_connected = []
for country, firm in tqdm(firm_household_pairs, desc="Processing pairs"):
    if not G.has_edge(country, firm):
        firm_household_not_connected.append((country, firm))

In [None]:
model.transport_edges

In [None]:
events = model.parameters.events
events[0]

In [None]:
from disruptsc.disruption.disruption import DisruptionList, TransportDisruption
DisruptionList.from_events_parameter(model.parameters.events,  model.parameters.monetary_units_in_model,
                                                                    model.transport_edges, model.firm_table,
                                                                    model.firms)

In [None]:
disruption_object = TransportDisruption.from_edge_attributes(
                        edges=model.transport_edges,
                        attribute=events[0]['attribute'],
                        values=events[0]['values']
                    )
disruption_object

In [None]:
condition = [model.transport_edges[events[0]['attribute']].str.contains(value) for value in events[0]['values']]
condition

In [None]:

condition = pandas.concat(condition, axis=1)
condition = condition.any(axis=1)