# Control groups
Inspect what our control groups actually look like. Given we have so many degrees of freedom, this entails charting a bunch of seperate lines.

We need to view the following:

US imports:

    - Tariffed vs non-tariffed goods
    - From China and RoW

China exports:

    - Tariffed vs non-tariffed goods


UK imports:

    - Tariffed vs non-tariffed goods
    - From China and RoW

## Setup

In [1]:
import polars as pl
import plotly.express as px
import pandas as pd
from typing import List

import plotly.graph_objs as go
from plotly.subplots import make_subplots


In [2]:
# 1. Load data
unified_lf_path = "/Users/lukasalemu/Documents/00. Bank of England/03. MPIL/tariff_trade_analysis/data/final/unified_trade_tariff_partitioned"
unified_lf = pl.scan_parquet(unified_lf_path)

In [3]:
HM_tariffs: pl.DataFrame = pl.read_csv(
    "/Users/lukasalemu/Documents/00. Bank of England/03. MPIL/tariff_trade_analysis/data/intermediate/cm_us_tariffs.csv", 
    try_parse_dates=True
)

tariffed_products = HM_tariffs.filter(
    pl.col("Effective Date") > pd.Timestamp("2019-01-01")
)

tariffed_products: List = tariffed_products['product_code'].cast(pl.Utf8).to_list()

In [4]:
unified_lf.head().collect()

year,reporter_country,partner_country,product_code,value,quantity,tariff_rate_pref,min_rate_pref,max_rate_pref,tariff_rate_mfn,min_rate_mfn,max_rate_mfn,average_tariff,unit_value,value_global_trend,value_detrended,quantity_global_trend,quantity_detrended,price_global_trend,unit_value_detrended,official_tariff,average_tariff_official,value_global_trend_right,quantity_global_trend_right,price_global_trend_right
str,str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""1995""","""250""","""566""","""850720""",130.979004,10.187,0.0,0.0,0.0,57.509998,57.509998,57.509998,0.0,12.857466,2.925173,128.053833,2.925173,7.261827,2.925173,9.932293,0.0,0.0,2.925173,2.925174,2.925173
"""1995""","""250""","""608""","""620193""",83.441002,0.699,11.7,11.7,11.7,,,,11.7,119.371964,22.935419,60.505585,22.935429,-22.236429,22.935429,96.436539,0.0,11.7,22.935425,22.935425,22.935425
"""1995""","""276""","""686""","""300490""",22.0,1.2,0.0,0.0,0.0,,,,0.0,18.333332,41.930824,-19.930824,41.930832,-40.730831,41.930824,-23.597492,0.0,0.0,41.930836,41.930828,41.930828
"""1995""","""250""","""392""","""400610""",1.604,0.097,,,,3.9,3.9,3.9,3.9,16.536081,2.697191,-1.093191,2.697191,-2.600191,2.69719,13.838892,0.0,3.9,2.69719,2.69719,2.69719
"""1995""","""250""","""446""","""880330""",155.449997,0.419,0.0,0.0,0.0,,,,0.0,371.00238,495.343719,-339.893738,495.343781,-494.924774,495.343842,-124.341461,0.0,0.0,495.343658,495.343781,495.343781


In [5]:
# Drop alu and steel from the set
alu_steel_product_codes = [
    # Steel Products
    # 720610 through 721650
    "720610", "720690", "720711", "720712", "720719", "720720", "720810",
    "720825", "720826", "720827", "720836", "720837", "720838", "720839",
    "720840", "720851", "720852", "720853", "720854", "720890", "720915",
    "720916", "720917", "720918", "720925", "720926", "720927", "720928",
    "720990", "721011", "721012", "721020", "721030", "721041", "721049",
    "721050", "721061", "721069", "721070", "721090", "721113", "721114",
    "721119", "721123", "721129", "721190", "721210", "721220", "721230",
    "721240", "721250", "721260", "721310", "721320", "721391", "721399",
    "721410", "721420", "721430", "721491", "721499", "721510", "721550",
    "721590", "721610", "721621", "721622", "721631", "721632", "721633",
    "721640", "721650",

    # 721699 through 730110
    "721699", "721710", "721720", "721730", "721790", "721810", "721891",
    "721899", "721911", "721912", "721913", "721914", "721921", "721922",
    "721923", "721924", "721931", "721932", "721933", "721934", "721935",
    "721990", "722011", "722012", "722020", "722090", "722100", "722211",
    "722219", "722220", "722230", "722240", "722300", "722410", "722490",
    "722511", "722519", "722530", "722540", "722550", "722591", "722592",
    "722599", "722611", "722619", "722620", "722691", "722692", "722699",
    "722710", "722720", "722790", "722810", "722820", "722830", "722840",
    "722850", "722860", "722870", "722880", "722920", "722990", "730110",

    # 730210
    "730210",

    # 730240 through 730290
    "730240", "730290",

    # 730410 through 730690
    "730411", "730419", "730422", "730423", "730424", "730429", "730431",
    "730439", "730441", "730449", "730451", "730459", "730490", "730511",
    "730512", "730519", "730520", "730531", "730539", "730590", "730611",
    "730619", "730621", "730629", "730630", "730640", "730650", "730661",
    "730669", "730690",

    # Aluminum Products
    # 7601 (Unwrought aluminum)
    "760110", "760120",

    # 7604 (Aluminum bars, rods, and profiles)
    "760410", "760421", "760429",

    # 7605 (Aluminum wire)
    "760511", "760519", "760521", "760529",

    # 7606 (Aluminum plates, sheets, and strip)
    "760611", "760612", "760691", "760692",

    # 7607 (Aluminum foil)
    "760711", "760719", "760720",

    # 7608 (Aluminum tubes and pipes)
    "760810", "760820",

    # 7609 (Aluminum tube or pipe fittings)
    "760900",
]

unified_lf = unified_lf.filter(
    ~pl.col('product_code').is_in(alu_steel_product_codes)
)


## US total imports vs RoW imports - tariffed vs non-tariffed
**not** specifically from China. This is about general demand - showing one of the diff-in-diffs we're doing here

In [6]:
USA_CODE = '840'

# Prepare the data by classifying products and regions
classified_lf = unified_lf.filter(
    pl.col("year").is_in([str(y) for y in range(2017, 2021)])
).with_columns(
    tariff_status=pl.when(pl.col("product_code").is_in(tariffed_products))
                   .then(pl.lit("Tariffed"))
                   .otherwise(pl.lit("Non-Tariffed")),
    region=pl.when(pl.col("reporter_country") == USA_CODE)
              .then(pl.lit("USA"))
              .otherwise(pl.lit("Rest of World"))
)

# Group by region, tariff status, and year, and aggregate all required metrics
aggregated_df = classified_lf.group_by(
    ['region', 'tariff_status', 'year']
).agg(
    pl.sum("value").alias("value"),
    pl.sum("quantity").alias("quantity")
).sort('year').with_columns(
    # Calculate unit value, handling potential division by zero
    unit_value=(pl.col("value") / pl.col("quantity")).fill_null(0).fill_nan(0)
)

# Calculate indexed values for each metric
indexed_df = aggregated_df.with_columns(
    value_indexed=pl.col("value") / pl.col("value").filter(pl.col("year") == "2017").first().over(['region', 'tariff_status']) * 100,
    quantity_indexed=pl.col("quantity") / pl.col("quantity").filter(pl.col("year") == "2017").first().over(['region', 'tariff_status']) * 100,
    unit_value_indexed=pl.col("unit_value") / pl.col("unit_value").filter(pl.col("year") == "2017").first().over(['region', 'tariff_status']) * 100,
).collect()

# Create a figure with 3 vertically stacked subplots
fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=("Indexed Import Value", "Indexed Import Quantity", "Indexed Import Unit Value")
)

# Define plot styles for each category
styles = {
    ('USA', 'Tariffed'): {'color': 'crimson', 'dash': 'solid'},
    ('USA', 'Non-Tariffed'): {'color': 'royalblue', 'dash': 'solid'},
    ('Rest of World', 'Tariffed'): {'color': 'crimson', 'dash': 'dot'},
    ('Rest of World', 'Non-Tariffed'): {'color': 'royalblue', 'dash': 'dot'},
}

# Iterate through each data category to add traces to the subplots
for (region, status), style in styles.items():
    plot_data = indexed_df.filter(
        (pl.col("region") == region) & (pl.col("tariff_status") == status)
    )
    if plot_data.is_empty():
        continue
    
    legend_name = f"{region} - {status}"
    
    # Add traces for each metric to its respective subplot
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["value_indexed"],
        name=legend_name, legendgroup=legend_name,
        mode='lines', line=style
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["quantity_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["unit_value_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=3, col=1)

# Update layout for a clean, readable chart
fig.update_layout(
    height=800,
    title_text='Import Trends: USA vs. Rest of World',
    legend_title_text='Region & Product Type',
    hovermode='x unified',
    hoversubplots='axis',
)
fig.update_yaxes(title_text="Indexed to 2017", row=1, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=2, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=3, col=1)
fig.update_xaxes(title_text="Year", row=3, col=1)

fig.show()

## Chinese exports - US vs RoW, Tariffed vs non-Tariffed goods

Exports from China to the US and to the RoW - comparing over those goods which were tariffed and those which weren't

In [7]:
USA_CODE = "840"
CHINA_CODE = "156"

# Prepare data: filter for Chinese exports and classify destinations/products
chinese_exports_lf = unified_lf.filter(
    (pl.col("reporter_country") == CHINA_CODE) &
    (pl.col("year").is_in([str(y) for y in range(2017, 2021)]))
).with_columns(
    tariff_status=pl.when(pl.col("product_code").is_in(tariffed_products))
                   .then(pl.lit("Tariffed by US"))
                   .otherwise(pl.lit("Non-Tariffed by US")),
    destination=pl.when(pl.col("partner_country") == USA_CODE)
                 .then(pl.lit("USA"))
                 .otherwise(pl.lit("Rest of World"))
)

# Group by destination, tariff status, and year, and aggregate all required metrics
aggregated_df = chinese_exports_lf.group_by(
    ['destination', 'tariff_status', 'year']
).agg(
    pl.sum("value").alias("value"),
    pl.sum("quantity").alias("quantity")
).sort('year').with_columns(
    # Calculate unit value, handling potential division by zero
    unit_value=(pl.col("value") / pl.col("quantity")).fill_null(0).fill_nan(0)
)

# Calculate indexed values for each metric
indexed_df = aggregated_df.with_columns(
    value_indexed=pl.col("value") / pl.col("value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    quantity_indexed=pl.col("quantity") / pl.col("quantity").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    unit_value_indexed=pl.col("unit_value") / pl.col("unit_value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
).collect()

# Create a figure with 3 vertically stacked subplots
fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=("Indexed Export Value", "Indexed Export Quantity", "Indexed Export Unit Value")
)

# Define plot styles for each category
styles = {
    ('USA', 'Tariffed by US'): {'color': 'crimson', 'dash': 'solid'},
    ('USA', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'solid'},
    ('Rest of World', 'Tariffed by US'): {'color': 'crimson', 'dash': 'dot'},
    ('Rest of World', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'dot'},
}

# Iterate through each data category to add traces to the subplots
for (destination, status), style in styles.items():
    plot_data = indexed_df.filter(
        (pl.col("destination") == destination) & (pl.col("tariff_status") == status)
    )
    if plot_data.is_empty():
        continue
    
    legend_name = f"{destination} - {status}"
    
    # Add traces for each metric to its respective subplot
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["value_indexed"],
        name=legend_name, legendgroup=legend_name,
        mode='lines', line=style
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["quantity_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["unit_value_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=3, col=1)

# Update layout for a clean, readable chart
fig.update_layout(
    height=800,
    title_text='Chinese Export Trends: US vs. Rest of World (2017-2023)',
    legend_title_text='Destination & Product Type',
    hovermode='x unified'
)
fig.update_yaxes(title_text="Indexed to 2017", row=1, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=2, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=3, col=1)
fig.update_xaxes(title_text="Year", row=3, col=1)

fig.show()

## Chinese+RelevantCountries Exports, US vs RoW

Adding a basket of countries which we might imagine diversion to go through: 

- China
- Vietnam
- Mexico
- Taiwan
- South Korea
- Malaysia
- Thailand
- Cambodia 


In [8]:

EXPORTER_GROUP_CODES = [
    "156",  # China
    "704",  # Vietnam
    "484",  # Mexico
    "158",  # Taiwan
    "410",  # South Korea
    "458",  # Malaysia
    "764",  # Thailand
    "116",  # Cambodia   
]
USA_CODE = "840"

# Prepare data: filter for the exporter group and classify destinations/products
exporter_group_lf = unified_lf.filter(
    (pl.col("reporter_country").is_in(EXPORTER_GROUP_CODES)) &
    (pl.col("year").is_in([str(y) for y in range(2017, 2023)]))
).with_columns(
    tariff_status=pl.when(pl.col("product_code").is_in(tariffed_products))
                   .then(pl.lit("Tariffed by US"))
                   .otherwise(pl.lit("Non-Tariffed by US")),
    destination=pl.when(pl.col("partner_country") == USA_CODE)
                 .then(pl.lit("USA"))
                 .otherwise(pl.lit("Rest of World"))
)

# Group by destination, tariff status, and year, and aggregate all required metrics
aggregated_df = exporter_group_lf.group_by(
    ['destination', 'tariff_status', 'year']
).agg(
    pl.sum("value").alias("value"),
    pl.sum("quantity").alias("quantity")
).sort('year').with_columns(
    # Calculate unit value, handling potential division by zero
    unit_value=(pl.col("value") / pl.col("quantity")).fill_null(0).fill_nan(0)
)

# Calculate indexed values for each metric
indexed_df = aggregated_df.with_columns(
    value_indexed=pl.col("value") / pl.col("value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    quantity_indexed=pl.col("quantity") / pl.col("quantity").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    unit_value_indexed=pl.col("unit_value") / pl.col("unit_value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
).collect()

# Create a figure with 3 vertically stacked subplots
fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=("Indexed Export Value", "Indexed Export Quantity", "Indexed Export Unit Value")
)

# Define plot styles for each category
styles = {
    ('USA', 'Tariffed by US'): {'color': 'crimson', 'dash': 'solid'},
    ('USA', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'solid'},
    ('Rest of World', 'Tariffed by US'): {'color': 'crimson', 'dash': 'dot'},
    ('Rest of World', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'dot'},
}

# Iterate through each data category to add traces to the subplots
for (destination, status), style in styles.items():
    plot_data = indexed_df.filter(
        (pl.col("destination") == destination) & (pl.col("tariff_status") == status)
    )
    if plot_data.is_empty():
        continue
    
    legend_name = f"{destination} - {status}"
    
    # Add traces for each metric to its respective subplot
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["value_indexed"],
        name=legend_name, legendgroup=legend_name,
        mode='lines', line=style
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["quantity_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["unit_value_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=3, col=1)

# Update layout for a clean, readable chart
fig.update_layout(
    height=800,
    title_text='Exports from China & Diversion Countries: US vs. RoW (2017-2023)',
    legend_title_text='Destination & Product Type',
    hovermode='x unified'
)
fig.update_yaxes(title_text="Indexed to 2017", row=1, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=2, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=3, col=1)
fig.update_xaxes(title_text="Year", row=3, col=1)

fig.show()

## Chinese exports - tariffed vs non-tariffed to UK vs RoW

Comparing Chinese exports over two dimensions - where they went, and whether they were tariffed or not.

Our ideal is that exports of tariffed goods increased from China to both RoW and the UK, and stayed the same for non-tariffed goods.

In [9]:
UK_CODE = "826"
# Prepare data: filter for Chinese exports and classify destinations/products
chinese_exports_lf = unified_lf.filter(
    (pl.col("reporter_country") == CHINA_CODE) &
    (pl.col("year").is_in([str(y) for y in range(2017, 2021)]))
).with_columns(
    tariff_status=pl.when(pl.col("product_code").is_in(tariffed_products))
                   .then(pl.lit("Tariffed by US"))
                   .otherwise(pl.lit("Non-Tariffed by US")),
    destination=pl.when(pl.col("partner_country") == UK_CODE)
                 .then(pl.lit("UK"))
                 .otherwise(pl.lit("Rest of World"))
)

# Group by destination, tariff status, and year, and aggregate all required metrics
aggregated_df = chinese_exports_lf.group_by(
    ['destination', 'tariff_status', 'year']
).agg(
    pl.sum("value").alias("value"),
    pl.sum("quantity").alias("quantity")
).sort('year').with_columns(
    # Calculate unit value, handling potential division by zero
    unit_value=(pl.col("value") / pl.col("quantity")).fill_null(0).fill_nan(0)
)

# Calculate indexed values for each metric
indexed_df = aggregated_df.with_columns(
    value_indexed=pl.col("value") / pl.col("value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    quantity_indexed=pl.col("quantity") / pl.col("quantity").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
    unit_value_indexed=pl.col("unit_value") / pl.col("unit_value").filter(pl.col("year") == "2017").first().over(['destination', 'tariff_status']) * 100,
).collect()

# Create a figure with 3 vertically stacked subplots
fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=("Indexed Export Value", "Indexed Export Quantity", "Indexed Export Unit Value")
)

# Define plot styles for each category
styles = {
    ('UK', 'Tariffed by US'): {'color': 'crimson', 'dash': 'solid'},
    ('UK', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'solid'},
    ('Rest of World', 'Tariffed by US'): {'color': 'crimson', 'dash': 'dot'},
    ('Rest of World', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'dot'},
}

# Iterate through each data category to add traces to the subplots
for (destination, status), style in styles.items():
    plot_data = indexed_df.filter(
        (pl.col("destination") == destination) & (pl.col("tariff_status") == status)
    )
    if plot_data.is_empty():
        continue
    
    legend_name = f"{destination} - {status}"
    
    # Add traces for each metric to its respective subplot
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["value_indexed"],
        name=legend_name, legendgroup=legend_name,
        mode='lines', line=style
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["quantity_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["unit_value_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=3, col=1)

# Update layout for a clean, readable chart
fig.update_layout(
    height=800,
    title_text='Chinese Export Trends: UK vs. Rest of World (2017-2023)',
    legend_title_text='Destination & Product Type',
    hovermode='x unified'
)
fig.update_yaxes(title_text="Indexed to 2017", row=1, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=2, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=3, col=1)
fig.update_xaxes(title_text="Year", row=3, col=1)

fig.show()

## UK import trends: China vs RoW on tariffed vs non-tariffed products

How did UK imports evolve, comparing those from RoW and those from China, over tariffed vs non-tariffed goods? 

In [10]:

# Prepare data: filter for Chinese exports and classify destinations/products
uk_imports_lf = unified_lf.filter(
    (pl.col("partner_country") == UK_CODE) &
    (pl.col("year").is_in([str(y) for y in range(2017, 2021)]))
).with_columns(
    tariff_status=pl.when(pl.col("product_code").is_in(tariffed_products))
                   .then(pl.lit("Tariffed by US"))
                   .otherwise(pl.lit("Non-Tariffed by US")),
    source=pl.when(pl.col("reporter_country") == CHINA_CODE)
                 .then(pl.lit("China"))
                 .otherwise(pl.lit("Rest of World"))
)

# Group by destination, tariff status, and year, and aggregate all required metrics
aggregated_df = uk_imports_lf.group_by(
    ['source', 'tariff_status', 'year']
).agg(
    pl.sum("value").alias("value"),
    pl.sum("quantity").alias("quantity")
).sort('year').with_columns(
    # Calculate unit value, handling potential division by zero
    unit_value=(pl.col("value") / pl.col("quantity")).fill_null(0).fill_nan(0)
)

# Calculate indexed values for each metric
indexed_df = aggregated_df.with_columns(
    value_indexed=pl.col("value") / pl.col("value").filter(pl.col("year") == "2017").first().over(['source', 'tariff_status']) * 100,
    quantity_indexed=pl.col("quantity") / pl.col("quantity").filter(pl.col("year") == "2017").first().over(['source', 'tariff_status']) * 100,
    unit_value_indexed=pl.col("unit_value") / pl.col("unit_value").filter(pl.col("year") == "2017").first().over(['source', 'tariff_status']) * 100,
).collect()

# Create a figure with 3 vertically stacked subplots
fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=("Indexed Import Value", "Indexed Import Quantity", "Indexed Import Unit Value")
)

# Define plot styles for each category
styles = {
    ('China', 'Tariffed by US'): {'color': 'crimson', 'dash': 'solid'},
    ('China', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'solid'},
    ('Rest of World', 'Tariffed by US'): {'color': 'crimson', 'dash': 'dot'},
    ('Rest of World', 'Non-Tariffed by US'): {'color': 'royalblue', 'dash': 'dot'},
}

# Iterate through each data category to add traces to the subplots
for (source, status), style in styles.items():
    plot_data = indexed_df.filter(
        (pl.col("source") == source) & (pl.col("tariff_status") == status)
    )
    if plot_data.is_empty():
        continue
    
    legend_name = f"{source} - {status}"
    
    # Add traces for each metric to its respective subplot
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["value_indexed"],
        name=legend_name, legendgroup=legend_name,
        mode='lines', line=style
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["quantity_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=plot_data["year"], y=plot_data["unit_value_indexed"],
        name=legend_name, legendgroup=legend_name, showlegend=False,
        mode='lines', line=style
    ), row=3, col=1)

# Update layout for a clean, readable chart
fig.update_layout(
    height=800,
    title_text='UK Import Trends: China vs. Rest of World',
    legend_title_text='Destination & Product Type',
    hovermode='x unified'
)
fig.update_yaxes(title_text="Indexed to 2017", row=1, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=2, col=1)
fig.update_yaxes(title_text="Indexed to 2017", row=3, col=1)
fig.update_xaxes(title_text="Year", row=3, col=1)

fig.show()