# Update the regression setup once more

We need to do the following:

1. Improve the regression setup to better isolate our treatment. Currently the control has treated data within it (i.e. imports of Chinese goods with tariffs applied by other countries).
2. Estimate only a single number - the cross-elasticity using a weighted average of 2018, 2019, and maybe 2020. 
3. Construct a comparison of a broad range of countries.
4. Include a measure of the effects of exposure, and the effects of protectionist policy to explain the *why*

In [38]:
# Setup
import re
from typing import Optional, List, Tuple


import marimo as mo
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import polars as pl
import pycountry
import pyfixest
import plotly.io as pio
import numpy as np

import pickle
from pathlib import Path

pio.renderers.default = "notebook"

In [2]:
# Analysis function - all country regression
def run_saturated_regression(
    data: pl.LazyFrame,
    formula: str,
    year_range: list[str],
    countries_to_exclude: Optional[List[str]] = None,
    vcov: Optional[str | dict] = "hetero",
    filter_expression: Optional[pl.Expr] = None,
):
    """
    Runs a fully saturated regression to estimate country-specific effects.

    This function automatically creates an interaction term for each importing
    country to measure its unique response to the US-China tariffs.

    Args:
        data: Pre-filtered LazyFrame.
        formula: The regression formula base, e.g., "log(value) | fe1 + fe2".
                 The function will add the saturated terms to this formula.
        year_range: A list of years (as strings) for the analysis.
        countries_to_exclude: A list of country codes (as strings) to
                              exclude from the saturated model (e.g., US, China).
        vcov: The variance-covariance matrix specification for pyfixest.
        filter_expression: An optional Polars expression to filter the data
                           after tariff calculation but before regression.

    Returns:
        A tuple containing the model object, the etable, and the coefficient plot.
    """
    USA_CC = "840"
    CHINA_CC = "156"

    # 1. Prepare the base data with tariff information
    tariff_expr = (
        pl.col("average_tariff_official")
        .filter((pl.col("partner_country") == USA_CC) & (pl.col("reporter_country") == CHINA_CC))
        .mean()
        .over(["year", "product_code"])
        .alias("tariff_us_china")
    )

    base_lf = data.with_columns(
        pl.col("partner_country").alias("importer"),
        pl.col("reporter_country").alias("exporter"),
        tariff_expr,
    )

    # 2. Get all unique importer countries from the data
    if countries_to_exclude is None:
        countries_to_exclude = [CHINA_CC, USA_CC]  # Default to excluding China and USA 

    importers = base_lf.select("importer").unique().filter(~pl.col("importer").is_in(countries_to_exclude)).collect().to_series().to_list()

    # 3. Create a specific interaction term for each importer
    interaction_expressions = []
    formula_terms = []
    for country_code in importers:
        for year in year_range:
            # Create a pyfixest-friendly term name, e.g., "gamma_826_2017" for GBR
            term_name = f"gamma_{country_code}_{year}"
            formula_terms.append(term_name)

            # This term is active only for the specific country's imports from China, in the specific year
            interaction_filter = (pl.col("importer") == country_code) & (pl.col("exporter") == CHINA_CC) & (pl.col("year") == year)

            # The interaction value is the US-China tariff when conditions are met
            expression = pl.when(interaction_filter).then(pl.col("tariff_us_china")).otherwise(0.0).alias(term_name)
            interaction_expressions.append(expression)
    
    final_lf = base_lf.with_columns(*interaction_expressions)

    # Drop all imports by the USA -> we handle the USA separately
    final_lf = final_lf.filter(~pl.col("importer").is_in([USA_CC]))

    final_lf = final_lf.drop(
        [
            "tariff_rate_pref",
            "min_rate_pref",
            "max_rate_pref",
            "tariff_rate_mfn",
            "min_rate_mfn",
            "max_rate_mfn",
            "average_tariff",
            "value_global_trend",
            "value_detrended",
            "quantity_global_trend",
            "quantity_detrended",
            "price_global_trend",
            "unit_value_detrended",
            "value_global_trend_right",
            "quantity_global_trend_right",
            "price_global_trend_right",
            "reporter_country",
            "partner_country",
        ]
    )
    # 4. Dynamically construct the full formula string
    # Joins all "gamma_XXX" terms together
    saturated_effects = " + ".join(formula_terms)

    # Inserts the saturated effects into the user-provided formula
    if "|" in formula:
        parts = formula.split("|", 1)
        # Formula becomes: y ~ gamma_1 + gamma_2 + ... | fe1 + fe2
        final_formula = f"{parts[0].strip()} ~ {saturated_effects} | {parts[1].strip()}"
    else:
        final_formula = f"{formula.strip()} ~ {saturated_effects}"

    print("--- Generated Saturated Formula ---")
    print(final_formula)
    print("---------------------------------")

    # return final_lf

    # 5. Run the regression with the new data and formula
    if filter_expression is not None:
        final_lf = final_lf.filter(filter_expression)

    dependent_var_col = re.findall(r"\b\w+\b", formula.split("~")[0].strip())[-1]

    print(f"Checking for nulls in '{dependent_var_col}' and 'tariff_us_china'.")
    clean_df = final_lf.drop_nulls(subset=[dependent_var_col, "tariff_us_china"]).collect()
    # clean_df = final_lf.collect()
    
    model = pyfixest.feols(fml=final_formula, data=clean_df, vcov=vcov, store_data=False, lean=True)
    etable = pyfixest.etable(model)
    coefplot = model.coefplot(drop="Intercept", plot_backend="matplotlib")

    return model, etable, coefplot

In [3]:
# Prepare analysis data function
def prepare_analysis_data(
    source_lf: pl.LazyFrame,
    top_n: int | None = None,
    selection_year: str | None = None,
    year_range_to_keep: list[str] | None = None,
    selection_method: str = "total_trade",
    oil_export_threshold: float | None = 50.0,
    countries_to_exclude: list[str] | None = None,
    countries_to_include: list[str] | None = None,
    product_codes_to_exclude: list[str] | None = None,
) -> pl.LazyFrame:
    """
    Filters and subsets the main trade dataset to create a sample for analysis.

    Args:
        source_lf: The initial LazyFrame of the full trade dataset.
        top_n: The number of top countries to select (e.g., 40).
        selection_year: The year used to determine the top countries.
        year_range_to_keep: A list of years to keep in the final dataset.
        selection_method: Method to determine top countries ("total_trade" or "importers").
        oil_export_threshold: Export percentage to classify a country as an oil exporter.
        countries_to_exclude: A list of country codes to remove.
        countries_to_include: A specific list of country codes to use for the analysis.
        product_codes_to_exclude: A list of product code prefixes (e.g., HS chapters)
                                    to remove from the dataset.

    Returns:
        A filtered LazyFrame.
    """
    if countries_to_include and (top_n or selection_year):
        raise ValueError("'countries_to_include' cannot be used with 'top_n' or 'selection_year'.")
    if not countries_to_include and not (top_n and selection_year):
        raise ValueError("Either 'countries_to_include' or both 'top_n' and 'selection_year' must be provided.")

    print("--- Cleaning data ---")

    lf = source_lf

    if product_codes_to_exclude:
        print(f"Excluding product codes starting with: {product_codes_to_exclude}")
        exclusion_expr = pl.any_horizontal(pl.col("product_code").str.starts_with(code) for code in product_codes_to_exclude)
        lf = lf.filter(~exclusion_expr)

    if oil_export_threshold is not None:
        oil_countries = get_oil_exporting_countries(lf, oil_export_threshold)
        lf = lf.filter(~pl.col("reporter_country").is_in(oil_countries))

    if countries_to_include:
        top_countries_list = countries_to_include
    else:
        trade_in_year_lf = lf.filter(pl.col("year") == selection_year)

        if selection_method == "importers":
            top_countries_df = (
                trade_in_year_lf.group_by("partner_country")
                .agg(pl.sum("value").alias("import_value"))
                .sort("import_value", descending=True)
                .head(top_n)
                .collect()
            )
            top_countries_list = top_countries_df["partner_country"].to_list()

        elif selection_method == "total_trade":
            exports_lf = trade_in_year_lf.select(pl.col("reporter_country").alias("country"), "value")
            imports_lf = trade_in_year_lf.select(pl.col("partner_country").alias("country"), "value")

            top_countries_df = (
                pl.concat([exports_lf, imports_lf])
                .group_by("country")
                .agg(pl.sum("value").alias("total_trade"))
                .sort("total_trade", descending=True)
                .head(top_n)
                .collect()
            )
            top_countries_list = top_countries_df["country"].to_list()
        else:
            raise ValueError("selection_method must be 'importers' or 'total_trade'")

    if countries_to_exclude:
        top_countries_list = [c for c in top_countries_list if c not in countries_to_exclude]

    print(f"Final sample includes {len(top_countries_list)} countries.")

    analysis_lf = lf.filter(pl.col("reporter_country").is_in(top_countries_list) & pl.col("partner_country").is_in(top_countries_list))

    if year_range_to_keep:
        analysis_lf = analysis_lf.filter(pl.col("year").is_in(year_range_to_keep))

    return analysis_lf

In [4]:
# Get oil exporting countries
def get_oil_exporting_countries(lzdf: pl.LazyFrame, oil_export_percentage_threshold: float) -> list[str]:
    """
    Finds countries where oil products (HS code '27') exceed a certain
    percentage of their total export value.
    """
    print("--- Filtering out oil countries ---")

    total_exports = lzdf.group_by("reporter_country").agg(pl.sum("value").alias("total_value"))

    oil_exports = lzdf.filter(pl.col("product_code").str.starts_with("27")).group_by("reporter_country").agg(pl.sum("value").alias("oil_value"))

    summary = total_exports.join(oil_exports, on="reporter_country", how="left").with_columns(pl.col("oil_value").fill_null(0.0))

    summary = summary.with_columns(((pl.col("oil_value") / pl.col("total_value")) * 100).alias("oil_export_percentage"))

    filtered_countries = summary.filter(pl.col("oil_export_percentage") > oil_export_percentage_threshold)

    return filtered_countries.collect()["reporter_country"].to_list()

In [5]:
# Alumunium and Steel Product Codes
alu_steel_product_codes = [
    # Steel Products
    # 720610 through 721650
    "720610",
    "720690",
    "720711",
    "720712",
    "720719",
    "720720",
    "720810",
    "720825",
    "720826",
    "720827",
    "720836",
    "720837",
    "720838",
    "720839",
    "720840",
    "720851",
    "720852",
    "720853",
    "720854",
    "720890",
    "720915",
    "720916",
    "720917",
    "720918",
    "720925",
    "720926",
    "720927",
    "720928",
    "720990",
    "721011",
    "721012",
    "721020",
    "721030",
    "721041",
    "721049",
    "721050",
    "721061",
    "721069",
    "721070",
    "721090",
    "721113",
    "721114",
    "721119",
    "721123",
    "721129",
    "721190",
    "721210",
    "721220",
    "721230",
    "721240",
    "721250",
    "721260",
    "721310",
    "721320",
    "721391",
    "721399",
    "721410",
    "721420",
    "721430",
    "721491",
    "721499",
    "721510",
    "721550",
    "721590",
    "721610",
    "721621",
    "721622",
    "721631",
    "721632",
    "721633",
    "721640",
    "721650",
    # 721699 through 730110
    "721699",
    "721710",
    "721720",
    "721730",
    "721790",
    "721810",
    "721891",
    "721899",
    "721911",
    "721912",
    "721913",
    "721914",
    "721921",
    "721922",
    "721923",
    "721924",
    "721931",
    "721932",
    "721933",
    "721934",
    "721935",
    "721990",
    "722011",
    "722012",
    "722020",
    "722090",
    "722100",
    "722211",
    "722219",
    "722220",
    "722230",
    "722240",
    "722300",
    "722410",
    "722490",
    "722511",
    "722519",
    "722530",
    "722540",
    "722550",
    "722591",
    "722592",
    "722599",
    "722611",
    "722619",
    "722620",
    "722691",
    "722692",
    "722699",
    "722710",
    "722720",
    "722790",
    "722810",
    "722820",
    "722830",
    "722840",
    "722850",
    "722860",
    "722870",
    "722880",
    "722920",
    "722990",
    "730110",
    # 730210
    "730210",
    # 730240 through 730290
    "730240",
    "730290",
    # 730410 through 730690
    "730411",
    "730419",
    "730422",
    "730423",
    "730424",
    "730429",
    "730431",
    "730439",
    "730441",
    "730449",
    "730451",
    "730459",
    "730490",
    "730511",
    "730512",
    "730519",
    "730520",
    "730531",
    "730539",
    "730590",
    "730611",
    "730619",
    "730621",
    "730629",
    "730630",
    "730640",
    "730650",
    "730661",
    "730669",
    "730690",
    # Aluminum Products
    # 7601 (Unwrought aluminum)
    "760110",
    "760120",
    # 7604 (Aluminum bars, rods, and profiles)
    "760410",
    "760421",
    "760429",
    # 7605 (Aluminum wire)
    "760511",
    "760519",
    "760521",
    "760529",
    # 7606 (Aluminum plates, sheets, and strip)
    "760611",
    "760612",
    "760691",
    "760692",
    # 7607 (Aluminum foil)
    "760711",
    "760719",
    "760720",
    # 7608 (Aluminum tubes and pipes)
    "760810",
    "760820",
    # 7609 (Aluminum tube or pipe fittings)
    "760900",
]

In [6]:
# Country Codes
USA_CC = pycountry.countries.search_fuzzy("USA")[0].numeric
CHINA_CC = pycountry.countries.search_fuzzy("China")[0].numeric
BRAZIL_CC = pycountry.countries.search_fuzzy("Brazil")[0].numeric
IRELAND_CC = pycountry.countries.search_fuzzy("Ireland")[0].numeric
JAPAN_CC = pycountry.countries.search_fuzzy("Japan")[0].numeric
ITALY_CC = pycountry.countries.search_fuzzy("Italy")[0].numeric
SOUTHAFRICA_CC = pycountry.countries.search_fuzzy("South Africa")[0].numeric
UK_CC = pycountry.countries.search_fuzzy("United Kingdom")[0].numeric
GERMANY_CC = pycountry.countries.search_fuzzy("Germany")[0].numeric
FRANCE_CC = pycountry.countries.search_fuzzy("France")[0].numeric
KOREA_CC = pycountry.countries.search_fuzzy("Korea")[0].numeric
TURKEY_CC = pycountry.countries.search_fuzzy("Turkiye")[0].numeric
AUSTRALIA_CC = pycountry.countries.search_fuzzy("Australia")[0].numeric
SAUDI_CC = pycountry.countries.search_fuzzy("Saudi Arabia")[0].numeric
MEXICO_CC = pycountry.countries.search_fuzzy("Mexico")[0].numeric
CANADA_CC = pycountry.countries.search_fuzzy("Canada")[0].numeric
INDONESIA_CC = pycountry.countries.search_fuzzy("Indonesia")[0].numeric
INDIA_CC = pycountry.countries.search_fuzzy("India")[0].numeric
VIETNAM_CC = pycountry.countries.search_fuzzy("Vietnam")[0].numeric
RUSSIA_CC = pycountry.countries.search_fuzzy("Russia")[0].numeric
HONGKONG_CC = pycountry.countries.search_fuzzy("Hong Kong")[0].numeric

In [7]:
# Load raw data
unified_lf_path: str = "/Users/lukasalemu/Documents/00. Bank of England/03. MPIL/tariff_trade_analysis/data/final/unified_trade_tariff_partitioned"

unified_lf: pl.LazyFrame = pl.scan_parquet(unified_lf_path)

# unified_lf.head().collect()

In [8]:
# Clean raw data
analysis_lf = prepare_analysis_data(
    source_lf=unified_lf,
    top_n=34,
    selection_year="2017",
    year_range_to_keep=[str(y) for y in range(2016, 2021)],  # 2016 - 2020
    selection_method="total_trade",
    oil_export_threshold=50.0,
    countries_to_exclude=[RUSSIA_CC, HONGKONG_CC, ITALY_CC, IRELAND_CC],
    product_codes_to_exclude=alu_steel_product_codes,
)

--- Cleaning data ---
Excluding product codes starting with: ['720610', '720690', '720711', '720712', '720719', '720720', '720810', '720825', '720826', '720827', '720836', '720837', '720838', '720839', '720840', '720851', '720852', '720853', '720854', '720890', '720915', '720916', '720917', '720918', '720925', '720926', '720927', '720928', '720990', '721011', '721012', '721020', '721030', '721041', '721049', '721050', '721061', '721069', '721070', '721090', '721113', '721114', '721119', '721123', '721129', '721190', '721210', '721220', '721230', '721240', '721250', '721260', '721310', '721320', '721391', '721399', '721410', '721420', '721430', '721491', '721499', '721510', '721550', '721590', '721610', '721621', '721622', '721631', '721632', '721633', '721640', '721650', '721699', '721710', '721720', '721730', '721790', '721810', '721891', '721899', '721911', '721912', '721913', '721914', '721921', '721922', '721923', '721924', '721931', '721932', '721933', '721934', '721935', '721990'

In [9]:
analysis_lf.head().collect()

year,reporter_country,partner_country,product_code,value,quantity,tariff_rate_pref,min_rate_pref,max_rate_pref,tariff_rate_mfn,min_rate_mfn,max_rate_mfn,average_tariff,unit_value,value_global_trend,value_detrended,quantity_global_trend,quantity_detrended,price_global_trend,unit_value_detrended,official_tariff,average_tariff_official,value_global_trend_right,quantity_global_trend_right,price_global_trend_right
str,str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""2016""","""156""","""392""","""911280""",712.544983,39.585999,,,,0.0,0.0,0.0,0.0,17.999924,11.291309,701.253662,11.291299,28.294701,11.291302,6.708622,0.0,0.0,11.291303,11.2913,11.2913
"""2016""","""156""","""250""","""911280""",9.439,0.208,,,,2.7,2.7,2.7,2.7,45.379807,11.291309,-1.852309,11.291299,-11.083299,11.291302,34.088505,0.0,2.7,11.291303,11.2913,11.2913
"""2016""","""156""","""410""","""482340""",140.945007,38.558998,,,,0.0,0.0,0.0,0.0,3.655308,4.32161,136.623398,4.321611,34.237389,4.32161,-0.666302,0.0,0.0,4.321611,4.321609,4.32161
"""2016""","""156""","""392""","""910919""",772.275024,11.607,,,,0.0,0.0,0.0,0.0,66.535278,14.145529,758.129517,14.145526,-2.538526,14.14553,52.389748,0.0,0.0,14.145522,14.145525,14.14553
"""2016""","""156""","""410""","""130213""",0.6,0.01,,,,30.0,30.0,30.0,30.0,60.000004,26.585186,-25.985186,26.585188,-26.575188,26.585186,33.414818,0.0,30.0,26.585184,26.585176,26.58519


In [10]:
def plot_country_coefficients(model, country_codes: List[str]) -> go.Figure:
    """
    Generates a plot of estimated coefficients over time for selected countries,
    showing 3-letter country codes in the legend.

    Args:
        model: A fitted PyFixest feols model object.
        country_codes: A list of 3-digit numeric country codes (as strings).

    Returns:
        A Plotly Figure object showing the coefficient paths with translucent
        confidence intervals.
    """
    tidy_df = model.tidy().reset_index().rename(columns={"index": "Coefficient"})

    extracted_data = tidy_df["Coefficient"].str.extract(r"gamma_(\d+)_(\d{4})")
    tidy_df[["country_code", "year"]] = extracted_data

    plot_df = tidy_df[tidy_df["country_code"].isin(country_codes)].copy()
    plot_df["year"] = pd.to_numeric(plot_df["year"])

    fig = go.Figure()
    colors = px.colors.qualitative.Plotly
    unique_countries = plot_df["country_code"].unique()

    # First, draw all the confidence interval areas
    for i, country_code in enumerate(unique_countries):
        country_df = plot_df[plot_df["country_code"] == country_code].sort_values("year")
        if country_df.empty:
            continue

        hex_color = colors[i % len(colors)]
        # Correctly convert HEX to RGBA for translucency
        h = hex_color.lstrip("#")
        r, g, b = tuple(int(h[j : j + 2], 16) for j in (0, 2, 4))
        rgba_color = f"rgba({r}, {g}, {b}, 0.2)"

        fig.add_trace(go.Scatter(x=country_df["year"], y=country_df["97.5%"], mode="lines", line=dict(width=0), showlegend=False, hoverinfo="none"))
        fig.add_trace(
            go.Scatter(
                x=country_df["year"],
                y=country_df["2.5%"],
                mode="lines",
                line=dict(width=0),
                fillcolor=rgba_color,
                fill="tonexty",
                showlegend=False,
                hoverinfo="none",
            )
        )

    # Then, draw all the coefficient lines on top
    for i, country_code in enumerate(unique_countries):
        country_df = plot_df[plot_df["country_code"] == country_code].sort_values("year")
        if country_df.empty:
            continue

        try:
            country_name = pycountry.countries.get(numeric=country_code).alpha_3
        except AttributeError:
            print(f"Warning: Country with numeric code {country_code} not found. Using numeric code.")
            country_name = country_code

        hex_color = colors[i % len(colors)]

        fig.add_trace(
            go.Scatter(
                x=country_df["year"],
                y=country_df["Estimate"],
                name=country_name,
                mode="lines+markers",
                line=dict(color=hex_color),
                marker=dict(size=8),
            )
        )

    fig.update_layout(
        xaxis_title="Year",
        yaxis_title="Estimated Coefficient",
        hovermode="x unified",
        legend_title="Country",
        xaxis=dict(tickmode="linear"),
    )

    return fig

## 1. Fixing the regression specification

Our current specification has the problem that the control set includes treated data. That is, it includes imports of Chinese goods which were tariffed by the US, by countries other than the UK. Our beta coefficient only absorbs those cases where the UK is the importer, meaning all those other cases are in the control set. 

So what we're actually estimating, when we isolate out the UK, is how different is the UK to RoW in it's response. Which we find it's not that different (in fact a little less diversion than other countries). And that's why when we isolate the RoW, we do get an effect (because the baseline is then just US - China).

I have a few ideas for approaches we could take:

### Add a parameter for every country in the sample, which is active for imports from China of tariffed goods in that year.
The equation would look something like this:

log(Imports_ijpt) = Σ_k^t [γ_k^t * (Is_Country_k_i * Is_China_j * Tariff_pt * t)] + Fixed Effects + ε

This would be a much more saturated model. 

We could select out a list of countries based on other features.

### Filtering to only UK imports
An alternative approach - control the control set!


### Filtering to a subset of countries
Filter to a subset of countries - those which a-priori we believe won't experience diversion (as far as possible). This could be the US, for example.


## Adding a parameter for every country in the sample

log(Imports_ijpt) = Σ_k^t [γ_k^t * (Is_Country_k_i * Is_China_j * Tariff_pt * t)] + Fixed Effects + ε

In [11]:
# my_formula = "log(value) | importer^year^product_code + importer^exporter"
# model, _, _ = run_saturated_regression(
#     data=analysis_lf,
#     formula=my_formula,
#     year_range=[str(y) for y in range(2017, 2021)],
# )

In [None]:
# # Create models directory if it doesn't exist
# Path("models").mkdir(exist_ok=True)

# # Run models one at a time and save them
# for dep in ["quantity", "value"]: # "quantity",
#     my_formula = f"log({dep}) | importer^year^product_code + importer^exporter"
#     model, _, _ = run_saturated_regression(
#         data=analysis_lf,
#         formula=my_formula,
#         year_range=[str(y) for y in range(2017, 2021)],
#     )
#     # Save results to pickle file
#     with open(f"models/{dep}.pkl", "wb") as f:
#         pickle.dump(model, f)

#     del model  # Free up memory 

In [13]:
# # Get the average weight for each country with the following function
# def get_average_coefficient_by_country(feols_model):
#     """
#     Calculates the average of the 2018, 2019, and 2020 estimated coefficients
#     for each country from a pyfixest.feols model object.

#     Args:
#         feols_model: A fitted pyfixest.feols model object.

#     Returns:
#         A pandas DataFrame with the average coefficient for each country.
#     """

#     # The user did not specify weights, so a simple average will be computed.
#     # weights = {2018: 0.2, 2019: 0.6, 2020: 0.2}
#     weights = {2018: 0.25, 2019: 0.75, 2020: 0.0}

#     # tidy creates a summary of the model coefficients
#     df = feols_model.tidy().reset_index().rename(columns={"index": "Coefficient"})

#     df[["gamma", "country", "year"]] = df["Coefficient"].str.split("_", expand=True)
#     df["year"] = df["year"].astype(int)


#     df = df[df["year"].isin(weights.keys())]

#     df["weighted_estimate"] = df.apply(lambda row: row["Estimate"] * weights[row["year"]], axis=1) * 100
#     df["weighted_std_error"] = df.apply(lambda row: row["Std. Error"] * weights[row["year"]], axis=1) * 100

#     result = df.groupby("country")[["weighted_estimate", "weighted_std_error"]].sum().reset_index()
#     result.rename(columns={"weighted_estimate": "average_coefficient"}, inplace=True)

#     result["country"] = result["country"].apply(lambda x: pycountry.countries.get(numeric=x).name)

#     return result


# model_coefficients = get_average_coefficient_by_country(model)
# model_coefficients.head(1)

In [14]:
def get_average_coefficient_by_country_multi(models_dict):
    """
    Calculates the average of the 2018, 2019, and 2020 estimated coefficients
    for each country from multiple pyfixest.feols model objects (e.g., value and quantity).

    Args:
        models_dict: dict, keys are model types (e.g., "value", "quantity"), values are pyfixest.feols model objects.

    Returns:
        A pandas DataFrame with the average coefficient and std error for each country and model type.
    """
    weights = {2018: 0.25, 2019: 0.75, 2020: 0.0}
    results = []

    for model_type, feols_model in models_dict.items():
        df = feols_model.tidy().reset_index().rename(columns={"index": "Coefficient"})
        df[["gamma", "country", "year"]] = df["Coefficient"].str.split("_", expand=True)
        df["year"] = df["year"].astype(int)
        df = df[df["year"].isin(weights.keys())]
        df["weighted_estimate"] = df.apply(lambda row: row["Estimate"] * weights[row["year"]], axis=1) * 100
        df["weighted_std_error"] = df.apply(lambda row: row["Std. Error"] * weights[row["year"]], axis=1) * 100
        result = df.groupby("country")[["weighted_estimate", "weighted_std_error"]].sum().reset_index()
        result.rename(columns={
            "weighted_estimate": f"average_coefficient_{model_type}",
            "weighted_std_error": f"std_error_{model_type}"
        }, inplace=True)
        results.append(result)

    # Merge all results on 'country'
    from functools import reduce
    merged = reduce(lambda left, right: left.merge(right, on="country", how="outer"), results)
    merged["country"] = merged["country"].apply(lambda x: pycountry.countries.get(numeric=x).name)
    return merged

In [24]:
# Load the two models
model_quantity = pickle.load(open("models/quantity.pkl", "rb"))
model_value = pickle.load(open("models/value.pkl", "rb"))

models_dict = {
    "quantity": model_quantity,
    "value": model_value
}

model_coefficients = get_average_coefficient_by_country_multi(models_dict)
model_coefficients

Unnamed: 0,country,average_coefficient_quantity,std_error_quantity,average_coefficient_value,std_error_value
0,Australia,-0.331976,0.194549,-0.146098,0.165741
1,Austria,0.150568,0.224165,0.024664,0.20363
2,Belgium,0.503998,0.236103,0.181749,0.215697
3,Brazil,0.64801,0.220306,0.240111,0.195442
4,Canada,0.233255,0.183799,0.134401,0.16741
5,"Taiwan, Province of China",0.391282,0.194673,0.013007,0.177505
6,Czechia,0.256982,0.238307,0.107321,0.20905
7,France,0.765303,0.221376,0.295788,0.182268
8,Germany,0.236536,0.193121,0.091714,0.174726
9,Hungary,0.978309,0.24172,0.926853,0.20297


In [16]:
# fig = go.Figure()

# fig.add_trace(
#     go.Scatter(
#         x=model_coefficients["country"],
#         y=model_coefficients["average_coefficient"],
#         error_y=dict(type="data", array=model_coefficients["weighted_std_error"], visible=True),
#         mode="markers",
#         marker=dict(size=10),
#         name="Estimated Value",
#     )
# )

# fig.update_layout(
#     title_text="Average Cross Elasticity Coefficient by Country", xaxis_title="Country", yaxis_title="Average Cross Elasticity Coefficient"
# )
# fig.show()

# Run just the US, seperately with a slightly different formula


In [17]:
def run_direct_effect_regression(
    data: pl.LazyFrame,
    interaction_term_name: str,
    interaction_importers: list[str],
    interaction_exporters: list[str],
    year_range: list[str],
    formula: str,
    vcov: Optional[str | dict] = "hetero",
    filter_expression: Optional[pl.Expr] = None,
):
    USA_CC = "840"
    CHINA_CC = "156"

    try:
        dependent_var_str = formula.split("~")[0].strip()
        dependent_var_col = re.findall(r"\b\w+\b", dependent_var_str)[-1]
    except IndexError:
        raise ValueError(f"Could not parse dependent variable from formula: {formula}")

    tariff_expr = (
        pl.col("average_tariff_official")
        .filter((pl.col("partner_country") == USA_CC) & (pl.col("reporter_country") == CHINA_CC))
        .mean()
        .over(["year", "product_code"])
        .alias("tariff_us_china")
    )

    input_lf = data.with_columns(
        pl.col("partner_country").alias("importer"),
        pl.col("reporter_country").alias("exporter"),
        tariff_expr,
    )

    interaction_filter = (pl.col("importer").is_in(interaction_importers)) & (pl.col("exporter").is_in(interaction_exporters))

    interaction_expressions = [
        pl.when(interaction_filter & (pl.col("year") == str(year)))
        .then(pl.col("tariff_us_china"))
        .otherwise(0.0)
        .alias(f"{interaction_term_name}_{year}")
        for year in year_range
    ]

    final_lf = input_lf.with_columns(*interaction_expressions)

    # Apply the additional filter expression if provided
    if filter_expression is not None:
        final_lf = final_lf.filter(filter_expression)

    print(f"Checking for nulls in dependent variable '{dependent_var_col}' and 'tariff_us_china'.")
    clean_df = final_lf.drop_nulls(subset=[dependent_var_col, "tariff_us_china"]).collect()

    model = pyfixest.feols(fml=formula, data=clean_df, vcov=vcov, lean=True, store_data=False)

    return model

### First for value

In [18]:
models_us_dict = {}

In [19]:
interaction_name = "USA_from_China"
importer_list = ["840"]  # USA
exporter_list = ["156"]  # China

regressors = " + ".join(f"{interaction_name}_{year}" for year in [str(y) for y in range(2017, 2021)])
formula = f"log(value) ~ {regressors} | importer^year^product_code + importer^exporter + exporter^year^product_code"
print(f"Formula for model:\n{formula}")

# 3. Run the model
model = run_direct_effect_regression(
    data=analysis_lf,
    interaction_term_name=interaction_name,
    interaction_importers=importer_list,
    interaction_exporters=exporter_list,
    year_range=[str(y) for y in range(2017, 2021)],
    formula=formula,
)

# Get the weighted average coefficient for the US, and the weighted std error
df = model.tidy().reset_index().rename(columns={"index": "Coefficient"})

val_2018 = df[df['Coefficient'] == "USA_from_China_2018"]
val_2019 = df[df['Coefficient'] == "USA_from_China_2019"]
val_2020 = df[df['Coefficient'] == "USA_from_China_2020"]

val_average = (0.25*val_2018['Estimate'].values[0] + 0.75*val_2019['Estimate'].values[0] + 0.0*val_2020['Estimate'].values[0]) * 100
val_std_error = (0.25*val_2018['Std. Error'].values[0] + 0.75*val_2019['Std. Error'].values[0] + 0.0*val_2020['Std. Error'].values[0]) * 100

models_us_dict["model_value"] = model


Formula for model:
log(value) ~ USA_from_China_2017 + USA_from_China_2018 + USA_from_China_2019 + USA_from_China_2020 | importer^year^product_code + importer^exporter + exporter^year^product_code
Checking for nulls in dependent variable 'value' and 'tariff_us_china'.


In [20]:
formula = f"log(quantity) ~ {regressors} | importer^year^product_code + importer^exporter + exporter^year^product_code"
print(f"Formula for model:\n{formula}")

# 3. Run the model
model = run_direct_effect_regression(
    data=analysis_lf,
    interaction_term_name=interaction_name,
    interaction_importers=importer_list,
    interaction_exporters=exporter_list,
    year_range=[str(y) for y in range(2017, 2021)],
    formula=formula,
)

# Get the weighted average coefficient for the US, and the weighted std error
df = model.tidy().reset_index().rename(columns={"index": "Coefficient"})

qty_2018 = df[df['Coefficient'] == "USA_from_China_2018"]
qty_2019 = df[df['Coefficient'] == "USA_from_China_2019"]
qty_2020 = df[df['Coefficient'] == "USA_from_China_2020"]

qty_average = (0.25*qty_2018['Estimate'].values[0] + 0.75*qty_2019['Estimate'].values[0] + 0.0*qty_2020['Estimate'].values[0]) * 100
qty_std_error = (0.25*qty_2018['Std. Error'].values[0] + 0.75*qty_2019['Std. Error'].values[0] + 0.0*qty_2020['Std. Error'].values[0]) * 100

models_us_dict["model_quantity"] = model

Formula for model:
log(quantity) ~ USA_from_China_2017 + USA_from_China_2018 + USA_from_China_2019 + USA_from_China_2020 | importer^year^product_code + importer^exporter + exporter^year^product_code
Checking for nulls in dependent variable 'quantity' and 'tariff_us_china'.


In [29]:
# Add new row for United States with the calculated values if it doesn't exist
if 'United States' not in model_coefficients['country'].values:
    new_row = pd.DataFrame({
        'country': ['United States'],
        'average_coefficient_value': [val_average],
        'std_error_value': [val_std_error],
        'average_coefficient_quantity': [qty_average],
        'std_error_quantity': [qty_std_error],
    })
    model_coefficients = pd.concat([model_coefficients, new_row], ignore_index=True)

# Now create our final plots, final tables, etc.

1. Table of results using an etable or some such
2. Big plot of the results

In [28]:
model_coefficients

Unnamed: 0,country,average_coefficient_quantity,std_error_quantity,average_coefficient_value,std_error_value,average_std_error_value,average_std_error_quantity
0,Australia,-0.331976,0.194549,-0.146098,0.165741,,
1,Austria,0.150568,0.224165,0.024664,0.20363,,
2,Belgium,0.503998,0.236103,0.181749,0.215697,,
3,Brazil,0.64801,0.220306,0.240111,0.195442,,
4,Canada,0.233255,0.183799,0.134401,0.16741,,
5,"Taiwan, Province of China",0.391282,0.194673,0.013007,0.177505,,
6,Czechia,0.256982,0.238307,0.107321,0.20905,,
7,France,0.765303,0.221376,0.295788,0.182268,,
8,Germany,0.236536,0.193121,0.091714,0.174726,,
9,Hungary,0.978309,0.24172,0.926853,0.20297,,


In [None]:
df_value = model_coefficients[["country", "average_coefficient_value", "std_error_value"]].copy()
df_value.rename(columns={"average_coefficient_value": "coefficient", "std_error_value": "stderror"}, inplace=True)
df_value["Estimation"] = "Value"

df_quantity = model_coefficients[["country", "average_coefficient_quantity", "std_error_quantity"]].copy()
df_quantity.rename(columns={"average_coefficient_quantity": "coefficient", "std_error_quantity": "stderror"}, inplace=True)
df_quantity["Estimation"] = "Quantity"

df_long = pd.concat([df_value, df_quantity])

# 2. Prepare for plotting with offsets
fig = go.Figure()
countries = df_long["country"].unique()
x_numeric = np.arange(len(countries)) # Create numeric base positions: [0, 1, 2, 3, 4]

# Define the offset and B&W-friendly marker styles
offset = 0.15
marker_styles = {
    "Value": {"symbol": "circle", "color": "black", "size": 10},
    "Quantity": {"symbol": "x-thin-open", "color": "black", "size": 11},
}
offsets = {"Value": -offset, "Quantity": offset}

# 3. Add a separate trace for each estimation method with its own offset
for name in ["Value", "Quantity"]:
    df_subset = df_long[df_long["Estimation"] == name]
    fig.add_trace(go.Scatter(
        x=x_numeric + offsets[name], # Apply offset here
        y=df_subset["coefficient"],
        error_y=dict(type="data", array=df_subset["stderror"], thickness=1, width=5),
        name=name,
        mode="markers",
        marker=marker_styles[name]
    ))

# 4. Refine layout for a professional, academic look
fig.update_layout(
    template="plotly_white",
    font=dict(family="Times New Roman", size=15),
    # title=dict(text="Cross Elasticity Coefficient by Country: Value vs. Quantity", x=0.5, font=dict(size=18)),
    xaxis_title="Country",
    yaxis_title="Cross Elasticity Coefficient",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ),
    xaxis=dict(
        tickmode='array',
        tickvals=x_numeric, # Position ticks at the original numeric non-offset locations
        ticktext=countries, # Label ticks with country names
        showline=True,
        linecolor='black',
        mirror=True
    ),
    yaxis=dict(showline=True, linecolor='black', mirror=True),
    margin=dict(l=60, r=40, b=40, t=80),
)

# Add the same horizontal reference line
fig.add_hline(y=0, line_width=1, line_dash="dash", line_color="grey")

# To save for LaTeX, ensure you have 'kaleido' installed:
# pip install kaleido
fig.write_image(
    "/Users/lukasalemu/Documents/00. Bank of England/03. MPIL/tariff_trade_analysis/notebook_outputs/cross_elasticity_coefficient_by_country.png"
)
fig.show()

ValueError: 
Image export using the "kaleido" engine requires the Kaleido package,
which can be installed using pip:

    $ pip install --upgrade kaleido
