# `README.md`

# A High-Resolution Digital Twin of the Global Production Network

<!-- PROJECT SHIELDS -->
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Python Version](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/downloads/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![Type Checking: mypy](https://img.shields.io/badge/type_checking-mypy-blue)](http://mypy-lang.org/)
[![Pandas](https://img.shields.io/badge/pandas-%23150458.svg?style=flat&logo=pandas&logoColor=white)](https://pandas.pydata.org/)
[![NumPy](https://img.shields.io/badge/numpy-%23013243.svg?style=flat&logo=numpy&logoColor=white)](https://numpy.org/)
[![SciPy](https://img.shields.io/badge/SciPy-%23025596?style=flat&logo=scipy&logoColor=white)](https://scipy.org/)
[![NetworkX](https://img.shields.io/badge/NetworkX-blue.svg?style=flat&logo=networkx&logoColor=white)](https://networkx.org/)
[![Statsmodels](https://img.shields.io/badge/Statsmodels-150458.svg?style=flat&logo=python-social-auth&logoColor=white)](https://www.statsmodels.org/stable/index.html)
[![Scikit-learn](https://img.shields.io/badge/scikit--learn-%23F7931E.svg?style=flat&logo=scikit-learn&logoColor=white)](https://scikit-learn.org/)
[![Jupyter](https://img.shields.io/badge/Jupyter-%23F37626.svg?style=flat&logo=Jupyter&logoColor=white)](https://jupyter.org/)
[![arXiv](https://img.shields.io/badge/arXiv-2508.12315-b31b1b.svg)](https://arxiv.org/abs/2508.12315)
[![Research](https://img.shields.io/badge/Research-Supply%20Chain%20Economics-green)](https://github.com/chirindaopensource/global_production_network_mapping)
[![Discipline](https://img.shields.io/badge/Discipline-Network%20Science%20%26%20Econometrics-blue)](https://github.com/chirindaopensource/global_production_network_mapping)
[![Methodology](https://img.shields.io/badge/Methodology-Network%20Inference-orange)](https://github.com/chirindaopensource/global_production_network_mapping)
[![Year](https://img.shields.io/badge/Year-2025-purple)](https://github.com/chirindaopensource/global_production_network_mapping)

**Repository:** `https://github.com/chirindaopensource/global_production_network_mapping`

**Owner:** 2025 Craig Chirinda (Open Source Projects)

This repository contains an **independent**, professional-grade Python implementation of the research methodology from the 2025 paper entitled **"Deciphering the global production network from cross-border firm transactions"** by:

*   Neave O'Clery
*   Ben Radcliffe-Brown
*   Thomas Spencer
*   Daniel Tarling-Hunter

The project provides a complete, end-to-end computational framework for transforming massive-scale, firm-level transaction data into a high-resolution, computable "digital twin" of the global production economy. It implements the paper's novel network inference algorithm, a full suite of network and econometric analyses, and a comprehensive set of validation and robustness checks. The goal is to provide a transparent, robust, and computationally efficient toolkit for researchers and policymakers to replicate, validate, and extend the paper's findings on global supply chain structures and economic diversification.

## Table of Contents

- [Introduction](#introduction)
- [Theoretical Background](#theoretical-background)
- [Features](#features)
- [Methodology Implemented](#methodology-implemented)
- [Core Components (Notebook Structure)](#core-components-notebook-structure)
- [Key Callable: main_analysis_orchestrator](#key-callable-main_analysis_orchestrator)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Input Data Structure](#input-data-structure)
- [Usage](#usage)
- [Output Structure](#output-structure)
- [Project Structure](#project-structure)
- [Customization](#customization)
- [Contributing](#contributing)
- [License](#license)
- [Citation](#citation)
- [Acknowledgments](#acknowledgments)

## Introduction

This project provides a Python implementation of the methodologies presented in the 2025 paper "Deciphering the global production network from cross-border firm transactions." The core of this repository is the iPython Notebook `global_production_network_mapping_draft.ipynb`, which contains a comprehensive suite of functions to replicate the paper's findings, from initial data validation to the final execution of a full suite of robustness checks.

The study of global supply chains has been historically constrained by a lack of granular data. This project implements the paper's innovative approach, which leverages a massive dataset of 1 billion firm-to-firm transactions to infer a detailed, product-level production network.

This codebase enables users to:
-   Rigorously validate and cleanse massive-scale transaction and firm metadata.
-   Implement the core network inference algorithm to build a weighted, directed product network.
-   Analyze the network's structure using community detection and centrality measures.
-   Perform multi-pronged validation of the inferred network against external benchmarks and statistical null models.
-   Engineer network-based econometric features to predict national economic diversification.
-   Execute the full Probit regression analysis with fixed effects.
-   Conduct a comprehensive suite of robustness checks to test the stability of the findings.

## Theoretical Background

The implemented methods are grounded in network science, econometrics, and economic complexity theory.

**1. Network Inference from Revealed Preference:**
The core of the methodology is to infer an input-output link from product `i` to product `j` not from direct input tables, but from the observed behavior of firms. The weight of a link, `A_ij`, is a measure of "excess purchase" or revealed preference. It is calculated as the ratio of the probability that a producer of `j` buys `i` to the baseline probability that any firm buys `i`.

$$
A_{i,j} = \frac{|S_i^j|/|S_j|}{|S_i^\dagger|/|S|}
$$

An `A_ij > 1` indicates that producers of `j` have a revealed preference for input `i`, suggesting a production linkage. This method effectively filters out ubiquitous inputs (like packaging) and highlights specific, technologically relevant connections.

**2. Network Density and Economic Diversification:**
The project implements the "density" metric, a concept from economic complexity that measures a country's existing capabilities relevant to a new product. The network-derived upstream and downstream densities measure the proportion of a target product's key suppliers or customers, respectively, that a country already has a comparative advantage in.

$$
d_{p,c} = \frac{\sum_{j \in J_p} I(A_{p,j}) \cdot M_{j,c}}{\sum_{j \in J_p} I(A_{p,j})}
$$

where `J_p` is the set of top-k downstream partners of product `p`, and `M_j,c` is an indicator of country `c`'s presence in product `j`.

**3. Probit Model for Diversification:**
The final analysis uses a Probit model to test the hypothesis that higher network density predicts the probability of a country developing a new export capability in a product. The model includes country and product fixed effects to control for unobserved heterogeneity.

$$
R_{p,c} = \Phi(\alpha + \beta_d d_{p,c} + \gamma_p + \eta_c)
$$

## Features

The provided iPython Notebook (`global_production_network_mapping_draft.ipynb`) implements the full research pipeline, including:

-   **Modular, Task-Based Architecture:** The entire pipeline is broken down into 11 distinct, modular tasks, from data validation to robustness checks.
-   **High-Performance Data Engineering:** Utilizes vectorized `pandas` and `numpy` operations to efficiently process and transform large datasets.
-   **Efficient Network Inference:** Implements the core `A_ij` formula and sparsification rules using performant, vectorized calculations.
-   **State-of-the-Art Network Analysis:** Employs the `leidenalg` library for robust community detection and `networkx` for standard centrality measures.
-   **Rigorous Statistical Validation:** Includes a parallelized Monte Carlo simulation framework for testing subgraph modularity against the configuration model.
-   **Professional-Grade Econometrics:** Implements the Probit model with fixed effects using `statsmodels`, including correct calculation of Average Marginal Effects for interpretation.
-   **Comprehensive Robustness Suite:** A full suite of advanced robustness checks to analyze the framework's sensitivity to parameters, temporal windows, and methodological choices.

## Methodology Implemented

The core analytical steps directly implement the methodology from the paper:

1.  **Input Data Validation (Task 1):** Ingests and rigorously validates all raw data and configuration files.
2.  **Data Preprocessing (Task 2):** Cleanses the transaction log and performs firm entity resolution.
3.  **Firm Classification (Task 3):** Identifies significant producer and purchaser sets for each product.
4.  **Network Inference (Task 4):** Computes the `A_ij` matrix and constructs the network objects.
5.  **Structural Analysis (Task 5):** Performs community detection and topological validation.
6.  **Centrality Calculation (Task 6):** Computes Betweenness and Hub Score centralities.
7.  **Network Validation (Task 7):** Validates the network against external data and a statistical null model.
8.  **Feature Engineering (Task 8):** Computes Rpop, the diversification outcome, and network density metrics.
9.  **Econometric Analysis (Task 9):** Estimates the final Probit models.
10. **Orchestration (Task 10):** Provides a master function to run the entire end-to-end pipeline.
11. **Robustness Analysis (Task 11):** Provides a master function to run the full suite of robustness checks.

## Core Components (Notebook Structure)

The `global_production_network_mapping_draft.ipynb` notebook is structured as a logical pipeline with modular orchestrator functions for each of the 11 major tasks.

## Key Callable: main_analysis_orchestrator

The central function in this project is `main_analysis_orchestrator`. It orchestrates the entire analytical workflow, providing a single entry point for running the baseline study replication and the advanced robustness checks.

```python
def main_analysis_orchestrator(
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    # ... other data inputs
    base_manifest: Dict[str, Any],
    run_robustness_checks: bool = True,
    # ... other robustness configurations
) -> Dict[str, Any]:
    """
    Serves as the top-level entry point for the entire research project.
    """
    # ... (implementation is in the notebook)
```

## Prerequisites

-   Python 3.8+
-   Core dependencies: `pandas`, `numpy`, `scipy`, `networkx`, `statsmodels`, `scikit-learn`, `python-igraph`, `leidenalg`, `joblib`.

## Installation

1.  **Clone the repository:**
    ```sh
    git clone https://github.com/chirindaopensource/global_production_network_mapping.git
    cd global_production_network_mapping
    ```

2.  **Create and activate a virtual environment (recommended):**
    ```sh
    python -m venv venv
    source venv/bin/activate  # On Windows, use `venv\Scripts\activate`
    ```

3.  **Install Python dependencies:**
    ```sh
    pip install pandas numpy scipy networkx statsmodels scikit-learn python-igraph leidenalg joblib
    ```

## Input Data Structure

The pipeline requires four `pandas` DataFrames and two Python dictionaries with specific structures, which are rigorously validated by the first task.
1.  `transactions_log_frame`: Contains transaction-level data.
2.  `firm_metadata_frame`: Contains firm-level metadata.
3.  `comtrade_exports_frame`: Contains country-product level export data.
4.  `country_data_frame`: Contains country-level population data.
5.  `supply_chains_definitions_dict`: Defines product sets for validation.
6.  `replication_manifest`: A comprehensive dictionary controlling all parameters.

A fully specified example of all inputs is provided in the main notebook.

## Usage

The `global_production_network_mapping_draft.ipynb` notebook provides a complete, step-by-step guide. The core workflow is:

1.  **Prepare Inputs:** Load your data `DataFrame`s and define your configuration dictionaries. A complete template is provided.
2.  **Execute Pipeline:** Call the master orchestrator function.

    ```python
    # This single call runs the baseline analysis and all configured robustness checks.
    final_results = main_analysis_orchestrator(
        transactions_log_frame=transactions_df,
        firm_metadata_frame=firms_df,
        comtrade_exports_frame=comtrade_df,
        country_data_frame=country_df,
        supply_chains_definitions_dict=supply_chains,
        base_manifest=replication_manifest,
        run_robustness_checks=True,
        parameter_grid=param_grid,
        methods_to_test=methods_list
    )
    ```
3.  **Inspect Outputs:** Programmatically access any result from the returned dictionary. For example, to view the temporal robustness results:
    ```python
    temporal_df = final_results['robustness_results']['temporal_robustness']
    print(temporal_df.head())
    ```

## Output Structure

The `main_analysis_orchestrator` function returns a single, comprehensive dictionary with two top-level keys:
-   `baseline_results`: A dictionary containing all artifacts from the primary study replication (network objects, analysis DataFrames, econometric models, etc.).
-   `robustness_results`: A dictionary containing the summary DataFrames from each of the executed robustness checks.

## Project Structure

```
global_production_network_mapping/
│
├── global_production_network_mapping_draft.ipynb  # Main implementation notebook
├── requirements.txt                                 # Python package dependencies
├── LICENSE                                          # MIT license file
└── README.md                                        # This documentation file
```

## Customization

The pipeline is highly customizable via the `replication_manifest` dictionary and the arguments to the `main_analysis_orchestrator`. Users can easily modify all relevant parameters for the baseline run or define custom scenarios for the robustness checks.

## Contributing

Contributions are welcome. Please fork the repository, create a feature branch, and submit a pull request with a clear description of your changes. Adherence to PEP 8, type hinting, and comprehensive docstrings is required.

## License

This project is licensed under the MIT License. See the `LICENSE` file for details.

## Citation

If you use this code or the methodology in your research, please cite the original paper:

```bibtex
@article{oclery2025deciphering,
  title={Deciphering the global production network from cross-border firm transactions},
  author={O'Clery, Neave and Radcliffe-Brown, Ben and Spencer, Thomas and Tarling-Hunter, Daniel},
  journal={arXiv preprint arXiv:2508.12315},
  year={2025}
}
```

For the implementation itself, you may cite this repository:
```
Chirinda, C. (2025). A Python Implementation of "Deciphering the global production network from cross-border firm transactions".
GitHub repository: https://github.com/chirindaopensource/global_production_network_mapping
```

## Acknowledgments

-   Credit to Neave O'Clery, Ben Radcliffe-Brown, Thomas Spencer, and Daniel Tarling-Hunter for their innovative and clearly articulated research.
-   Thanks to the developers of the scientific Python ecosystem (`numpy`, `pandas`, `scipy`, `networkx`, `statsmodels`, etc.) for their powerful open-source tools.

--

*This README was generated based on the structure and content of `global_production_network_mapping_draft.ipynb` and follows best practices for research software documentation.*

# Paper

Title: "*Deciphering the global production network from cross-border firm transactions*"

Authors: Neave O'Clery, Ben Radcliffe-Brown, Thomas Spencer, Daniel Tarling-Hunter

E-Journal Submission Date: 17 August 2025

Link: https://arxiv.org/abs/2508.12315

Abstract:

Critical for policy-making and business operations, the study of global supply chains has been severely hampered by a lack of detailed data. Here we harness global firm-level transaction data covering 20m global firms, and 1 billion cross-border transactions, to infer key inputs for over 1200 products. Transforming this data to a directed network, we find that products are clustered into three large groups including textiles, chemicals and food, and machinery and metals. European industrial nations and China dominate critical intermediate products in the network such as metals, common components and tools, while industrial complexity is correlated with embeddedness in densely connected supply chains. To validate the network, we find structural similarities with two alternative product networks, one generated via LLM queries and the other derived by NAFTA to track product origins. We further detect linkages between products identified in manually mapped single sector supply chains, including electric vehicle batteries and semi-conductors. Finally, metrics derived from network structure capturing both forward and backward linkages are able to predict country-product diversification patterns with high accuracy.

# Summary

### Summary of "Deciphering the global production network from cross-border firm transactions"

**1. Core Problem & Novel Contribution:**
*   **Problem:** Traditional global supply chain analysis is hampered by highly aggregated input-output (IO) tables (e.g., WIOD, WIOT) or fragmented, manually mapped sector-specific data. This limits granular insights into inter-product linkages and the propagation of shocks.
*   **Novelty:** The authors address this by constructing a *directed, weighted product-level supply chain network* derived from *global, firm-level, cross-border transaction data*. This moves beyond aggregate sector-level or single-country firm-level analyses.

**2. Data & Pre-processing (Computational & Data Engineering Aspect):**
*   **Source:** UK Government Global Supply Chain Intelligence Programme (GSCIP) data.
*   **Scale:** Covers 20 million global firms and 1 billion cross-border transactions from 2021-2023. This is orders of magnitude larger than typical datasets.
*   **Attributes:** Includes detailed product codes (Harmonised System - HS 2022), transaction values/weights, firm location, industry code, ownership structure.
*   **Key Pre-processing:**
    *   Conversion of all transactions to HS 2022 for consistency.
    *   **Crucially, aggregation of firms to an 'owner-country' level** using GSCIP's ownership hierarchy data. This addresses multinational structures and avoids double-counting within a conglomerate, though it's noted that local linkages (within a country) are not captured.
    *   Filtering out firms exporting products in more than five HS Sections to mitigate noise from multi-product firms (e.g., wholesalers).

**3. Methodology for Inferring Product Linkages (Econometrics & Network Science):**
*   **"Excess Purchase" Metric (A_i,j):** This is the core innovation for inferring input importance. For any product *j* (the output) and potential input *i*, the metric `A_i,j` quantifies the "excess" purchase of *i* by firms that produce *j*, relative to the purchase of *i* by an average firm.
    *   **Formula:** `A_i,j = (|S_j^i|/|S_j|) / (|S_i|/|S|)`
        *   `S_j`: Set of firms producing product *j*.
        *   `S_j^i`: Subset of firms in `S_j` that also purchase product *i*.
        *   `S_i`: Set of all firms that purchase product *i*.
        *   `S`: Set of all firms.
    *   **Interpretation:** `A_i,j > 1` indicates product *i* is more prevalent in the purchasing basket of firms producing *j* than in the general firm population. This effectively down-weights ubiquitous inputs (e.g., cardboard boxes) and highlights specific inputs (e.g., aluminum panels for washing machines).
*   **Network Construction:** A directed, weighted network of 1228 products (nodes) is formed, where `A_i,j` represents the edge weight from input *i* to output *j*. Thresholds are applied for `A_i,j > 1` and a minimum `firmcount` (number of firms providing evidence for the link) to control sparsity.

**4. Key Findings & Network Structure (Network Science & Economic Geography):**
*   **Community Structure:** Using a multi-scale community detection algorithm (based on random walkers, Delvenne et al. [2010]), the network broadly segregates into three main clusters:
    1.  **Textiles:** Highly isolated.
    2.  **Chemicals & Food:** Interconnected.
    3.  **Machinery & Metals:** Interconnected.
*   **Critical Intermediate Products ("Choke Points"):**
    *   **Betweenness Centrality (BC):** Products with high BC (e.g., various machinery, common components like pumps/blades/taps/motors, plastics, chemical binders) act as key junctures in many supply paths.
    *   **Country Dominance:** European industrial nations (Germany, Italy, Austria) and China dominate the export of products with high BC, indicating their control over critical intermediate goods.
*   **Industrial Complexity & Embeddedness:**
    *   **Hub Score:** A measure similar to eigencentrality, capturing "neighbor of neighbor" linkages and embeddedness in dense supply chains.
    *   **Correlation:** Product Complexity Index (PCI) correlates positively with hub score (rank correlation 0.46). At the country level, Economic Complexity Index (ECI) correlates strongly with mean hub score (rank correlation 0.85). This suggests complex products/countries are deeply embedded in densely connected global supply chains.

**5. Validation of Network Structure (Computer Science & Econometrics):**
*   **Comparison with AI-generated Network:** Structural similarities were found with a product network generated via LLM queries (Fetzer et al. [2024]). While the AI network is sparser, both show similar meso-scale structures and significant in/out-degree correlations (peaking at ~0.68 for in-degree).
*   **Comparison with NAFTA IO Network:** Similar aggregate clusters (food, textiles, chemicals, metals/machinery) were observed, despite lower direct edge correlations.
*   **Validation against Manually Mapped Supply Chains:** The network successfully picks up dense linkages within manually constructed single-sector supply chains (e.g., electric vehicle batteries, semiconductors). A rigorous statistical test (comparing subgraph modularity to 100k random networks) yielded extremely low p-values (~1e-27 to 1e-35), indicating these observed linkages are highly unlikely to occur by chance.

**6. Predictive Power for Economic Diversification (Econometrics):**
*   **Density Metrics:** "Downstream" and "Upstream" density metrics are constructed, capturing the share of top in-degree (downstream) or out-degree (upstream) neighboring products present in a country.
*   **Probit Regression:** A Probit model predicts country-product export diversification patterns (product appearance, defined by Rpop increase from 2016 to 2021).
*   **Results:** Both downstream and upstream linkages show high predictive power (AUC values of 0.86-0.87), consistent with existing literature on capability overlap. Interestingly, predictive power is highest for "weak links" (low firmcount threshold) and for predicting smaller increases in Rpop, suggesting that even less certain connections are informative.

**7. Limitations & Future Directions:**
*   **Static Analysis:** The current network is aggregated over 2021-2023, precluding temporal analysis of supply chain evolution.
*   **Cross-Border Focus:** Misses crucial local linkages within countries.
*   **Data Noise:** Acknowledges inherent noise from uneven firm coverage, multi-product firms (despite mitigation efforts), and idiosyncratic purchasing patterns (e.g., semiconductor firms buying washing machines for chips during shortages).
*   **Granularity:** Currently at HS 4-digit; higher granularity (6-digit) is computationally challenging due to exponential increase in edges.
*   **Future Work:** Temporal analysis, regional-specific networks, higher granularity, and more sophisticated validation.

**Overall Significance:**
This paper provides an unprecedented, data-driven, and rigorously validated tool for understanding the complex global production network. Its insights into critical intermediate products, the relationship between industrial complexity and supply chain embeddedness, and its predictive power for diversification patterns have significant implications for policy-making (e.g., identifying choke points, assessing vulnerabilities, guiding industrial policy) and for advancing economic geography and network science research.

# Import Essential Modules

In [None]:
#!/usr/bin/env python3
# ==============================================================================#
#
#  A High-Resolution Digital Twin of the Global Production Network
#
#  This module provides a complete, professional-grade, and end-to-end
#  implementation of the research pipeline presented in "Deciphering the global
#  production network from cross-border firm transactions" by O'Clery et al. (2025).
#  It transforms massive-scale, firm-level transaction data into a computable,
#  directed, and weighted product-space network. This "digital twin" of the
#  global production economy enables quantitative analysis of supply chain
#  structures, identification of critical products, and econometric prediction
#  of national economic diversification patterns.
#
#  Core Methodological Components:
#  • Network Inference from Firm Behavior: A novel statistical method to infer
#    product-level input-output links from the "excess purchase" patterns of
#    producer firms, using a dataset of 1 billion transactions.
#  • Network Structure Analysis: Application of community detection (Leiden
#    algorithm) and centrality measures (Betweenness, HITS) to uncover the
#    meso- and micro-scale architecture of the global production system.
#  • Multi-Pronged Validation: Rigorous validation of the inferred network via
#    comparison with external benchmarks and statistical testing against a
#    configuration model null hypothesis for known supply chains (EVs, semiconductors).
#  • Econometric Prediction: Engineering of network-based "density" metrics
#    that capture a country's proximity to new products in the network, and
#    using these metrics in a Probit model to predict economic diversification.
#
#  Technical Implementation Features:
#  • High-performance, vectorized data processing using Pandas and NumPy.
#  • Efficient sparse matrix algebra (SciPy) for network construction and analysis.
#  • State-of-the-art community detection using the Leiden algorithm.
#  • Robust econometric modeling with fixed effects using Statsmodels.
#  • A comprehensive, parallelized robustness analysis suite to test sensitivity
#    to key methodological parameters and temporal windows.
#
#  Paper Reference:
#  O'Clery, N., Radcliffe-Brown, B., Spencer, T., & Tarling-Hunter, D. (2025).
#  Deciphering the global production network from cross-border firm transactions.
#  arXiv preprint arXiv:2508.12315.
#  https://arxiv.org/abs/2508.12315
#
#  Author: CS Chirinda
#  License: MIT
#  Version: 1.0.0
#
# ==============================================================================#

# --- Standard Library Imports ---
import logging
import warnings
import itertools
from copy import deepcopy
from multiprocessing import Pool, cpu_count
from typing import Dict, Any, List, Tuple, Set, Optional

# --- Third-Party Library Imports ---

# Core Data Manipulation and Numerical Computing
import numpy as np
import pandas as pd

# Network Analysis
import networkx as nx
import igraph as ig
import leidenalg as la

# Scientific and Statistical Computing
from scipy.sparse import csr_matrix
from scipy.stats import pearsonr

# Econometric Modeling
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.discrete.discrete_model import ProbitResults

# Machine Learning and Model Evaluation
from sklearn.metrics import roc_auc_score, roc_curve

# Parallel Processing
from joblib import Parallel, delayed


# Implementation

## Draft 1

### Documentation of Inputs, Processes and Outputs (IPO Analysis) of Key Pipeline Callables

### **Module: Core Pipeline and Analysis**

#### **Callable: `validate_and_assess_inputs`**
*   **Inputs:** `transactions_log_frame`, `firm_metadata_frame`, `comtrade_exports_frame`, `country_data_frame`, `replication_manifest`. These are the raw, unprocessed data sources and the main configuration file.
*   **Process:** This function is a non-mutating gatekeeper. It performs a series of checks:
    1.  Validates the structure and values of the `replication_manifest` dictionary.
    2.  Validates the schema (column names and dtypes) of all input DataFrames.
    3.  Assesses data quality by checking for missing values in critical columns, verifying the completeness of ownership data, and confirming the temporal and product code ranges in the transaction data.
*   **Outputs:** `None`. The function's purpose is to either pass silently, confirming the inputs are valid, or to raise a `DataValidationError`, halting execution before any processing begins.
*   **Data Transformation:** No data is transformed. This is a read-only validation step.
*   **Role in Research Pipeline:** This function implements the implicit but essential first step of any rigorous quantitative research: **Data Integrity Verification**. It ensures that the raw data conforms to the expected format and quality standards before being used in the analysis, preventing a wide range of downstream errors.

#### **Callable: `preprocess_and_cleanse_data`**
*   **Inputs:** `transactions_log_frame`, `firm_metadata_frame`, `replication_manifest`.
*   **Process:** This function executes the primary data engineering and entity resolution steps.
    1.  **Filtering:** It filters the transaction log to the specified date range (2021-2023) and removes records with non-positive values.
    2.  **Entity Resolution:** It maps raw `firm_id`s to `ultimate_owner_id`s, handling missing owners by treating firms as self-owned.
    3.  **Augmentation:** It merges the transaction log with firm metadata to attach the country of domicile to both seller and buyer.
    4.  **Composite Key Creation:** It creates the primary analytical unit, the `owner_country_entity`, by concatenating owner ID and country.
    5.  **Business Rule Filtering:** It removes domestic transactions and transactions from highly diversified firms (wholesalers), as specified in the "Methods" section (Page 14).
*   **Outputs:** A single, cleansed `pd.DataFrame` (`cleansed_df`) where each row is a valid, cross-border transaction between two well-defined owner-country entities.
*   **Data Transformation:** This is a heavy transformation step involving row filtering, value mapping, column joining, and string manipulation.
*   **Role in Research Pipeline:** This function implements the **Data Preparation and Entity Resolution** stage. It transforms the raw transaction log into the specific, analysis-ready format required by the paper's novel network inference methodology.

#### **Callable: `classify_firms_and_construct_sets`**
*   **Inputs:** The `cleansed_df` from the previous step.
*   **Process:**
    1.  It calls `_identify_significant_entities` twice. This helper function first aggregates transaction values to the (entity, product) level. It then calculates the per-product average of these total transaction values. Finally, it identifies an entity as a "significant" producer (or purchaser) of a product if its total value exceeds this average.
    2.  It then calls `_compute_intersection_metrics`, which calculates for every product pair `(i, j)` the size of the intersection set `S_i^j` (firms that produce `j` and purchase `i`) and the total value of product `i` purchased by firms in that set.
*   **Outputs:** A tuple containing three data structures: `producer_sets` (Dict), `purchaser_sets` (Dict), and `intersection_metrics` (DataFrame).
*   **Data Transformation:** This function transforms the transaction log into aggregated, set-based structures that form the direct inputs for the network inference formula.
*   **Role in Research Pipeline:** This function implements the **Firm Classification** logic described in the "Results" section (Page 5). It operationalizes the core assumption of the paper: that the purchasing behavior of significant producers can be used to reveal input-output relationships.

#### **Callable: `infer_and_construct_network`**
*   **Inputs:** `cleansed_df`, `producer_sets`, `purchaser_sets`, `intersection_metrics`, `replication_manifest`.
*   **Process:**
    1.  It calculates the raw edge weight `A_ij` for every potential link using the paper's core formula.
    2.  It applies the three sparsification filters: `firmcount` threshold, edge weight threshold, and minimum transaction value threshold.
    3.  It constructs the final network representations from the filtered list of edges.
*   **Outputs:** A tuple containing the final network objects: `adj_matrix` (a `csr_matrix`), `graph` (a `networkx.DiGraph`), and `product_to_idx` (a mapping dictionary).
*   **Data Transformation:** This function transforms the set-based data into a weighted adjacency matrix and a graph object.
*   **Role in Research Pipeline:** This function is the heart of the paper. It implements the **Network Inference Formula** from the "Results" section (Page 6):
    $$
    A_{i,j} = \frac{|S_i^j|/|S_j|}{|S_i^\dagger|/|S|}
    $$
    It also implements the **Network Sparsification** rules described in the "Methods" and "Results" sections (Pages 7 & 15) to refine the network.

#### **Callable: `analyze_network_structure`**
*   **Inputs:** `graph`, `adj_matrix`, `product_to_idx`, `replication_manifest`.
*   **Process:**
    1.  It performs multi-scale community detection using the Leiden algorithm, a state-of-the-art method analogous to the "Stability" algorithm mentioned in the paper.
    2.  It characterizes the resulting communities with descriptive labels.
    3.  It validates the network's topology by calculating basic statistics and comparing the intra-sector vs. inter-sector link densities.
*   **Outputs:** A `pd.Series` (`community_labels`) mapping each product to its assigned community.
*   **Data Transformation:** It transforms the graph structure into a node-level feature (community membership).
*   **Role in Research Pipeline:** This function implements the **Meso-scale Structural Analysis** described in the "Results" section (Page 7). It replicates the finding of the three main product clusters (Machinery/Metals, Chemicals/Food, Textiles) and validates the structural properties shown in Figure 3A.

#### **Callable: `calculate_centralities`**
*   **Inputs:** `graph`, `replication_manifest`.
*   **Process:**
    1.  It calculates the weighted betweenness centrality for each node, correctly inverting the edge weights to represent "distance".
    2.  It calculates the hub scores for each node using the HITS algorithm.
    3.  It aggregates these scores into a single DataFrame, ranks them, and computes their correlation.
*   **Outputs:** A `pd.DataFrame` (`centrality_df`) containing the centrality scores and ranks for each product.
*   **Data Transformation:** It transforms the graph structure into node-level importance scores.
*   **Role in Research Pipeline:** This function implements the **Node Centrality Analysis** described in the "Results" section (Page 7). It identifies the "choke point" products (high betweenness) and the products embedded in complex supply chains (high hub score), as shown in Figure 4.

#### **Callable: `run_validation_procedures`**
*   **Inputs:** `adj_matrix`, `product_to_idx`, `supply_chains_definitions_dict`, `replication_manifest`, and optional `external_networks`.
*   **Process:**
    1.  If external networks are provided, it aligns them to the same node set and computes edge-level and degree-level correlations.
    2.  For each manually defined supply chain (EVs, semiconductors), it calculates the empirical directed modularity of the corresponding subgraph.
    3.  It runs a large-scale Monte Carlo simulation, generating thousands of random graphs with the same degree sequence (the configuration model).
    4.  It calculates the p-value of the empirical modularity against the null distribution from the random graphs.
*   **Outputs:** A dictionary (`validation_results`) containing the correlation metrics and the modularity test results (empirical modularity and p-value) for each supply chain.
*   **Data Transformation:** This function transforms the network structure into statistical evidence of its validity.
*   **Role in Research Pipeline:** This function implements the rigorous **Network Validation** procedures described throughout the "Results" section (Pages 9-11). It specifically replicates the comparison to external networks (Figure 5) and the statistical significance test for manually mapped supply chains (Figure 6), including the directed modularity formula from the "Methods" section (Page 15):
    $$
    M_G = \frac{1}{m}\left(\sum_{i,j \in G} X_{i,j} - \frac{Out_i \cdot In_j}{m}\right)
    $$

#### **Callable: `engineer_economic_features`**
*   **Inputs:** `comtrade_exports_frame`, `country_data_frame`, `adj_matrix`, `product_to_idx`, `replication_manifest`.
*   **Process:**
    1.  It calculates the Revealed Population-Adjusted Comparative Advantage (Rpop) for each country-product-year.
    2.  It uses Rpop to construct the binary diversification outcome variable (absent in 2016, present in 2021).
    3.  It calculates the upstream and downstream network density metrics for each country-product pair using efficient matrix algebra.
    4.  It assembles and filters the final long-format dataset for the regression analysis.
*   **Outputs:** A `pd.DataFrame` (`econometric_df`) ready for modeling.
*   **Data Transformation:** This is a major feature engineering step, transforming raw trade data and the network structure into the specific dependent and independent variables for the econometric model.
*   **Role in Research Pipeline:** This function implements the **Economic Feature Engineering** required for the final analysis. It calculates the Rpop metric from the "Methods" section (Page 15) and the crucial network density metrics from the "Results" section (Page 11):
    $$
    d_{p,c} = \frac{\sum_{j \in J_p} I(A_{p,j}) \cdot M_{j,c}}{\sum_{j \in J_p} I(A_{p,j})}
    $$

#### **Callable: `run_econometric_analysis`**
*   **Inputs:** `econometric_df`, `replication_manifest`.
*   **Process:**
    1.  It specifies and estimates the Probit models for upstream and downstream density, including country and product fixed effects.
    2.  It evaluates the models' predictive performance using AUC and ROC curves.
    3.  It extracts key results, including coefficients, p-values, and the correctly calculated Average Marginal Effects for interpretation.
*   **Outputs:** A dictionary (`econometric_results`) containing the full results, performance metrics, and summary statistics for both models.
*   **Data Transformation:** This function transforms the engineered dataset into a statistical model and a set of interpretable economic results.
*   **Role in Research Pipeline:** This function implements the final **Econometric Analysis** described in the "Results" section (Page 12). It estimates the Probit model:
    $$
    R_{p,c} = \Phi(\alpha + \beta_d d_{p,c} + \gamma_p + \eta_c)
    $$
    and produces the results shown in the tables in Figure 7.

### **Module: Orchestration and Robustness**

#### **Callable: `run_end_to_end_pipeline`**
*   **Inputs:** All raw data and configuration files.
*   **Process:** Sequentially executes the nine task-level orchestrators listed above.
*   **Outputs:** A comprehensive dictionary containing all major artifacts from the pipeline.
*   **Data Transformation:** Orchestrates the entire transformation from raw data to final conclusion.
*   **Role in Research Pipeline:** Serves as the **Master Orchestrator** for a single, complete replication of the study.

#### **Callable: `run_full_robustness_analysis`**
*   **Inputs:** All raw data and configuration files, plus configurations for the robustness checks (`parameter_grid`, `methods_to_test`).
*   **Process:** Orchestrates the execution of the three distinct robustness checks: parameter sensitivity, temporal window analysis, and alternative construction methods. It calls the appropriate specialized orchestrators for each check.
*   **Outputs:** A dictionary containing the summary DataFrames from each completed robustness check.
*   **Data Transformation:** Orchestrates multiple runs of the full pipeline under varying assumptions.
*   **Role in Research Pipeline:** Serves as the **Master Robustness Orchestrator**, executing the analyses described in Task 11 to test the stability and reliability of the baseline findings.

#### **Callable: `main_analysis_orchestrator`**
*   **Inputs:** All raw data and configuration files.
*   **Process:** Serves as the ultimate top-level entry point. It first calls `run_end_to_end_pipeline` to get the baseline results. Then, if enabled, it calls `run_full_robustness_analysis` to get the robustness results.
*   **Outputs:** A final master dictionary containing both the baseline and robustness results.
*   **Data Transformation:** Orchestrates the entire project.
*   **Role in Research Pipeline:** Acts as the **Main Entry Point** for the entire project, providing a single function to replicate the paper and perform a comprehensive stability analysis.

<br><br>
### Usage Example

### **Example Usage of the End-to-End Pipeline**

This example demonstrates how to load the necessary data, define the configurations, and execute the entire research workflow—including the baseline replication and the full suite of robustness checks—using the master orchestrator function, `main_analysis_orchestrator`.

#### **1. Preamble: Loading Data and Defining Configurations**

Before calling the main function, a user must first load all required data into memory as pandas DataFrames and define the configurations for the analysis.

**Step 1.1: Load Input DataFrames**

In a real-world scenario, this data would be loaded from databases, CSV files, or other data sources. For this example, we will create placeholder DataFrames that conform *exactly* to the required schemas.

```python
# --- Assume this block is in a script named `run_experiment.py` ---

import numpy as np
import pandas as pd

# NOTE: In a real application, you would load your actual data here.
# For example:
# transactions_log_frame = pd.read_csv("path/to/transactions.csv", parse_dates=['timestamp'])
# firm_metadata_frame = pd.read_csv("path/to/firm_metadata.csv")
# ... and so on.

# For this example, we create correctly-structured placeholder DataFrames.
# This demonstrates the exact format the pipeline expects.

# i. Transaction Log
transactions_log_frame = pd.DataFrame({
    'seller_firm_id': pd.Series([101, 102, 103], dtype=pd.Int64Dtype()),
    'buyer_firm_id': pd.Series([201, 202, 201], dtype=pd.Int64Dtype()),
    'product_hs_code': pd.Series([8507, 2836, 8541], dtype=pd.Int16Dtype()),
    'transaction_value_usd': pd.Series([10000.0, 5000.0, 25000.0], dtype=np.float64),
    'timestamp': pd.to_datetime(['2022-01-15', '2022-03-10', '2023-06-01'])
})

# ii. Firm Metadata
firm_metadata_frame = pd.DataFrame({
    'firm_id': pd.Series([101, 102, 103, 201, 202], dtype=pd.Int64Dtype()),
    'legal_name': pd.Series(['SellerA', 'SellerB', 'SellerC', 'BuyerX', 'BuyerY'], dtype=pd.StringDtype()),
    'country_domicile': pd.Series(['DEU', 'CHN', 'USA', 'USA', 'MEX'], dtype='category'),
    'ultimate_owner_id': pd.Series([10, 10, 11, 20, 21], dtype=pd.Int64Dtype())
})

# iii. Comtrade Exports Data
comtrade_exports_frame = pd.DataFrame({
    'year': pd.Series([2016, 2021, 2016, 2021], dtype=np.int16),
    'reporter_iso': pd.Series(['DEU', 'DEU', 'CHN', 'CHN'], dtype='category'),
    'product_hs_code': pd.Series([8507, 8507, 2836, 2836], dtype=pd.Int16Dtype()),
    'export_value_usd': pd.Series([1e9, 1.5e9, 2.2e9, 3e9], dtype=np.float64)
})

# iv. Country Data
country_data_frame = pd.DataFrame({
    'year': pd.Series([2016, 2021, 2016, 2021], dtype=np.int16),
    'reporter_iso': pd.Series(['DEU', 'DEU', 'CHN', 'CHN'], dtype='category'),
    'population': pd.Series([82e6, 83e6, 1.37e9, 1.41e9], dtype=np.int64)
})

# For a complete run, these placeholder frames would need to be populated
# with sufficient data to avoid errors in the econometric stage (e.g.,
# not enough observations to fit the model).
```

**Step 1.2: Define Configuration Dictionaries**

Next, we define the non-data inputs: the manifest, the supply chain definitions, and the configurations for the robustness checks.

```python
# v. Supply Chain Definitions
supply_chains_definitions_dict = {
    "EV Battery": [8507, 2836, 2530, 7110, 2825, 2827],
    "Semiconductor": [8541, 8542, 3818, 2804, 8486]
}

# vi. Base Replication Manifest
# This is the primary configuration for the baseline study replication.
replication_manifest = {
    'parameters': {
        'data_ingestion': {
            'start_date': '2021-01-01',
            'end_date': '2023-12-31',
            'hs_revision_standard': 2022,
        },
        'network_inference': {
            'multi_product_firm_hs2_threshold': 5,
            'min_aggregated_link_value_usd': 1000.0,
            'primary_edge_weight_threshold': 2.0,
            'primary_firmcount_threshold': 2,
        },
        'network_analysis': {
            'community_detection_iterations': 100,
            'modularity_validation_simulations': 1000, # Reduced for example speed
        },
        'econometric_analysis': {
            'start_year_for_diversification': 2016,
            'end_year_for_diversification': 2021,
            'rpop_absence_threshold': 0.05,
            'rpop_presence_threshold': 0.1,
            'density_metric_top_k_edges': 50,
            'min_global_trade_for_product_inclusion_usd': 2e9,
        }
    }
}

# --- Define Configurations for Robustness Checks ---

# Configuration for Task 11.1: Parameter Sensitivity Analysis
# We will test a grid of firmcount and edge weight thresholds.
# Note the use of path tuples for the keys, as required by our robust implementation.
parameter_grid_for_robustness = {
    ('network_inference', 'primary_firmcount_threshold'): [2, 3, 4],
    ('network_inference', 'primary_edge_weight_threshold'): [1.5, 2.0, 2.5],
}

# Configuration for Task 11.3: Construction Robustness Analysis
# We will test the baseline 'mean' against 'median' and the 75th percentile.
methods_to_test_for_robustness = ['mean', 'median', 0.75]

# Optional: External networks would be loaded here into the required format.
# For example:
# ai_network_matrix = scipy.sparse.load_npz("path/to/ai_network.npz")
# ai_product_list = pd.read_csv("path/to/ai_products.csv")['hs_code'].tolist()
# external_networks_for_validation = {
#     "AI_Network": {
#         "matrix": ai_network_matrix,
#         "product_list": ai_product_list
#     }
# }
external_networks_for_validation = None
```

#### **2. Execution: Calling the Top-Level Orchestrator**

With all inputs prepared, the entire analysis can be launched with a single function call.

```python
# --- Assume all pipeline functions are imported ---
# from my_pipeline.main import main_analysis_orchestrator

# --- Execute the Full Analysis ---

# This single function call will:
# 1. Run the complete baseline replication of the study.
# 2. Run the full suite of robustness checks as configured.
final_results = main_analysis_orchestrator(
    # Pass all the loaded data
    transactions_log_frame=transactions_log_frame,
    firm_metadata_frame=firm_metadata_frame,
    comtrade_exports_frame=comtrade_exports_frame,
    country_data_frame=country_data_frame,
    supply_chains_definitions_dict=supply_chains_definitions_dict,
    # Pass the configurations
    base_manifest=replication_manifest,
    run_robustness_checks=True, # Explicitly enable the robustness suite
    parameter_grid=parameter_grid_for_robustness,
    methods_to_test=methods_to_test_for_robustness,
    external_networks=external_networks_for_validation,
    n_jobs=-1 # Use all available CPU cores for parallel tasks
)

# --- 3. Inspecting the Results ---

# The `final_results` dictionary now contains all outputs in a structured format.

# Example: Access the baseline econometric results
baseline_econometrics = final_results['baseline_results']['econometric_results']
print("\n--- Baseline Downstream Model Summary ---")
print(baseline_econometrics['downstream']['statsmodels_results'].summary())

# Example: Access the temporal robustness comparison table
temporal_robustness_df = final_results['robustness_results']['temporal_robustness']
print("\n--- Temporal Robustness Results ---")
print(temporal_robustness_df[['downstream_auc', 'upstream_auc', 'sample_size']])

# Example: Access the parameter sensitivity results
sensitivity_df = final_results['robustness_results']['parameter_sensitivity']
print("\n--- Parameter Sensitivity Results ---")
print(sensitivity_df.head())
```

This example provides a complete template for using the pipeline. It separates the data loading and configuration steps from the execution step, uses placeholder data that perfectly matches the required schemas, and demonstrates how to access the structured results for further analysis.


In [None]:
# Task 1: Data Validation and Quality Assurance

# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# Custom Exception Class for Enhanced Error Reporting
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

class DataValidationError(Exception):
    """
    A custom exception class for handling specific data validation failures.

    Purpose:
    --------
    This exception is designed to be raised when an input DataFrame, parameter,
    or data record fails to meet the predefined structural, logical, or quality
    constraints required by the analytical pipeline. Using a custom exception
    type allows for specific `try...except` blocks that can catch validation
    errors distinctly from other runtime errors (e.g., `ValueError`, `TypeError`),
    enabling more granular error handling and clearer, more informative logging.

    Inheritance:
    ------------
    Inherits from the base `Exception` class, making it a standard checked
    exception in Python.

    Usage:
    ------
    This class should be instantiated and raised within validation functions
    when a specific, non-recoverable data integrity issue is detected.

    Example:
    --------
    >>> if not column_set == expected_column_set:
    ...     raise DataValidationError("DataFrame columns do not match schema.")
    """
    # The 'pass' statement indicates that this class inherits all the
    # behavior from its parent class, `Exception`, without adding any new
    # methods or attributes. Its primary purpose is to create a new, distinct
    # exception type for semantic clarity and targeted error handling.
    pass

# =============================================================================
# Task 1.1: Input Parameter Validation
# =============================================================================

def _validate_manifest_structure(
    manifest: Dict[str, Any]
) -> None:
    """
    Validates the hierarchical structure of the replication_manifest dictionary.

    This function ensures that all required primary and nested keys exist,
    preventing downstream errors from misconfigured or incomplete parameter
    dictionaries. It employs a recursive helper function to traverse the
    expected schema.

    Args:
        manifest (Dict[str, Any]): The replication_manifest dictionary.

    Raises:
        DataValidationError: If a required key is missing or a value is not a
                             dictionary where one is expected.
    """
    # Define the required hierarchical schema for the manifest.
    required_schema = {
        'parameters': {
            'data_ingestion': None,
            'network_inference': None,
            'network_analysis': None,
            'econometric_analysis': None,
        }
    }

    # Define a recursive helper function to traverse and validate the schema.
    def _traverse_and_check(
        sub_manifest: Dict[str, Any],
        sub_schema: Dict[str, Any],
        path: str
    ) -> None:
        # Iterate through all keys required by the current schema level.
        for key, nested_schema in sub_schema.items():
            # Check if the key is present in the manifest at the current path.
            if key not in sub_manifest:
                # Raise a specific error if a required key is missing.
                raise DataValidationError(
                    f"Missing required key '{key}' in manifest at path: '{path}'."
                )

            # If the schema expects a nested dictionary, recurse.
            if nested_schema is not None:
                # Get the nested dictionary from the manifest.
                nested_manifest = sub_manifest[key]
                # Ensure the value is indeed a dictionary before traversing.
                if not isinstance(nested_manifest, dict):
                    raise DataValidationError(
                        f"Expected a dictionary for key '{key}' at path '{path}', "
                        f"but found type {type(nested_manifest).__name__}."
                    )
                # Recursively call the traversal function for the nested structure.
                _traverse_and_check(
                    nested_manifest,
                    nested_schema,
                    f"{path}.{key}"
                )

    # Start the validation from the top level of the manifest.
    _traverse_and_check(manifest, required_schema, "manifest")


def _validate_manifest_date_parameters(
    params: Dict[str, Any]
) -> None:
    """
    Validates date parameters within the manifest for format and logical order.

    Args:
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the manifest.

    Raises:
        DataValidationError: If date strings are malformed or if start_date is
                             not before end_date.
    """
    try:
        # Extract the date strings from the data_ingestion parameters.
        start_date_str = params['data_ingestion']['start_date']
        end_date_str = params['data_ingestion']['end_date']

        # Attempt to parse the date strings into pandas Timestamp objects.
        # errors='raise' will throw an exception for any parsing failure.
        start_date = pd.to_datetime(start_date_str, errors='raise')
        end_date = pd.to_datetime(end_date_str, errors='raise')

    except KeyError as e:
        # Catch missing keys and raise a more informative error.
        raise DataValidationError(
            f"Missing required date parameter in 'data_ingestion': {e}"
        )
    except Exception as e:
        # Catch parsing errors and raise a specific validation error.
        raise DataValidationError(
            f"Failed to parse date parameters. Ensure 'start_date' and "
            f"'end_date' are valid date strings. Original error: {e}"
        )

    # Check the logical ordering of the dates.
    if start_date >= end_date:
        # Raise an error if the start date is not strictly before the end date.
        raise DataValidationError(
            f"Logical error: 'start_date' ({start_date_str}) must be before "
            f"'end_date' ({end_date_str})."
        )


def _validate_manifest_numerical_thresholds(
    params: Dict[str, Any]
) -> None:
    """
    Validates numerical thresholds in the manifest against defined constraints.

    Args:
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the manifest.

    Raises:
        DataValidationError: If any numerical parameter is not of a valid type
                             or violates its positivity constraint.
    """
    # Define a list of validation rules for numerical parameters.
    # Each tuple contains: (path_tuple, validation_lambda, error_message).
    validation_rules = [
        (
            ('network_inference', 'multi_product_firm_hs2_threshold'),
            lambda x: x >= 1,
            "must be an integer ≥ 1"
        ),
        (
            ('network_inference', 'min_aggregated_link_value_usd'),
            lambda x: x > 0,
            "must be a number > 0"
        ),
        (
            ('network_inference', 'primary_edge_weight_threshold'),
            lambda x: x > 0,
            "must be a number > 0"
        ),
        (
            ('network_inference', 'primary_firmcount_threshold'),
            lambda x: x >= 1,
            "must be an integer ≥ 1"
        ),
    ]

    # Iterate through each defined validation rule.
    for path, rule, msg in validation_rules:
        try:
            # Safely navigate the nested dictionary to get the parameter value.
            value = params[path[0]][path[1]]

            # First, check if the value is a valid number (int or float).
            if not isinstance(value, (int, float)):
                # Raise a type error if it's not a number.
                raise DataValidationError(
                    f"Parameter '{'.'.join(path)}' must be a number, but found "
                    f"type {type(value).__name__}."
                )

            # Apply the specific validation rule for the parameter.
            if not rule(value):
                # Raise a value error if the rule is violated.
                raise DataValidationError(
                    f"Parameter '{'.'.join(path)}' with value {value} is invalid. "
                    f"Constraint: {msg}."
                )
        except KeyError:
            # Catch cases where the parameter path does not exist.
            raise DataValidationError(
                f"Missing required numerical parameter: '{'.'.join(path)}'."
            )

# =============================================================================
# Task 1.2: DataFrame Structure Validation
# =============================================================================

def _validate_dataframe_schemas(
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame
) -> None:
    """
    Validates the schema (columns and dtypes) of all input DataFrames.

    Args:
        transactions_log_frame (pd.DataFrame): The transactions log.
        firm_metadata_frame (pd.DataFrame): The firm metadata.
        comtrade_exports_frame (pd.DataFrame): The Comtrade export data.
        country_data_frame (pd.DataFrame): The country attribute data.

    Raises:
        DataValidationError: If any DataFrame has incorrect columns or dtypes.
    """
    # Define the expected schema for all input DataFrames.
    expected_schemas = {
        "transactions_log_frame": {
            'seller_firm_id': pd.Int64Dtype(),
            'buyer_firm_id': pd.Int64Dtype(),
            'product_hs_code': pd.Int16Dtype(),
            'transaction_value_usd': np.dtype('float64'),
            'timestamp': np.dtype('datetime64[ns]')
        },
        "firm_metadata_frame": {
            'firm_id': pd.Int64Dtype(),
            'legal_name': pd.StringDtype(),
            'country_domicile': pd.CategoricalDtype(),
            'ultimate_owner_id': pd.Int64Dtype()
        },
        "comtrade_exports_frame": {
            'year': np.dtype('int16'),
            'reporter_iso': pd.CategoricalDtype(),
            'product_hs_code': pd.Int16Dtype(),
            'export_value_usd': np.dtype('float64')
        },
        "country_data_frame": {
            'year': np.dtype('int16'),
            'reporter_iso': pd.CategoricalDtype(),
            'population': np.dtype('int64')
        }
    }

    # Create a dictionary mapping DataFrame names to their objects.
    data_frames = {
        "transactions_log_frame": transactions_log_frame,
        "firm_metadata_frame": firm_metadata_frame,
        "comtrade_exports_frame": comtrade_exports_frame,
        "country_data_frame": country_data_frame
    }

    # Iterate through each DataFrame and its expected schema.
    for name, schema in expected_schemas.items():
        # Get the actual DataFrame object.
        df = data_frames[name]

        # 1. Validate column set integrity.
        # Get the set of expected and actual column names.
        expected_cols = set(schema.keys())
        actual_cols = set(df.columns)

        # Check if the column sets are identical.
        if expected_cols != actual_cols:
            # Identify missing and unexpected columns for a precise error message.
            missing = sorted(list(expected_cols - actual_cols))
            extra = sorted(list(actual_cols - expected_cols))
            raise DataValidationError(
                f"Schema validation failed for '{name}'.\n"
                f"Missing columns: {missing if missing else 'None'}\n"
                f"Unexpected columns: {extra if extra else 'None'}"
            )

        # 2. Validate dtypes for each column.
        for col, expected_dtype in schema.items():
            # Get the actual dtype of the column.
            actual_dtype = df[col].dtype
            # Compare the actual dtype with the expected dtype.
            if actual_dtype != expected_dtype:
                raise DataValidationError(
                    f"Dtype validation failed for '{name}'.\n"
                    f"Column '{col}' has dtype '{actual_dtype}', but "
                    f"expected '{expected_dtype}'."
                )

    # 3. Validate ISO code format for relevant columns.
    # Check 'country_domicile' in firm_metadata_frame.
    invalid_iso_domicile = firm_metadata_frame[
        ~firm_metadata_frame['country_domicile'].astype(str).str.match(r'^[A-Z]{3}$')
    ]
    if not invalid_iso_domicile.empty:
        raise DataValidationError(
            "Invalid ISO 3166-1 alpha-3 format found in "
            "'firm_metadata_frame.country_domicile'."
        )

    # Check 'reporter_iso' in comtrade_exports_frame.
    invalid_iso_comtrade = comtrade_exports_frame[
        ~comtrade_exports_frame['reporter_iso'].astype(str).str.match(r'^[A-Z]{3}$')
    ]
    if not invalid_iso_comtrade.empty:
        raise DataValidationError(
            "Invalid ISO 3166-1 alpha-3 format found in "
            "'comtrade_exports_frame.reporter_iso'."
        )

# =============================================================================
# Task 1.3: Data Completeness Assessment
# =============================================================================

def _assess_missing_values(
    transactions_log_frame: pd.DataFrame
) -> None:
    """
    Assesses missing value percentages in critical columns and issues warnings.

    Args:
        transactions_log_frame (pd.DataFrame): The transactions log.
    """
    # Define critical columns and the warning threshold.
    critical_columns = [
        'seller_firm_id', 'buyer_firm_id',
        'product_hs_code', 'transaction_value_usd'
    ]
    threshold = 0.01  # 1%

    # Calculate the percentage of missing values for each column.
    missing_pct = transactions_log_frame[critical_columns].isnull().mean()

    # Iterate through the critical columns to check against the threshold.
    for col, pct in missing_pct.items():
        # If the percentage of missing values exceeds the threshold...
        if pct > threshold:
            # ...issue a user warning that does not halt execution.
            warnings.warn(
                f"Column '{col}' has {pct:.2%} missing values, which "
                f"exceeds the {threshold:.0%} threshold.",
                UserWarning
            )


def _assess_ownership_completeness(
    firm_metadata_frame: pd.DataFrame
) -> float:
    """
    Calculates the percentage of firms with unresolved ownership.

    Args:
        firm_metadata_frame (pd.DataFrame): The firm metadata.

    Returns:
        float: The percentage of firms where 'ultimate_owner_id' is null.
    """
    # Calculate the percentage of nulls in the 'ultimate_owner_id' column.
    unresolved_pct = firm_metadata_frame['ultimate_owner_id'].isnull().mean()

    # Log the finding for informational purposes.
    logging.info(
        f"Ownership resolution completeness: {1 - unresolved_pct:.2%} of firms "
        f"have an 'ultimate_owner_id'. {unresolved_pct:.2%} are unresolved."
    )

    # Return the percentage of unresolved firms.
    return unresolved_pct


def _assess_temporal_and_hs_code_coverage(
    transactions_log_frame: pd.DataFrame,
    params: Dict[str, Any]
) -> None:
    """
    Validates temporal coverage and HS code range integrity.

    Args:
        transactions_log_frame (pd.DataFrame): The transactions log.
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the manifest.

    Raises:
        DataValidationError: If temporal coverage is incorrect or HS codes are
                             outside the valid 4-digit range.
    """
    # 1. Validate temporal coverage.
    # Get expected start and end dates from parameters.
    expected_start = pd.to_datetime(params['data_ingestion']['start_date'])
    expected_end = pd.to_datetime(params['data_ingestion']['end_date'])

    # Get actual min and max timestamps from the data.
    actual_min_date = transactions_log_frame['timestamp'].min()
    actual_max_date = transactions_log_frame['timestamp'].max()

    # Check if the actual date range is contained within the expected range.
    if not (actual_min_date >= expected_start and actual_max_date <= expected_end):
        raise DataValidationError(
            f"Temporal coverage validation failed. Data spans from "
            f"{actual_min_date.date()} to {actual_max_date.date()}, which is "
            f"outside the expected range of {expected_start.date()} to "
            f"{expected_end.date()}."
        )

    # 2. Validate HS code range.
    # Create a boolean mask for valid HS codes (4-digit, non-zero).
    # HS codes range from section 01 to 99, so codes are 0101 to 99xx.
    is_valid_hs = (
        (transactions_log_frame['product_hs_code'] >= 101) &
        (transactions_log_frame['product_hs_code'] <= 9999)
    )

    # Count the number of invalid HS codes.
    invalid_hs_count = (~is_valid_hs).sum()

    # If any invalid codes are found, raise an error.
    if invalid_hs_count > 0:
        raise DataValidationError(
            f"Found {invalid_hs_count} records with invalid 4-digit HS codes "
            f"(outside the range 0101-9999)."
        )

# =============================================================================
# Task 1: Orchestrator Function
# =============================================================================

def validate_and_assess_inputs(
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    replication_manifest: Dict[str, Any]
) -> None:
    """
    Orchestrates the execution of all data validation and quality assurance tasks.

    This function serves as the single entry point for Task 1, sequentially
    calling all validation and assessment functions. It will raise a
    DataValidationError if any critical check fails, or issue UserWarnings for
    non-critical issues.

    Args:
        transactions_log_frame (pd.DataFrame): Log of firm transactions.
        firm_metadata_frame (pd.DataFrame): Firm ownership and location data.
        comtrade_exports_frame (pd.DataFrame): Country-product export data.
        country_data_frame (pd.DataFrame): Country population data.
        replication_manifest (Dict[str, Any]): Dictionary of study parameters.

    Raises:
        DataValidationError: If any validation step fails.
    """
    # Set up basic logging configuration.
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    # Log the start of the validation process.
    logging.info("Starting Task 1: Data Validation and Quality Assurance...")

    # --- Step 1.1: Input Parameter Validation ---
    logging.info("Step 1.1: Validating 'replication_manifest' parameters...")
    # Validate the overall dictionary structure.
    _validate_manifest_structure(manifest=replication_manifest)
    # Extract the parameters sub-dictionary for convenience.
    params = replication_manifest['parameters']
    # Validate date parameters for format and logical order.
    _validate_manifest_date_parameters(params=params)
    # Validate numerical thresholds for type and positivity constraints.
    _validate_manifest_numerical_thresholds(params=params)
    logging.info("...Manifest parameters validated successfully.")

    # --- Step 1.2: DataFrame Structure Validation ---
    logging.info("Step 1.2: Validating DataFrame schemas (columns and dtypes)...")
    # Validate the structure of all input DataFrames.
    _validate_dataframe_schemas(
        transactions_log_frame=transactions_log_frame,
        firm_metadata_frame=firm_metadata_frame,
        comtrade_exports_frame=comtrade_exports_frame,
        country_data_frame=country_data_frame
    )
    logging.info("...DataFrame schemas validated successfully.")

    # --- Step 1.3: Data Completeness Assessment ---
    logging.info("Step 1.3: Assessing data completeness and coverage...")
    # Assess missing values in critical transaction columns.
    _assess_missing_values(transactions_log_frame=transactions_log_frame)
    # Assess the completeness of the firm ownership data.
    _assess_ownership_completeness(firm_metadata_frame=firm_metadata_frame)
    # Assess temporal and HS code coverage in the transaction data.
    _assess_temporal_and_hs_code_coverage(
        transactions_log_frame=transactions_log_frame,
        params=params
    )
    logging.info("...Data completeness assessment finished.")

    # Log the successful completion of the entire task.
    logging.info("Task 1 successfully completed. All inputs are validated.")


In [None]:
# Task 2: Data Preprocessing and Cleansing

# =============================================================================
# Task 2.1: Temporal and Value Filtering
# =============================================================================

def _filter_transactions_by_date_and_value(
    transactions_df: pd.DataFrame,
    params: Dict[str, Any]
) -> pd.DataFrame:
    """
    Filters transactions based on a specified date range and positive value.

    This function applies two critical filters to the raw transaction log:
    1. Temporal Filter: Retains only transactions within the study period
       defined in the manifest (e.g., 2021-01-01 to 2023-12-31).
    2. Value Filter: Removes transactions with a non-positive monetary value,
       as these are considered invalid for economic analysis.

    Args:
        transactions_df (pd.DataFrame): The input transactions log.
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the
                                 replication manifest.

    Returns:
        pd.DataFrame: A new DataFrame containing only the filtered transactions.
    """
    # Log the initial number of records for audit purposes.
    initial_record_count = len(transactions_df)
    logging.info(f"Initial transaction records: {initial_record_count:,}")

    # --- Temporal Filtering ---
    # Extract start and end dates from the parameters.
    start_date = pd.to_datetime(params['data_ingestion']['start_date'])
    end_date = pd.to_datetime(params['data_ingestion']['end_date'])

    # Create a boolean mask for transactions within the valid date range.
    # The comparison is inclusive of the start and end dates.
    date_mask = (
        (transactions_df['timestamp'] >= start_date) &
        (transactions_df['timestamp'] <= end_date)
    )

    # --- Value Filtering ---
    # Create a boolean mask for transactions with a positive monetary value.
    value_mask = transactions_df['transaction_value_usd'] > 0

    # Combine the masks into a single filter for efficiency.
    combined_mask = date_mask & value_mask

    # Apply the combined mask to the DataFrame.
    # Using .copy() is crucial to prevent SettingWithCopyWarning downstream.
    filtered_df = transactions_df[combined_mask].copy()

    # Log the number of records after filtering.
    final_record_count = len(filtered_df)
    records_removed = initial_record_count - final_record_count
    logging.info(
        f"Removed {records_removed:,} records due to date or invalid value. "
        f"Remaining records: {final_record_count:,}"
    )

    # Return the new, filtered DataFrame.
    return filtered_df

# =============================================================================
# Task 2.2: Firm Entity Resolution
# =============================================================================

def _resolve_firm_entities(
    transactions_df: pd.DataFrame,
    firm_metadata_df: pd.DataFrame
) -> pd.DataFrame:
    """
    Resolves firm IDs to ultimate owners and creates owner-country entities.

    This function executes a critical transformation by moving from the raw
    `firm_id` to a more analytically relevant "owner-country entity".
    The process involves:
    1. Mapping each transacting firm to its ultimate owner. If an owner is
       not specified, the firm is treated as its own ultimate owner.
    2. Augmenting the transaction log with the country of domicile for both
       the seller's and buyer's ultimate owner.
    3. Creating a composite string identifier ('owner_id_country_iso') for
       both seller and buyer, which serves as the primary entity identifier
       for the rest of the analysis.

    Args:
        transactions_df (pd.DataFrame): The transaction log DataFrame.
        firm_metadata_df (pd.DataFrame): The firm metadata DataFrame.

    Returns:
        pd.DataFrame: An augmented DataFrame with resolved owner-country
                      entities for both seller and buyer.

    Raises:
        DataValidationError: If any firm ID in the transaction log cannot be
                             found in the firm metadata.
    """
    # --- Step 2.2.1 & 2.2.2: Ultimate Owner Resolution ---
    # Create a working copy of the metadata to avoid modifying the original.
    metadata = firm_metadata_df.copy()

    # Per the paper's methodology, fill missing ultimate owner IDs with the
    # firm's own ID, treating them as self-owned.
    metadata['ultimate_owner_id'] = metadata['ultimate_owner_id'].fillna(
        metadata['firm_id']
    )

    # Create an efficient mapping from firm_id to ultimate_owner_id.
    # A dictionary is faster for mapping than repeated merges.
    owner_map = metadata.set_index('firm_id')['ultimate_owner_id'].to_dict()

    # Create a working copy of the transactions DataFrame.
    resolved_df = transactions_df.copy()

    # Map seller and buyer firm IDs to their ultimate owner IDs.
    resolved_df['seller_owner_id'] = resolved_df['seller_firm_id'].map(owner_map)
    resolved_df['buyer_owner_id'] = resolved_df['buyer_firm_id'].map(owner_map)

    # Validate that all mappings were successful.
    if resolved_df[['seller_owner_id', 'buyer_owner_id']].isnull().any().any():
        raise DataValidationError(
            "Entity resolution failed: Some firm IDs in the transaction log "
            "could not be mapped to an ultimate owner. Check for missing "
            "firms in the metadata."
        )

    # --- Step 2.2.3: Owner-Country Composite Entity Creation ---
    # Create a minimal mapping from firm_id to country_domicile for merging.
    # This is more memory-efficient than merging the full metadata table.
    country_map_df = metadata[['firm_id', 'country_domicile']].drop_duplicates()

    # Merge to get the seller's country of domicile.
    resolved_df = pd.merge(
        resolved_df,
        country_map_df,
        left_on='seller_firm_id',
        right_on='firm_id',
        how='left'
    )
    resolved_df.rename(columns={'country_domicile': 'seller_country'}, inplace=True)

    # Merge to get the buyer's country of domicile.
    resolved_df = pd.merge(
        resolved_df,
        country_map_df,
        left_on='buyer_firm_id',
        right_on='firm_id',
        how='left',
        suffixes=('_seller_map', '_buyer_map') # Suffixes not strictly needed here
    )
    resolved_df.rename(columns={'country_domicile': 'buyer_country'}, inplace=True)

    # Validate that all country lookups were successful.
    if resolved_df[['seller_country', 'buyer_country']].isnull().any().any():
        raise DataValidationError(
            "Country resolution failed: Could not find country domicile for "
            "all firms in the transaction log."
        )

    # Create the composite owner-country entity identifiers.
    # Ensure IDs are strings for concatenation.
    resolved_df['seller_owner_country_entity'] = (
        resolved_df['seller_owner_id'].astype(str) + '_' +
        resolved_df['seller_country'].astype(str)
    )
    resolved_df['buyer_owner_country_entity'] = (
        resolved_df['buyer_owner_id'].astype(str) + '_' +
        resolved_df['buyer_country'].astype(str)
    )

    # Clean up intermediate columns.
    resolved_df.drop(
        columns=[
            'firm_id_seller_map', 'firm_id_buyer_map', 'firm_id'
        ],
        inplace=True,
        errors='ignore' # Use ignore in case column names vary slightly
    )

    logging.info("Successfully resolved firm IDs to owner-country entities.")
    return resolved_df

# =============================================================================
# Task 2.3: Cross-Border and Multi-Product Firm Filtering
# =============================================================================

def _filter_by_transaction_type_and_firm_diversity(
    resolved_df: pd.DataFrame,
    params: Dict[str, Any]
) -> pd.DataFrame:
    """
    Filters for cross-border transactions and removes highly diversified firms.

    This function applies the final two cleansing steps:
    1. Cross-Border Filter: Removes all transactions where the seller and
       buyer entities operate in the same country.
    2. Multi-Product Firm Filter: Removes all transactions where the selling
       entity exports products across a large number of distinct HS 2-digit
       sections, as these are likely wholesalers or conglomerates not
       specialized in production.

    Args:
        resolved_df (pd.DataFrame): The DataFrame with resolved owner-country
                                    entities.
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the
                                 replication manifest.

    Returns:
        pd.DataFrame: A new, fully cleansed DataFrame ready for analysis.
    """
    # Log the record count before this filtering stage.
    initial_record_count = len(resolved_df)

    # --- Step 2.3.2: Cross-Border Transaction Filtering ---
    # Create a boolean mask to identify cross-border transactions.
    cross_border_mask = (
        resolved_df['seller_country'] != resolved_df['buyer_country']
    )
    # Apply the mask.
    cross_border_df = resolved_df[cross_border_mask].copy()

    # Log the number of domestic transactions removed.
    domestic_removed = initial_record_count - len(cross_border_df)
    logging.info(
        f"Removed {domestic_removed:,} domestic transactions. "
        f"Remaining records: {len(cross_border_df):,}"
    )

    # --- Step 2.3.3: Multi-Product Firm Filtering ---
    # Get the diversity threshold from the parameters.
    threshold = params['network_inference']['multi_product_firm_hs2_threshold']

    # Calculate the HS 2-digit section for each product.
    # Integer division of the 4-digit code by 100 yields the 2-digit section.
    cross_border_df['hs2_section'] = cross_border_df['product_hs_code'] // 100

    # Group by the selling entity and count the number of unique HS2 sections.
    seller_diversity = cross_border_df.groupby(
        'seller_owner_country_entity'
    )['hs2_section'].nunique()

    # Identify the entities that exceed the diversity threshold.
    diversified_sellers = seller_diversity[seller_diversity >= threshold].index

    # Create a boolean mask to filter out transactions from these sellers.
    # The .isin() method is highly efficient for this type of filtering.
    multi_product_mask = ~cross_border_df['seller_owner_country_entity'].isin(
        diversified_sellers
    )

    # Apply the mask to get the final cleansed DataFrame.
    cleansed_df = cross_border_df[multi_product_mask].copy()

    # Log the number of records removed due to firm diversity.
    multi_product_removed = len(cross_border_df) - len(cleansed_df)
    logging.info(
        f"Removed {len(diversified_sellers):,} highly diversified seller entities, "
        f"resulting in the removal of {multi_product_removed:,} transactions."
    )
    logging.info(f"Final cleansed transaction records: {len(cleansed_df):,}")

    # Return the final DataFrame, dropping the temporary hs2_section column.
    return cleansed_df.drop(columns=['hs2_section'])

# =============================================================================
# Task 2: Orchestrator Function
# =============================================================================

def preprocess_and_cleanse_data(
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    replication_manifest: Dict[str, Any]
) -> pd.DataFrame:
    """
    Orchestrates the end-to-end data preprocessing and cleansing pipeline.

    This function executes the full sequence of cleansing and transformation
    steps required to convert the raw, validated input data into an
    analysis-ready format. It chains together the filtering, entity
    resolution, and business rule application functions.

    Args:
        transactions_log_frame (pd.DataFrame): The validated raw transaction log.
        firm_metadata_frame (pd.DataFrame): The validated firm metadata.
        replication_manifest (Dict[str, Any]): The validated study parameters.

    Returns:
        pd.DataFrame: The fully preprocessed and cleansed transaction log,
                      ready for network inference.
    """
    # Log the start of the preprocessing task.
    logging.info("Starting Task 2: Data Preprocessing and Cleansing...")

    # Extract the parameters sub-dictionary for use in helper functions.
    params = replication_manifest['parameters']

    # Step 2.1: Apply temporal and transaction value filters.
    logging.info("Step 2.1: Applying temporal and value filters...")
    filtered_df = _filter_transactions_by_date_and_value(
        transactions_df=transactions_log_frame,
        params=params
    )

    # Step 2.2: Resolve firm IDs to ultimate owners and create composite entities.
    logging.info("Step 2.2: Resolving firm entities...")
    resolved_df = _resolve_firm_entities(
        transactions_df=filtered_df,
        firm_metadata_df=firm_metadata_frame
    )

    # Step 2.3: Apply cross-border and multi-product firm filters.
    logging.info("Step 2.3: Applying cross-border and firm diversity filters...")
    cleansed_df = _filter_by_transaction_type_and_firm_diversity(
        resolved_df=resolved_df,
        params=params
    )

    # Log the successful completion of the task.
    logging.info("Task 2 successfully completed. Data is cleansed and ready for analysis.")

    # Return the final, analysis-ready DataFrame.
    return cleansed_df


In [None]:
# Task 3: Firm Classification and Set Construction

# =============================================================================
# Task 3.1 & 3.2: Producer and Purchaser Set Identification
# =============================================================================

def _identify_significant_entities(
    cleansed_df: pd.DataFrame,
    entity_type: str
) -> Dict[int, Set[str]]:
    """
    Identifies significant producers or purchasers for each product.

    This function implements the core classification methodology from the paper.
    An entity is classified as a "significant" producer (or purchaser) of a
    product if its total sales (or purchase) value for that product exceeds
    the average total sales (or purchase) value across all entities active
    in that product market.

    Args:
        cleansed_df (pd.DataFrame): The preprocessed and cleansed transaction log.
        entity_type (str): The type of entity to identify. Must be either
                           'producer' or 'purchaser'.

    Returns:
        Dict[int, Set[str]]: A dictionary mapping each product HS code (int) to
                             a set of significant entity identifiers (str).

    Raises:
        ValueError: If entity_type is not 'producer' or 'purchaser'.
    """
    # Validate the entity_type parameter to ensure correct column selection.
    if entity_type == 'producer':
        # For producers, the relevant entity is the seller.
        entity_col = 'seller_owner_country_entity'
    elif entity_type == 'purchaser':
        # For purchasers, the relevant entity is the buyer.
        entity_col = 'buyer_owner_country_entity'
    else:
        # Raise an error for invalid entity types.
        raise ValueError("entity_type must be either 'producer' or 'purchaser'.")

    logging.info(f"Identifying significant {entity_type}s...")

    # --- Step 1: Calculate total transaction value per (entity, product) pair.
    # Group by the product and the relevant entity column, then sum the values.
    entity_product_values = cleansed_df.groupby(
        ['product_hs_code', entity_col]
    )['transaction_value_usd'].sum().reset_index()
    entity_product_values.rename(
        columns={'transaction_value_usd': 'total_value'}, inplace=True
    )

    # --- Step 2: Calculate the average total value threshold for each product.
    # We use .transform('mean') which is highly efficient. It calculates the
    # mean for each product group and broadcasts this value back to the original
    # shape of the grouped data, avoiding a costly merge operation.
    entity_product_values['avg_total_value'] = entity_product_values.groupby(
        'product_hs_code'
    )['total_value'].transform('mean')

    # --- Step 3: Filter for entities that exceed the average threshold.
    # This boolean mask identifies the significant entities for each product.
    significant_mask = (
        entity_product_values['total_value'] >
        entity_product_values['avg_total_value']
    )
    significant_entities_df = entity_product_values[significant_mask]

    # --- Step 4: Aggregate the entities into sets for each product.
    # Group the filtered DataFrame by product and apply the `set` constructor
    # to the entity column to collect all significant entities for that product.
    entity_sets_series = significant_entities_df.groupby(
        'product_hs_code'
    )[entity_col].apply(set)

    # Convert the resulting Series to a dictionary for fast lookups.
    entity_sets_dict = entity_sets_series.to_dict()

    logging.info(
        f"Identified significant {entity_type}s for "
        f"{len(entity_sets_dict):,} products."
    )

    # Return the final dictionary of sets.
    return entity_sets_dict

# =============================================================================
# Task 3.3: Intersection Set Construction and Value Aggregation
# =============================================================================

def _compute_intersection_metrics(
    cleansed_df: pd.DataFrame,
    producer_sets: Dict[int, Set[str]],
    purchaser_sets: Dict[int, Set[str]]
) -> pd.DataFrame:
    """
    Computes intersection set sizes and their total transaction values.

    This function calculates the two key metrics for each potential
    input-output product link (i, j) required for network inference:
    1.  |S_i^j|: The number of firms that are both significant producers of
        product j AND significant purchasers of product i.
    2.  Value(S_i^j): The total monetary value of product i purchased by the
        firms in the intersection set S_i^j.

    It uses an efficient, vectorized approach by creating tables of all
    products produced and purchased by each firm and then merging them.

    Args:
        cleansed_df (pd.DataFrame): The preprocessed transaction log.
        producer_sets (Dict[int, Set[str]]): A dictionary mapping product codes
                                             to sets of producer entities.
        purchaser_sets (Dict[int, Set[str]]): A dictionary mapping product codes
                                              to sets of purchaser entities.

    Returns:
        pd.DataFrame: A DataFrame with a MultiIndex (input_product_i,
                      output_product_j) and columns 'intersection_size' and
                      'intersection_value'.
    """
    logging.info("Computing intersection metrics for all product pairs...")

    # --- Step 1: Create long-format DataFrames of (entity, product) pairs ---
    # This "un-nests" the dictionaries into a format suitable for merging.
    producers_long = pd.DataFrame(
        [
            (entity, prod)
            for prod, entities in producer_sets.items()
            for entity in entities
        ],
        columns=['entity', 'output_product_j']
    )

    purchasers_long = pd.DataFrame(
        [
            (entity, prod)
            for prod, entities in purchaser_sets.items()
            for entity in entities
        ],
        columns=['entity', 'input_product_i']
    )

    # --- Step 2: Merge to find entities in both producer and purchaser sets ---
    # This self-merge on 'entity' is the vectorized equivalent of finding all
    # intersection sets S_i^j. The result contains every (entity, i, j) triplet.
    intersections_df = pd.merge(
        purchasers_long, producers_long, on='entity'
    )

    # --- Step 3: Calculate intersection sizes ---
    # Group by the product pair (i, j) and count the number of unique entities.
    # This directly calculates |S_i^j| for all pairs.
    intersection_sizes = intersections_df.groupby(
        ['input_product_i', 'output_product_j']
    )['entity'].nunique().to_frame(name='intersection_size')

    # --- Step 4: Calculate intersection transaction values ---
    # To get the value, we need to link back to the original transactions.
    # We merge the triplets with the transaction log.
    # The merge keys are the buyer entity and the input product.
    value_agg_df = pd.merge(
        cleansed_df,
        intersections_df,
        left_on=['buyer_owner_country_entity', 'product_hs_code'],
        right_on=['entity', 'input_product_i'],
        how='inner'
    )

    # Now, group by the product pair (i, j) and sum the transaction values.
    intersection_values = value_agg_df.groupby(
        ['input_product_i', 'output_product_j']
    )['transaction_value_usd'].sum().to_frame(name='intersection_value')

    # --- Step 5: Combine size and value metrics ---
    # Join the two resulting DataFrames to create the final output.
    # We use an outer join to ensure we don't lose any pairs, though in
    # practice they should have the same index. Fill any potential NaNs with 0.
    final_metrics_df = intersection_sizes.join(
        intersection_values, how='outer'
    ).fillna(0)

    logging.info(
        f"Computed metrics for {len(final_metrics_df):,} unique "
        f"(input, output) product links."
    )

    return final_metrics_df

# =============================================================================
# Task 3: Orchestrator Function
# =============================================================================

def classify_firms_and_construct_sets(
    cleansed_df: pd.DataFrame
) -> Tuple[Dict[int, Set[str]], Dict[int, Set[str]], pd.DataFrame]:
    """
    Orchestrates the firm classification and set construction pipeline.

    This function takes the cleansed transaction data and performs the
    necessary aggregations and transformations to produce the core data
    structures needed for the network inference algorithm in Task 4.

    Args:
        cleansed_df (pd.DataFrame): The fully preprocessed and cleansed
                                    transaction log from Task 2.

    Returns:
        Tuple[Dict[int, Set[str]], Dict[int, Set[str]], pd.DataFrame]:
        A tuple containing:
        - producer_sets: A dictionary mapping product codes to sets of
                         significant producer entities.
        - purchaser_sets: A dictionary mapping product codes to sets of
                          significant purchaser entities.
        - intersection_metrics: A DataFrame containing the size and total
                                transaction value for each product-pair
                                intersection set.
    """
    logging.info("Starting Task 3: Firm Classification and Set Construction...")

    # Step 3.1: Identify significant producer entities for each product.
    producer_sets = _identify_significant_entities(
        cleansed_df=cleansed_df,
        entity_type='producer'
    )

    # Step 3.2: Identify significant purchaser entities for each product.
    purchaser_sets = _identify_significant_entities(
        cleansed_df=cleansed_df,
        entity_type='purchaser'
    )

    # Step 3.3: Compute intersection metrics (size and value).
    intersection_metrics = _compute_intersection_metrics(
        cleansed_df=cleansed_df,
        producer_sets=producer_sets,
        purchaser_sets=purchaser_sets
    )

    logging.info("Task 3 successfully completed. All sets constructed.")

    # Return the three essential data structures for the next task.
    return producer_sets, purchaser_sets, intersection_metrics


In [None]:
# Task 4: Network Inference Implementation

# =============================================================================
# Task 4.1: Adjacency Matrix Computation
# =============================================================================

def _compute_adjacency_weights(
    producer_sets: Dict[int, Set[str]],
    purchaser_sets: Dict[int, Set[str]],
    intersection_metrics: pd.DataFrame,
    total_unique_entities: int
) -> pd.DataFrame:
    """
    Computes the raw adjacency matrix weights (A_ij) using the core formula.

    This function implements the central equation of the paper in a fully
    vectorized manner:
    A_ij = ( |S_i^j| / |S_j| ) / ( |S_i^†| / |S| )
    where:
    - |S_i^j|: Number of firms producing j and purchasing i.
    - |S_j|: Total number of firms producing j.
    - |S_i^†|: Total number of firms purchasing i.
    - |S|: Total number of unique firms in the dataset.

    Args:
        producer_sets (Dict[int, Set[str]]): Map from product code to the set
                                             of its producer entities.
        purchaser_sets (Dict[int, Set[str]]): Map from product code to the set
                                              of its purchaser entities.
        intersection_metrics (pd.DataFrame): DataFrame with MultiIndex
                                             (input_product_i, output_product_j)
                                             and a column 'intersection_size'.
        total_unique_entities (int): The total number of unique owner-country
                                     entities in the cleansed dataset (|S|).

    Returns:
        pd.DataFrame: The intersection_metrics DataFrame augmented with a
                      new column 'A_ij' containing the computed weights.
    """
    logging.info("Computing raw adjacency matrix weights (A_ij)...")

    # Create a working copy to avoid modifying the original DataFrame.
    adj_df = intersection_metrics.copy()

    # --- Calculate each term of the formula as a pandas Series ---
    # |S_j|: Cardinality of each producer set.
    s_j_sizes = pd.Series(
        {prod: len(entities) for prod, entities in producer_sets.items()},
        name='s_j_size'
    )

    # |S_i^†|: Cardinality of each purchaser set.
    s_i_dagger_sizes = pd.Series(
        {prod: len(entities) for prod, entities in purchaser_sets.items()},
        name='s_i_dagger_size'
    )

    # --- Align sizes with the intersection DataFrame using its index ---
    # Map the producer set sizes to the 'output_product_j' level of the index.
    adj_df['s_j_size'] = adj_df.index.get_level_values('output_product_j').map(s_j_sizes)

    # Map the purchaser set sizes to the 'input_product_i' level of the index.
    adj_df['s_i_dagger_size'] = adj_df.index.get_level_values('input_product_i').map(s_i_dagger_sizes)

    # --- Compute the numerator and denominator of the formula ---
    # Numerator: (|S_i^j| / |S_j|)
    # This is the proportion of producers of j that also purchase i.
    numerator = adj_df['intersection_size'] / adj_df['s_j_size']

    # Denominator: (|S_i^†| / |S|)
    # This is the baseline proportion of all firms that purchase i.
    denominator = adj_df['s_i_dagger_size'] / total_unique_entities

    # --- Compute A_ij and handle potential division by zero ---
    # The division of the two proportions gives the final weight.
    adj_df['A_ij'] = numerator / denominator

    # Handle edge cases: division by zero can result in `inf` or `NaN`.
    # These cases are not meaningful links, so their weight should be 0.
    adj_df.replace([np.inf, -np.inf], np.nan, inplace=True)
    adj_df['A_ij'].fillna(0, inplace=True)

    logging.info("Successfully computed raw A_ij weights.")

    # Return the DataFrame with the new 'A_ij' column.
    return adj_df.drop(columns=['s_j_size', 's_i_dagger_size'])

# =============================================================================
# Task 4.2: Network Sparsification
# =============================================================================

def _sparsify_network_edges(
    weighted_edges_df: pd.DataFrame,
    params: Dict[str, Any]
) -> pd.DataFrame:
    """
    Applies multiple filters to sparsify the network, retaining strong links.

    Args:
        weighted_edges_df (pd.DataFrame): DataFrame of all potential edges with
                                          their computed A_ij weights and
                                          intersection metrics.
        params (Dict[str, Any]): The 'parameters' sub-dictionary from the
                                 replication manifest.

    Returns:
        pd.DataFrame: A filtered DataFrame containing only the edges that
                      meet all specified criteria.
    """
    logging.info("Sparsifying network based on thresholds...")
    initial_edge_count = len(weighted_edges_df)

    # --- Step 4.2.1: Apply firmcount threshold ---
    # Filter based on the minimum number of firms supporting a link.
    firmcount_threshold = params['network_inference']['primary_firmcount_threshold']
    firmcount_mask = weighted_edges_df['intersection_size'] >= firmcount_threshold

    # --- Step 4.2.2: Apply edge weight threshold ---
    # Filter based on the minimum A_ij weight.
    edge_weight_threshold = params['network_inference']['primary_edge_weight_threshold']
    weight_mask = weighted_edges_df['A_ij'] >= edge_weight_threshold

    # --- Step 4.2.3: Apply minimum transaction value filter ---
    # Filter based on the minimum total monetary value of the link.
    value_threshold = params['network_inference']['min_aggregated_link_value_usd']
    value_mask = weighted_edges_df['intersection_value'] >= value_threshold

    # Combine all masks into a single filter.
    final_mask = firmcount_mask & weight_mask & value_mask

    # Apply the final mask to get the sparsified set of edges.
    sparsified_df = weighted_edges_df[final_mask].copy()

    final_edge_count = len(sparsified_df)
    edges_removed = initial_edge_count - final_edge_count
    logging.info(
        f"Sparsification complete. Removed {edges_removed:,} edges. "
        f"Final network has {final_edge_count:,} edges."
    )

    return sparsified_df

# =============================================================================
# Task 4.3: Network Graph Construction
# =============================================================================

def _construct_network_objects(
    final_edges_df: pd.DataFrame,
    all_products: List[int]
) -> Tuple[csr_matrix, nx.DiGraph, Dict[int, int]]:
    """
    Constructs the final network representations: a sparse matrix and a DiGraph.

    Args:
        final_edges_df (pd.DataFrame): The filtered DataFrame of network edges.
        all_products (List[int]): A sorted list of all unique product codes
                                  that will form the nodes of the network.

    Returns:
        Tuple[csr_matrix, nx.DiGraph, Dict[int, int]]: A tuple containing:
        - adj_matrix: The n x n weighted adjacency matrix in CSR format.
        - graph: The networkx DiGraph object representing the network.
        - product_to_idx: A dictionary mapping product HS codes to matrix indices.
    """
    logging.info("Constructing final network objects (sparse matrix and graph)...")

    # --- Step 1: Create the product-to-index mapping ---
    # This map is the single source of truth for node indexing.
    n_products = len(all_products)
    product_to_idx = {product: i for i, product in enumerate(all_products)}

    # --- Step 2: Prepare edge data for matrix construction ---
    # Reset index to access product codes as columns.
    edges = final_edges_df.reset_index()

    # Map product codes to their integer indices.
    row_indices = edges['input_product_i'].map(product_to_idx)
    col_indices = edges['output_product_j'].map(product_to_idx)

    # The data for the sparse matrix is the final A_ij weight.
    edge_weights = edges['A_ij']

    # --- Step 3: Construct the sparse adjacency matrix ---
    # Create the matrix in Coordinate (COO) format and then convert to
    # Compressed Sparse Row (CSR) format, which is efficient for analysis.
    adj_matrix = csr_matrix(
        (edge_weights, (row_indices, col_indices)),
        shape=(n_products, n_products)
    )

    # --- Step 4: Construct the networkx DiGraph object ---
    # Create an empty directed graph.
    graph = nx.DiGraph()

    # Add all products as nodes to ensure nodes without edges are included.
    graph.add_nodes_from(all_products)

    # Add the weighted, directed edges from the filtered DataFrame.
    # Create a list of tuples for efficient bulk edge addition.
    edge_tuples = [
        (row['input_product_i'], row['output_product_j'], {'weight': row['A_ij']})
        for _, row in edges.iterrows()
    ]
    graph.add_edges_from(edge_tuples)

    logging.info(
        f"Network construction complete. Graph has {graph.number_of_nodes()} "
        f"nodes and {graph.number_of_edges()} edges."
    )

    return adj_matrix, graph, product_to_idx

# =============================================================================
# Task 4: Orchestrator Function
# =============================================================================

def infer_and_construct_network(
    cleansed_df: pd.DataFrame,
    producer_sets: Dict[int, Set[str]],
    purchaser_sets: Dict[int, Set[str]],
    intersection_metrics: pd.DataFrame,
    replication_manifest: Dict[str, Any]
) -> Tuple[csr_matrix, nx.DiGraph, Dict[int, int]]:
    """
    Orchestrates the full network inference and construction pipeline.

    This function takes the outputs of Task 3 and applies the core inference
    algorithm and filtering rules from the paper to produce the final,
    analysis-ready network representations.

    Args:
        cleansed_df (pd.DataFrame): The cleansed transaction log.
        producer_sets (Dict[int, Set[str]]): Map of product codes to producers.
        purchaser_sets (Dict[int, Set[str]]): Map of product codes to purchasers.
        intersection_metrics (pd.DataFrame): Pre-computed intersection metrics.
        replication_manifest (Dict[str, Any]): The dictionary of study parameters.

    Returns:
        Tuple[csr_matrix, nx.DiGraph, Dict[int, int]]: A tuple containing:
        - adj_matrix: The final n x n weighted adjacency matrix (CSR format).
        - graph: The final networkx DiGraph object.
        - product_to_idx: The mapping from product HS codes to matrix/node indices.
    """
    logging.info("Starting Task 4: Network Inference Implementation...")

    # Extract parameters for convenience.
    params = replication_manifest['parameters']

    # Determine the total number of unique entities (|S|).
    total_unique_entities = len(
        pd.unique(cleansed_df[['seller_owner_country_entity', 'buyer_owner_country_entity']].values.ravel('K'))
    )

    # Step 4.1: Compute the raw A_ij weights for all potential edges.
    weighted_edges_df = _compute_adjacency_weights(
        producer_sets=producer_sets,
        purchaser_sets=purchaser_sets,
        intersection_metrics=intersection_metrics,
        total_unique_entities=total_unique_entities
    )

    # Step 4.2: Sparsify the network by applying all filters.
    final_edges_df = _sparsify_network_edges(
        weighted_edges_df=weighted_edges_df,
        params=params
    )

    # Step 4.3: Construct the final network objects.
    # Define the universe of nodes (all products present in sets).
    all_products = sorted(
        list(set(producer_sets.keys()) | set(purchaser_sets.keys()))
    )

    adj_matrix, graph, product_to_idx = _construct_network_objects(
        final_edges_df=final_edges_df,
        all_products=all_products
    )

    logging.info("Task 4 successfully completed. Network is constructed.")

    return adj_matrix, graph, product_to_idx


In [None]:
# Task 5: Community Detection and Structural Analysis

# =============================================================================
# Task 5.1 & 5.2: Community Detection and Analysis
# =============================================================================

def _detect_communities_leiden(
    graph: nx.DiGraph,
    resolution_parameter: float = 0.1,
    seed: Optional[int] = 42
) -> Dict[int, int]:
    """
    Detects communities in the network using the Leiden algorithm.

    The Leiden algorithm is a state-of-the-art method for community detection
    that improves upon the Louvain algorithm. It is used here as a robust
    implementation for identifying meso-scale clusters in the directed,
    weighted production network. The resolution_parameter is analogous to the
    "time scale" in the original Stability algorithm: lower values tend to
    produce fewer, larger communities.

    Args:
        graph (nx.DiGraph): The input networkx directed graph.
        resolution_parameter (float): The resolution parameter for the Leiden
                                      algorithm. Defaults to 0.1 to find a
                                      small number of large communities.
        seed (Optional[int]): A random seed for reproducibility.

    Returns:
        Dict[int, int]: A dictionary mapping each node (product HS code) to
                        its assigned community ID (integer).
    """
    logging.info(
        f"Detecting communities with Leiden algorithm (resolution={resolution_parameter})..."
    )

    # Convert the networkx graph to an igraph object for leidenalg.
    # Note: node names (HS codes) are stored in the 'name' attribute.
    igraph_graph = ig.Graph.from_networkx(graph, create_using=ig.Graph(directed=True))

    # Find the partition using the Leiden algorithm.
    # We use the RBConfigurationVertexPartition, which is suitable for directed graphs.
    # Edge weights from the networkx graph are used to guide the partitioning.
    partition = la.find_partition(
        igraph_graph,
        la.RBConfigurationVertexPartition,
        weights='weight',
        resolution_parameter=resolution_parameter,
        seed=seed
    )

    # Create a mapping from the node name (HS code) back to its community ID.
    # The partition object groups nodes by their igraph index, so we must map back.
    node_names = igraph_graph.vs['name']
    community_map = {
        node_name: membership
        for node_name, membership in zip(node_names, partition.membership)
    }

    logging.info(
        f"Found {len(partition)} communities for {len(graph.nodes())} nodes."
    )

    return community_map

def _characterize_communities(
    community_map: Dict[int, int]
) -> pd.Series:
    """
    Assigns descriptive labels to detected communities.

    This function provides a basic characterization of the communities based on
    a simplified mapping of HS sections to industrial groups. This is a heuristic
    to replicate the paper's qualitative findings (e.g., Machinery, Textiles).

    Args:
        community_map (Dict[int, int]): A map from product HS code to community ID.

    Returns:
        pd.Series: A Series mapping each product HS code to a descriptive
                   community label (e.g., "Machinery & Metals").
    """
    # Create a DataFrame from the community map for easier analysis.
    community_df = pd.DataFrame(
        community_map.items(), columns=['product_hs_code', 'community_id']
    )

    # Define a heuristic mapping from HS 2-digit section to a broad category.
    def hs_to_category(hs_code: int) -> str:
        section = hs_code // 100
        if 84 <= section <= 92 or 72 <= section <= 83:
            return "Machinery & Metals"
        elif 50 <= section <= 63:
            return "Textiles"
        elif 28 <= section <= 38 or 1 <= section <= 24:
            return "Chemicals & Food"
        else:
            return "Other"

    # Apply this mapping to each product.
    community_df['category'] = community_df['product_hs_code'].apply(hs_to_category)

    # Determine the dominant category for each community ID.
    # This is done by finding the mode of the categories within each community.
    community_labels = community_df.groupby('community_id')['category'].agg(
        lambda x: x.mode()[0]
    )

    # Map the descriptive labels back to the original DataFrame.
    community_df['community_label'] = community_df['community_id'].map(community_labels)

    # Create the final Series for output.
    final_labels = community_df.set_index('product_hs_code')['community_label']

    logging.info("Characterized communities with descriptive labels.")
    logging.info(f"Final community breakdown:\n{final_labels.value_counts()}")

    return final_labels

# =============================================================================
# Task 5.3: Network Topology Validation
# =============================================================================

def _validate_network_topology(
    graph: nx.DiGraph,
    adj_matrix: csr_matrix,
    product_to_idx: Dict[int, int]
) -> None:
    """
    Performs robust validation of the network's structural properties.

    This function conducts two key validation checks on the inferred network:
    1.  Basic Statistics: Reports fundamental network properties like node count,
        edge count, and overall density to provide a high-level summary.
    2.  HS Section Density: Validates the paper's finding that the network is
        denser within broad industrial sectors (defined by 2-digit HS codes)
        than between them. This is a test for sectoral homophily.

    This implementation is specifically designed to be robust against edge-case
    graph structures, such as those that may lack any intra- or inter-section
    links, by handling empty density lists gracefully.

    Args:
        graph (nx.DiGraph): The networkx graph object representing the network.
        adj_matrix (csr_matrix): The sparse adjacency matrix of the network.
        product_to_idx (Dict[int, int]): A dictionary mapping product HS codes
                                         to their corresponding matrix indices.
    """
    # --- Step 1: Calculate and log basic network statistics ---
    # Retrieve the number of nodes (products) from the graph.
    num_nodes = graph.number_of_nodes()
    # Retrieve the number of edges (inferred links) from the graph.
    num_edges = graph.number_of_edges()
    # Calculate the network density for a directed graph.
    # Formula: Density = E / (N * (N - 1)), where E is edges, N is nodes.
    density = num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0.0

    # Log the basic statistics for audit and review.
    logging.info("--- Network Topology Validation ---")
    logging.info(f"Node count: {num_nodes}")
    logging.info(f"Edge count: {num_edges}")
    logging.info(f"Network Density: {density:.6f}")

    # --- Step 2: Prepare HS section mappings for block density analysis ---
    # Create a DataFrame from the product-to-index map for easy manipulation.
    nodes_df = pd.DataFrame(
        product_to_idx.items(), columns=['product_hs_code', 'matrix_idx']
    )
    # Compute the 2-digit HS section for each product via integer division.
    nodes_df['hs2_section'] = nodes_df['product_hs_code'] // 100

    # Group matrix indices by their HS2 section to define the blocks.
    section_indices = nodes_df.groupby('hs2_section')['matrix_idx'].apply(list)

    # --- Step 3: Calculate block densities for all section pairs ---
    # Initialize empty lists to store the calculated densities.
    intra_section_densities = []
    inter_section_densities = []

    # Iterate over all pairs of source and target sections.
    for source_section, source_idxs in section_indices.items():
        for target_section, target_idxs in section_indices.items():
            # Calculate the number of potential edges in this block.
            potential_edges = len(source_idxs) * len(target_idxs)

            # Skip this block if no edges are possible (e.g., empty section).
            if potential_edges == 0:
                continue

            # Extract the sub-matrix corresponding to the (source, target) block.
            # This slicing is highly efficient with CSR matrices.
            sub_matrix = adj_matrix[source_idxs, :][:, target_idxs]

            # Calculate the density of the block.
            block_density = sub_matrix.nnz / potential_edges

            # Append the calculated density to the appropriate list.
            if source_section == target_section:
                intra_section_densities.append(block_density)
            else:
                inter_section_densities.append(block_density)

    # --- Step 4: Robustly calculate average densities and validate ---
    # This is the core remediation step to handle empty lists.

    # Calculate average intra-section density.
    if intra_section_densities:
        # If the list is not empty, compute the mean.
        avg_intra_density = np.mean(intra_section_densities)
    else:
        # If the list is empty, set the mean to 0.0 and log a warning.
        avg_intra_density = 0.0
        logging.warning(
            "No intra-section edges found. Average intra-section density is 0."
        )

    # Calculate average inter-section density.
    if inter_section_densities:
        # If the list is not empty, compute the mean.
        avg_inter_density = np.mean(inter_section_densities)
    else:
        # If the list is empty, set the mean to 0.0 and log a warning.
        avg_inter_density = 0.0
        logging.warning(
            "No inter-section edges found. Average inter-section density is 0."
        )

    # Log the final calculated average densities.
    logging.info(f"Average intra-section density: {avg_intra_density:.6f}")
    logging.info(f"Average inter-section density: {avg_inter_density:.6f}")

    # Perform the final validation check based on the paper's hypothesis.
    if avg_intra_density > avg_inter_density:
        # Log a success message if the validation passes.
        logging.info(
            "Validation PASSED: Network is denser within HS sections than between them."
        )
    else:
        # Log a warning if the validation fails.
        logging.warning(
            "Validation FAILED: Network is not denser within HS sections."
        )

# =============================================================================
# Task 5: Orchestrator Function
# =============================================================================

def analyze_network_structure(
    graph: nx.DiGraph,
    adj_matrix: csr_matrix,
    product_to_idx: Dict[int, int],
    replication_manifest: Dict[str, Any]
) -> pd.Series:
    """
    Orchestrates the community detection and structural analysis of the network.

    Args:
        graph (nx.DiGraph): The final networkx DiGraph object from Task 4.
        adj_matrix (csr_matrix): The final sparse adjacency matrix.
        product_to_idx (Dict[int, int]): Mapping from product codes to indices.
        replication_manifest (Dict[str, Any]): The dictionary of study parameters.

    Returns:
        pd.Series: A Series mapping each product HS code to its assigned
                   descriptive community label.
    """
    logging.info("Starting Task 5: Community Detection and Structural Analysis...")

    # Step 5.1 & 5.2: Detect and characterize communities.
    # We select a low resolution parameter to aim for the 3-cluster structure.
    # This may require empirical tuning in a real application.
    community_map = _detect_communities_leiden(graph=graph)

    # Assign descriptive labels to the found communities.
    community_labels = _characterize_communities(community_map=community_map)

    # Step 5.3: Perform topological validation.
    _validate_network_topology(
        graph=graph,
        adj_matrix=adj_matrix,
        product_to_idx=product_to_idx
    )

    logging.info("Task 5 successfully completed. Network structure analyzed.")

    return community_labels


In [None]:
# Task 6: Centrality Calculations

# =============================================================================
# Task 6.1: Betweenness Centrality Computation
# =============================================================================

def _calculate_betweenness_centrality(
    graph: nx.DiGraph
) -> Dict[int, float]:
    """
    Calculates the weighted betweenness centrality for each node in the network.

    Betweenness centrality identifies nodes that act as critical "choke points"
    or bridges along the shortest paths in the network. For this calculation,
    the edge weights (A_ij) must be inverted, as higher A_ij values represent
    stronger (i.e., "shorter" or "easier") paths.

    The function implements the mathematical definition:
    BC(v) = Σ_{s ≠ t ≠ v} [σ_st(v) / σ_st]
    where σ_st is the number of shortest paths between nodes s and t, and
    σ_st(v) is the number of those paths that pass through node v.

    Args:
        graph (nx.DiGraph): The input networkx directed graph with an 'weight'
                            attribute on each edge representing A_ij.

    Returns:
        Dict[int, float]: A dictionary mapping each node (product HS code) to
                          its normalized betweenness centrality score.
    """
    logging.info("Calculating weighted betweenness centrality...")

    # Create a working copy of the graph to avoid modifying the original object.
    g_copy = graph.copy()

    # --- Weight Inversion for Shortest Path Calculation ---
    # The algorithm finds shortest paths by minimizing the sum of weights.
    # Since our 'weight' (A_ij) signifies strength, a higher value should
    # mean a shorter path. We achieve this by using the inverse of the weight.
    for u, v, data in g_copy.edges(data=True):
        # Access the original weight.
        weight = data.get('weight', 0.0)
        # Add a new 'distance' attribute. Add a small epsilon to avoid division by zero.
        if weight > 0:
            g_copy[u][v]['distance'] = 1.0 / weight
        else:
            # Assign a very large distance to edges with zero or negative weight.
            g_copy[u][v]['distance'] = float('inf')

    # --- Centrality Calculation ---
    # Use networkx's highly optimized function to calculate centrality.
    # 'weight="distance"' tells the algorithm to use our inverted weights.
    # 'normalized=True' scales the results to the range [0, 1].
    betweenness_centrality = nx.betweenness_centrality(
        g_copy,
        weight='distance',
        normalized=True
    )

    logging.info("Successfully calculated betweenness centrality.")
    return betweenness_centrality

# =============================================================================
# Task 6.2: Hub Score Calculation (HITS Algorithm)
# =============================================================================

def _calculate_hub_scores(
    graph: nx.DiGraph
) -> Dict[int, float]:
    """
    Calculates the hub scores for each node using the HITS algorithm.

    The HITS algorithm identifies two types of important nodes: "hubs" (nodes
    that point to many important authorities) and "authorities" (nodes that
    are pointed to by many important hubs). This function calculates and
    returns the hub scores, which, in the context of the paper, identify
    products embedded in dense, complex supply chains.

    The algorithm iteratively solves:
    h = A^T * a  (hub score is proportional to sum of authority scores it points to)
    a = A * h    (authority score is proportional to sum of hub scores pointing to it)

    Args:
        graph (nx.DiGraph): The input networkx directed graph.

    Returns:
        Dict[int, float]: A dictionary mapping each node (product HS code) to
                          its hub score.

    Raises:
        DataValidationError: If the HITS algorithm fails to converge.
    """
    logging.info("Calculating hub scores using the HITS algorithm...")

    try:
        # Use networkx's implementation of the HITS algorithm.
        # The 'weight' attribute (A_ij) is used directly.
        # The tolerance and max_iterations are set to standard, robust values.
        hubs, authorities = nx.hits(
            graph,
            max_iter=1000,
            tol=1.0e-6,
            normalized=True
        )
    except nx.PowerIterationFailedConvergence as e:
        # Catch convergence errors, which indicate potential issues with the
        # network structure (e.g., disconnected components not handled).
        raise DataValidationError(
            "HITS algorithm failed to converge. The network may be unsuitable "
            "for this centrality measure."
        ) from e

    logging.info("Successfully calculated hub and authority scores.")
    return hubs

# =============================================================================
# Task 6.3: Centrality Aggregation and Validation
# =============================================================================

def _aggregate_and_validate_centralities(
    betweenness_centrality: Dict[int, float],
    hub_scores: Dict[int, float]
) -> pd.DataFrame:
    """
    Aggregates centrality scores into a single DataFrame and validates them.

    This function consolidates the computed centrality metrics, ranks them,
    and performs a correlation analysis to verify that they capture distinct
    aspects of network structure, as expected.

    Args:
        betweenness_centrality (Dict[int, float]): Node-to-betweenness map.
        hub_scores (Dict[int, float]): Node-to-hub-score map.

    Returns:
        pd.DataFrame: A DataFrame indexed by product HS code with columns for
                      centrality values and their ranks.
    """
    logging.info("Aggregating and validating centrality measures...")

    # Create a DataFrame from the centrality dictionaries.
    # Using pd.Series ensures alignment by node ID (HS code).
    centrality_df = pd.DataFrame({
        'betweenness': pd.Series(betweenness_centrality),
        'hub_score': pd.Series(hub_scores)
    })
    centrality_df.index.name = 'product_hs_code'

    # --- Rank the centrality scores ---
    # 'ascending=False' ensures that higher scores get lower rank numbers (e.g., rank 1).
    centrality_df['rank_betweenness'] = centrality_df['betweenness'].rank(
        method='average', ascending=False
    )
    centrality_df['rank_hub_score'] = centrality_df['hub_score'].rank(
        method='average', ascending=False
    )

    # --- Validate by checking correlation ---
    # We expect a low-to-moderate correlation, indicating the measures
    # capture different structural properties. Spearman is used for rank correlation.
    correlation_matrix = centrality_df[['betweenness', 'hub_score']].corr(
        method='spearman'
    )
    correlation = correlation_matrix.iloc[0, 1]
    logging.info(
        f"Spearman rank correlation between Betweenness and Hub Score: {correlation:.4f}"
    )

    return centrality_df

# =============================================================================
# Task 6: Orchestrator Function
# =============================================================================

def calculate_centralities(
    graph: nx.DiGraph,
    replication_manifest: Dict[str, Any]
) -> pd.DataFrame:
    """
    Orchestrates the calculation of all specified network centrality measures.

    This function takes the final constructed network graph and computes the
    key node-level importance metrics discussed in the paper: Betweenness
    Centrality and Hub Scores. It returns a consolidated DataFrame containing
    these metrics for every product in the network.

    Args:
        graph (nx.DiGraph): The final, weighted, directed network graph from Task 4.
        replication_manifest (Dict[str, Any]): The dictionary of study parameters.
                                               (Currently unused but included for
                                               API consistency).

    Returns:
        pd.DataFrame: A DataFrame indexed by product HS code, containing columns
                      for each centrality measure and its corresponding rank.
    """
    logging.info("Starting Task 6: Centrality Calculations...")

    # Step 6.1: Calculate weighted betweenness centrality.
    betweenness_centrality = _calculate_betweenness_centrality(graph=graph)

    # Step 6.2: Calculate hub scores using the HITS algorithm.
    hub_scores = _calculate_hub_scores(graph=graph)

    # Step 6.3: Aggregate, rank, and validate the centrality scores.
    centrality_df = _aggregate_and_validate_centralities(
        betweenness_centrality=betweenness_centrality,
        hub_scores=hub_scores
    )

    logging.info("Task 6 successfully completed. Centrality metrics calculated.")

    return centrality_df


In [None]:
# Task 7: Validation Procedures

# =============================================================================
# Task 7.1: External Network Comparison
# =============================================================================

def _compare_networks(
    empirical_matrix: csr_matrix,
    external_matrix: csr_matrix,
    product_to_idx: Dict[int, int],
    external_product_list: List[int]
) -> Dict[str, Tuple[float, float]]:
    """
    Compares the empirical network to an external network on multiple metrics.

    This function aligns the external network to the node set of the empirical
    network and then computes three correlation metrics:
    1. Edge-level correlation: Pearson correlation of the flattened adjacency matrices.
    2. In-degree correlation: Pearson correlation of the in-degree vectors.
    3. Out-degree correlation: Pearson correlation of the out-degree vectors.

    Args:
        empirical_matrix (csr_matrix): The adjacency matrix of our inferred network.
        external_matrix (csr_matrix): The adjacency matrix of the external network.
        product_to_idx (Dict[int, int]): Mapping from product HS code to the
                                         index in our empirical matrix.
        external_product_list (List[int]): Ordered list of products corresponding
                                           to the rows/cols of the external_matrix.

    Returns:
        Dict[str, Tuple[float, float]]: A dictionary with correlation results,
                                        where each value is a tuple of
                                        (correlation_coefficient, p_value).
    """
    logging.info("Aligning and comparing to an external network...")

    # --- Step 1: Align the external network to the empirical node set ---
    # Create the inverse mapping for our network.
    idx_to_product = {i: p for p, i in product_to_idx.items()}

    # Create the product-to-index map for the external network.
    external_product_to_idx = {p: i for i, p in enumerate(external_product_list)}

    # Find the common products between both networks.
    common_products = sorted(
        list(set(product_to_idx.keys()) & set(external_product_to_idx.keys()))
    )

    # Get the indices corresponding to these common products for each matrix.
    empirical_indices = [product_to_idx[p] for p in common_products]
    external_indices = [external_product_to_idx[p] for p in common_products]

    # Slice both matrices to create aligned versions with only common nodes.
    aligned_empirical = empirical_matrix[empirical_indices, :][:, empirical_indices]
    aligned_external = external_matrix[external_indices, :][:, external_indices]

    # --- Step 2: Calculate correlation metrics ---
    results = {}

    # Edge-level correlation
    # Flatten the matrices to 1D arrays for correlation.
    empirical_flat = aligned_empirical.toarray().flatten()
    external_flat = aligned_external.toarray().flatten()
    results['edge_correlation'] = pearsonr(empirical_flat, external_flat)

    # In-degree correlation
    # Sum over columns (axis=0) to get in-degrees.
    empirical_indegree = np.array(aligned_empirical.sum(axis=0)).flatten()
    external_indegree = np.array(aligned_external.sum(axis=0)).flatten()
    results['indegree_correlation'] = pearsonr(empirical_indegree, external_indegree)

    # Out-degree correlation
    # Sum over rows (axis=1) to get out-degrees.
    empirical_outdegree = np.array(aligned_empirical.sum(axis=1)).flatten()
    external_outdegree = np.array(aligned_external.sum(axis=1)).flatten()
    results['outdegree_correlation'] = pearsonr(empirical_outdegree, external_outdegree)

    logging.info(f"Comparison Results: {results}")
    return results

# =============================================================================
# Task 7.2 & 7.3: Configuration Model and Modularity Validation
# =============================================================================

def _calculate_directed_modularity(
    adj_matrix: csr_matrix,
    subgraph_indices: List[int],
    in_degrees: np.ndarray,
    out_degrees: np.ndarray
) -> float:
    """
    Calculates the directed modularity of a subgraph.

    Implements the formula for the contribution of a single community (subgraph)
    to the total network modularity:
    M_G = (1/m) * Σ_{i,j ∈ G} [A_ij - (k_i^out * k_j^in) / m]
    where A is the unweighted adjacency matrix.

    Args:
        adj_matrix (csr_matrix): The UNWEIGHTED adjacency matrix of the full graph.
        subgraph_indices (List[int]): A list of integer indices for the nodes
                                      in the subgraph.
        in_degrees (np.ndarray): The in-degree sequence of the FULL graph.
        out_degrees (np.ndarray): The out-degree sequence of the FULL graph.

    Returns:
        float: The modularity score of the specified subgraph.
    """
    # Total number of edges in the full graph.
    m = adj_matrix.nnz
    if m == 0:
        return 0.0

    # Extract the part of the adjacency matrix corresponding to the subgraph.
    subgraph_matrix = adj_matrix[subgraph_indices, :][:, subgraph_indices]

    # The actual number of edges within the subgraph.
    edges_within_subgraph = subgraph_matrix.nnz

    # The expected number of edges within the subgraph under the null model.
    subgraph_out_degrees = out_degrees[subgraph_indices]
    subgraph_in_degrees = in_degrees[subgraph_indices]
    expected_edges = (subgraph_out_degrees.sum() * subgraph_in_degrees.sum()) / m

    # Calculate and return the modularity score.
    modularity = (edges_within_subgraph - expected_edges) / m
    return modularity

def _modularity_simulation_worker(
    args: Tuple[np.ndarray, np.ndarray, List[int], int]
) -> float:
    """
    A worker function for parallel modularity simulation.

    Generates one random graph using the configuration model and calculates
    the modularity of the specified subgraph within it.

    Args:
        args (Tuple): A tuple containing (in_degrees, out_degrees,
                      subgraph_indices, seed).

    Returns:
        float: The modularity score for the subgraph in one random graph.
    """
    in_degrees, out_degrees, subgraph_indices, seed = args

    # Generate a random directed graph with the same degree sequence.
    random_graph = nx.directed_configuration_model(
        in_degree_sequence=in_degrees,
        out_degree_sequence=out_degrees,
        seed=seed,
        create_using=nx.DiGraph
    )

    # Convert to an unweighted adjacency matrix.
    random_adj = nx.to_scipy_sparse_matrix(random_graph, format='csr')

    # Calculate modularity for this random instance.
    return _calculate_directed_modularity(
        random_adj, subgraph_indices, in_degrees, out_degrees
    )

def _validate_subgraph_connectivity(
    adj_matrix: csr_matrix,
    product_to_idx: Dict[int, int],
    supply_chain_products: List[int],
    supply_chain_name: str,
    num_simulations: int
) -> Tuple[float, float]:
    """
    Tests if a subgraph's connectivity is statistically significant.

    This function performs a Monte Carlo simulation to determine if the
    observed modularity of a predefined subgraph (e.g., a known supply chain)
    is significantly higher than what would be expected by chance in a random
    network with the same degree distribution (the configuration model).

    Args:
        adj_matrix (csr_matrix): The adjacency matrix of the empirical network.
        product_to_idx (Dict[int, int]): Mapping from product HS code to index.
        supply_chain_products (List[int]): List of HS codes in the subgraph.
        supply_chain_name (str): The name of the supply chain for logging.
        num_simulations (int): The number of random graphs to generate.

    Returns:
        Tuple[float, float]: A tuple of (empirical_modularity, p_value).
    """
    logging.info(f"--- Validating connectivity for '{supply_chain_name}' subgraph ---")

    # --- Step 1: Calculate empirical modularity ---
    # Create an unweighted version of the adjacency matrix (0s and 1s).
    unweighted_adj = (adj_matrix > 0).astype(int)

    # Get the degree sequences of the full graph.
    out_degrees = np.array(unweighted_adj.sum(axis=1)).flatten()
    in_degrees = np.array(unweighted_adj.sum(axis=0)).flatten()

    # Get the matrix indices for the products in this supply chain.
    subgraph_indices = [
        product_to_idx[p] for p in supply_chain_products if p in product_to_idx
    ]
    if not subgraph_indices:
        logging.warning(f"No products from '{supply_chain_name}' found in the network.")
        return 0.0, 1.0

    # Calculate the modularity of the subgraph in our actual network.
    empirical_modularity = _calculate_directed_modularity(
        unweighted_adj, subgraph_indices, in_degrees, out_degrees
    )
    logging.info(f"Empirical modularity for '{supply_chain_name}': {empirical_modularity:.6f}")

    # --- Step 2: Run Monte Carlo simulation ---
    logging.info(f"Running {num_simulations:,} simulations for the null distribution...")

    # Prepare arguments for the parallel worker function.
    # We generate unique seeds for each worker for true parallelism.
    worker_args = [
        (in_degrees, out_degrees, subgraph_indices, seed)
        for seed in range(num_simulations)
    ]

    # Use a multiprocessing pool to parallelize the simulations.
    with Pool(processes=max(1, cpu_count() - 1)) as pool:
        null_distribution = pool.map(_modularity_simulation_worker, worker_args)

    # --- Step 3: Calculate the p-value ---
    # Count how many times the random modularity was >= the empirical one.
    null_distribution = np.array(null_distribution)
    count_extreme = np.sum(null_distribution >= empirical_modularity)

    # Calculate the p-value with a +1 correction to avoid p=0.
    p_value = (count_extreme + 1) / (num_simulations + 1)
    logging.info(f"P-value for '{supply_chain_name}': {p_value:.6f}")

    return empirical_modularity, p_value

# =============================================================================
# Task 7: Orchestrator Function
# =============================================================================

def run_validation_procedures(
    adj_matrix: csr_matrix,
    product_to_idx: Dict[int, int],
    supply_chains_definitions: Dict[str, List[int]],
    replication_manifest: Dict[str, Any],
    # Placeholder for external networks, which would be loaded here.
    external_networks: Dict[str, Any] = None
) -> Dict[str, Any]:
    """
    Orchestrates all network validation procedures.

    Args:
        adj_matrix (csr_matrix): The inferred adjacency matrix.
        product_to_idx (Dict[int, int]): Mapping from product HS code to index.
        supply_chains_definitions (Dict[str, List[int]]): Definitions of
                                                          manual supply chains.
        replication_manifest (Dict[str, Any]): The study parameters.
        external_networks (Dict[str, Any], optional): A dictionary containing
            loaded external networks for comparison. Defaults to None.

    Returns:
        Dict[str, Any]: A dictionary containing all validation results.
    """
    logging.info("Starting Task 7: Validation Procedures...")
    validation_results = {}

    # --- Task 7.1: External Network Comparison ---
    # This part is conditional on providing external network data.
    if external_networks:
        logging.info("--- Task 7.1: External Network Comparison ---")
        validation_results['external_comparison'] = {}
        for name, data in external_networks.items():
            # Assuming data is a dict with 'matrix' and 'product_list'
            comparison = _compare_networks(
                empirical_matrix=adj_matrix,
                external_matrix=data['matrix'],
                product_to_idx=product_to_idx,
                external_product_list=data['product_list']
            )
            validation_results['external_comparison'][name] = comparison
    else:
        logging.warning("Skipping external network comparison: No data provided.")

    # --- Task 7.2 & 7.3: Manual Supply Chain Validation ---
    logging.info("--- Task 7.2 & 7.3: Manual Supply Chain Validation ---")
    validation_results['subgraph_validation'] = {}
    num_sims = replication_manifest['parameters']['network_analysis']['modularity_validation_simulations']

    for name, products in supply_chains_definitions.items():
        modularity, p_value = _validate_subgraph_connectivity(
            adj_matrix=adj_matrix,
            product_to_idx=product_to_idx,
            supply_chain_products=products,
            supply_chain_name=name,
            num_simulations=num_sims
        )
        validation_results['subgraph_validation'][name] = {
            'empirical_modularity': modularity,
            'p_value': p_value
        }

    logging.info("Task 7 successfully completed. Validation results generated.")
    return validation_results


In [None]:
# Task 8: Economic Feature Engineering

# =============================================================================
# Task 8.1: Country-Product Presence Matrix Construction
# =============================================================================

def _calculate_rpop(
    comtrade_df: pd.DataFrame,
    country_df: pd.DataFrame,
    years: List[int]
) -> pd.DataFrame:
    """
    Calculates Revealed Population-Adjusted Comparative Advantage (Rpop).

    Implements the formula:
    Rpop_p,c = (E_p,c / pop_c) / (E_p / pop_world)
    where E_p,c is exports of product p by country c.

    Args:
        comtrade_df (pd.DataFrame): DataFrame of country-product exports.
        country_df (pd.DataFrame): DataFrame of country populations.
        years (List[int]): The years for which to calculate Rpop.

    Returns:
        pd.DataFrame: A DataFrame with a MultiIndex (year, reporter_iso,
                      product_hs_code) and a column for Rpop.
    """
    logging.info(f"Calculating Rpop for years: {years}...")

    # Filter data for the relevant years to reduce memory footprint.
    exports = comtrade_df[comtrade_df['year'].isin(years)].copy()
    population = country_df[country_df['year'].isin(years)].copy()

    # Calculate total world exports per product, per year (E_p).
    world_exports = exports.groupby(['year', 'product_hs_code'])['export_value_usd'].sum().rename('world_export')

    # Calculate total world population per year (pop_world).
    world_population = population.groupby('year')['population'].sum().rename('world_population')

    # Merge exports with population data.
    merged_df = pd.merge(
        exports, population, on=['year', 'reporter_iso'], how='left'
    )

    # Merge with world totals.
    merged_df = pd.merge(
        merged_df, world_exports, on=['year', 'product_hs_code'], how='left'
    )
    merged_df = pd.merge(
        merged_df, world_population, on='year', how='left'
    )

    # Handle cases where population or world data is missing.
    if merged_df[['population', 'world_export', 'world_population']].isnull().any().any():
        logging.warning("Missing population or world export data for some records. Rpop will be NaN.")

    # Calculate per capita values.
    merged_df['per_capita_export_country'] = merged_df['export_value_usd'] / merged_df['population']
    merged_df['per_capita_export_world'] = merged_df['world_export'] / merged_df['world_population']

    # Calculate Rpop, handling division by zero.
    merged_df['rpop'] = merged_df['per_capita_export_country'] / merged_df['per_capita_export_world']
    merged_df['rpop'].fillna(0, inplace=True)
    merged_df.replace([np.inf, -np.inf], 0, inplace=True)

    # Return a clean DataFrame with the required index and column.
    return merged_df.set_index(['year', 'reporter_iso', 'product_hs_code'])[['rpop']]


def _construct_diversification_outcome(
    rpop_df: pd.DataFrame,
    params: Dict[str, Any]
) -> pd.DataFrame:
    """
    Constructs the binary diversification outcome variable.

    A diversification event occurs if a country-product pair was "absent" in
    the start year and "present" in the end year, based on Rpop thresholds.

    Args:
        rpop_df (pd.DataFrame): DataFrame of Rpop values.
        params (Dict[str, Any]): The 'econometric_analysis' parameters.

    Returns:
        pd.DataFrame: A DataFrame indexed by (reporter_iso, product_hs_code)
                      with columns for presence and the diversification outcome.
    """
    logging.info("Constructing diversification outcome variable...")

    # Unstack the Rpop data to have years as columns.
    rpop_wide = rpop_df.unstack(level='year')
    rpop_wide.columns = rpop_wide.columns.droplevel(0) # Drop 'rpop' level

    start_year = params['start_year_for_diversification']
    end_year = params['end_year_for_diversification']

    # Fill NaNs with 0, as missing Rpop implies absence.
    rpop_wide.fillna(0, inplace=True)

    # Create binary presence indicators based on the specified thresholds.
    rpop_wide[f'presence_{start_year}'] = (
        rpop_wide[start_year] >= params['rpop_absence_threshold']
    )
    rpop_wide[f'presence_{end_year}'] = (
        rpop_wide[end_year] >= params['rpop_presence_threshold']
    )

    # Define the diversification outcome variable (R_p,c).
    # R_p,c = 1 if (M_2016 = 0) AND (M_2021 = 1) else 0
    rpop_wide['diversification_outcome'] = (
        (~rpop_wide[f'presence_{start_year}']) & (rpop_wide[f'presence_{end_year}'])
    ).astype(int)

    logging.info(
        f"Identified {rpop_wide['diversification_outcome'].sum():,} "
        "diversification events."
    )

    return rpop_wide[[f'presence_{start_year}', 'diversification_outcome']]

# =============================================================================
# Task 8.2: Network Density Metric Calculation
# =============================================================================

def _calculate_network_density(
    adj_matrix: csr_matrix,
    presence_matrix: pd.DataFrame,
    product_to_idx: Dict[int, int],
    density_type: str,
    top_k: int
) -> pd.DataFrame:
    """
    Calculates upstream or downstream network density for each country-product.

    This function uses efficient sparse matrix algebra to compute the density
    metric, which measures the proportion of a product's top-k suppliers
    (upstream) or buyers (downstream) that a country already exports.

    Args:
        adj_matrix (csr_matrix): The n x n adjacency matrix of the network.
        presence_matrix (pd.DataFrame): A (products x countries) binary matrix
                                        indicating export presence in the base year.
        product_to_idx (Dict[int, int]): Mapping from product HS code to matrix index.
        density_type (str): Either 'upstream' or 'downstream'.
        top_k (int): The number of top partners to consider for the density calculation.

    Returns:
        pd.DataFrame: A (products x countries) DataFrame of density scores.
    """
    logging.info(f"Calculating {density_type} density (top_k={top_k})...")

    # Determine which matrix to use based on density type.
    # Downstream density: uses transpose to find outputs for a given input.
    # Upstream density: uses the original matrix to find inputs for a given output.
    matrix = adj_matrix.T if density_type == 'downstream' else adj_matrix

    # --- Create the binary top-k weight matrix (W) ---
    n_products = matrix.shape[0]
    W = csr_matrix((n_products, n_products), dtype=np.float32)

    # Iterate through each product (row) to find its top-k partners.
    for i in range(n_products):
        # Get the data and indices for the current row.
        row_start = matrix.indptr[i]
        row_end = matrix.indptr[i+1]
        if row_start == row_end:
            continue # Skip if no partners

        row_data = matrix.data[row_start:row_end]
        row_indices = matrix.indices[row_start:row_end]

        # Find the indices of the top-k values.
        # `np.argpartition` is faster than a full sort for finding top-k.
        k = min(top_k, len(row_data))
        top_k_local_indices = np.argpartition(row_data, -k)[-k:]
        top_k_global_indices = row_indices[top_k_local_indices]

        # Set the corresponding entries in W to 1.
        W[i, top_k_global_indices] = 1

    # --- Calculate Density via Matrix Multiplication ---
    # Align the presence matrix with the adjacency matrix's indexing.
    aligned_presence = presence_matrix.reindex(
        index=list(product_to_idx.keys())
    ).fillna(0)

    # The core calculation: D = W * M
    # This multiplies the top-k links by the country presence.
    density_matrix_raw = W.dot(aligned_presence.values)

    # Normalize by the number of partners for each product.
    # The normalizer is the sum of each row in W.
    normalizer = np.array(W.sum(axis=1)).flatten()
    # Avoid division by zero for products with no partners.
    normalizer[normalizer == 0] = 1

    # Apply normalization.
    density_matrix_normalized = density_matrix_raw / normalizer[:, np.newaxis]

    # Convert the result back to a DataFrame.
    density_df = pd.DataFrame(
        density_matrix_normalized,
        index=aligned_presence.index,
        columns=aligned_presence.columns
    )

    return density_df

# =============================================================================
# Task 8.3: Econometric Dataset Preparation
# =============================================================================

def _prepare_econometric_dataset(
    diversification_df: pd.DataFrame,
    downstream_density: pd.DataFrame,
    upstream_density: pd.DataFrame,
    comtrade_df: pd.DataFrame,
    params: Dict[str, Any],
    network_products: List[int]
) -> pd.DataFrame:
    """
    Assembles and filters the final dataset for econometric analysis.

    Args:
        diversification_df (pd.DataFrame): DataFrame with diversification outcomes.
        downstream_density (pd.DataFrame): Matrix of downstream density scores.
        upstream_density (pd.DataFrame): Matrix of upstream density scores.
        comtrade_df (pd.DataFrame): Raw Comtrade data for filtering.
        params (Dict[str, Any]): The 'econometric_analysis' parameters.
        network_products (List[int]): List of products in the network's LCC.

    Returns:
        pd.DataFrame: A long-format DataFrame ready for Probit regression.
    """
    logging.info("Assembling and filtering the final econometric dataset...")

    # --- Step 8.3.1: Filter products ---
    # Filter by global trade value in the start year.
    start_year = params['start_year_for_diversification']
    trade_filter = comtrade_df[comtrade_df['year'] == start_year]
    product_trade = trade_filter.groupby('product_hs_code')['export_value_usd'].sum()

    products_above_threshold = product_trade[
        product_trade >= params['min_global_trade_for_product_inclusion_usd']
    ].index

    # Final product set is the intersection of all criteria.
    final_product_set = set(network_products) & set(products_above_threshold)
    logging.info(f"Final analysis includes {len(final_product_set)} products.")

    # --- Step 8.3.2 & 8.3.3: Melt, Merge, and Filter Sample ---
    # Melt all DataFrames to long format for merging.
    presence_col = f"presence_{start_year}"
    base_df = diversification_df.reset_index()

    downstream_long = downstream_density.melt(
        var_name='reporter_iso', value_name='downstream_density', ignore_index=False
    ).reset_index()

    upstream_long = upstream_density.melt(
        var_name='reporter_iso', value_name='upstream_density', ignore_index=False
    ).reset_index()

    # Merge all components together.
    merged = pd.merge(
        base_df, downstream_long, on=['reporter_iso', 'product_hs_code']
    )
    merged = pd.merge(
        merged, upstream_long, on=['reporter_iso', 'product_hs_code']
    )

    # Apply product filter.
    final_df = merged[merged['product_hs_code'].isin(final_product_set)]

    # Apply the final sample filter: only include observations that were
    # NOT present in the start year.
    final_df = final_df[~final_df[presence_col]].copy()

    logging.info(f"Final regression sample contains {len(final_df):,} observations.")

    return final_df.drop(columns=[presence_col])

# =============================================================================
# Task 8: Orchestrator Function
# =============================================================================

def engineer_economic_features(
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    adj_matrix: csr_matrix,
    product_to_idx: Dict[int, int],
    replication_manifest: Dict[str, Any]
) -> pd.DataFrame:
    """
    Orchestrates the full feature engineering pipeline for econometric analysis.

    Args:
        comtrade_exports_frame (pd.DataFrame): Country-product export data.
        country_data_frame (pd.DataFrame): Country population data.
        adj_matrix (csr_matrix): The inferred network adjacency matrix.
        product_to_idx (Dict[int, int]): Mapping from product HS code to index.
        replication_manifest (Dict[str, Any]): The dictionary of study parameters.

    Returns:
        pd.DataFrame: The final, long-format dataset ready for regression.
    """
    logging.info("Starting Task 8: Economic Feature Engineering...")
    params = replication_manifest['parameters']['econometric_analysis']

    # Step 8.1: Construct country-product presence and diversification outcome.
    rpop_df = _calculate_rpop(
        comtrade_df=comtrade_exports_frame,
        country_df=country_data_frame,
        years=[params['start_year_for_diversification'], params['end_year_for_diversification']]
    )
    diversification_df = _construct_diversification_outcome(rpop_df, params)

    # Step 8.2: Calculate network density metrics.
    # Create the base year presence matrix (products x countries).
    presence_matrix_2016 = diversification_df[[f"presence_{params['start_year_for_diversification']}"]].unstack(level='reporter_iso')
    presence_matrix_2016.columns = presence_matrix_2016.columns.droplevel(0)

    downstream_density = _calculate_network_density(
        adj_matrix, presence_matrix_2016, product_to_idx, 'downstream', params['density_metric_top_k_edges']
    )
    upstream_density = _calculate_network_density(
        adj_matrix, presence_matrix_2016, product_to_idx, 'upstream', params['density_metric_top_k_edges']
    )

    # Step 8.3: Assemble the final econometric dataset.
    network_products = list(product_to_idx.keys())
    econometric_df = _prepare_econometric_dataset(
        diversification_df, downstream_density, upstream_density,
        comtrade_exports_frame, params, network_products
    )

    logging.info("Task 8 successfully completed. Economic features engineered.")
    return econometric_df


In [None]:
# Task 9: Econometric Analysis

# =============================================================================
# Task 9.1: Probit Model Specification and Estimation
# =============================================================================

def _estimate_probit_model(
    econometric_df: pd.DataFrame,
    density_variable: str,
    include_fixed_effects: bool = True
) -> ProbitResults:
    """
    Estimates a Probit model for the diversification outcome.

    This function specifies and fits a Probit model according to the paper's
    methodology. The dependent variable is 'diversification_outcome', and the
    primary independent variable is a specified network density metric.
    The model can optionally include country and product fixed effects.

    The estimated model is:
    R_p,c = Φ(α + β * density_p,c + γ_p + η_c)

    Args:
        econometric_df (pd.DataFrame): The final, long-format dataset.
        density_variable (str): The name of the density column to use as the
                                predictor ('downstream_density' or 'upstream_density').
        include_fixed_effects (bool): If True, includes country and product
                                      fixed effects in the model.

    Returns:
        ProbitResults: The fitted model results object from statsmodels.

    Raises:
        DataValidationError: If the specified density_variable is not in the DataFrame.
    """
    # Validate that the required density variable exists.
    if density_variable not in econometric_df.columns:
        raise DataValidationError(f"Density variable '{density_variable}' not found in the dataset.")

    # Construct the model formula using the R-style formula API.
    if include_fixed_effects:
        # Formula with country (reporter_iso) and product fixed effects.
        # C() treats the variables as categorical for dummy creation.
        formula = (
            f"diversification_outcome ~ {density_variable} + "
            "C(reporter_iso) + C(product_hs_code)"
        )
        logging.info(f"Estimating Probit model with fixed effects: {formula}")
    else:
        # Formula without fixed effects for baseline comparison.
        formula = f"diversification_outcome ~ {density_variable}"
        logging.info(f"Estimating Probit model without fixed effects: {formula}")

    # Instantiate the Probit model.
    model = smf.probit(formula=formula, data=econometric_df)

    # Fit the model using Maximum Likelihood Estimation.
    # Note: The paper does not specify clustered errors, but it's a robust
    # choice. For strict replication, we will use standard errors first,
    # but the code structure allows for easy extension to clustered errors.
    # To use clustered errors:
    # groups = econometric_df['reporter_iso']
    # results = model.fit(cov_type='cluster', cov_kwds={'groups': groups})
    try:
        results = model.fit(disp=0) # disp=0 suppresses convergence messages
    except Exception as e:
        logging.error(f"Probit model estimation failed for '{density_variable}': {e}")
        raise DataValidationError("Probit model failed to converge.") from e

    logging.info(f"Successfully estimated Probit model for '{density_variable}'.")
    return results

# =============================================================================
# Task 9.2: Model Performance Evaluation
# =============================================================================

def _evaluate_model_performance(
    results: ProbitResults,
    econometric_df: pd.DataFrame
) -> Dict[str, Any]:
    """
    Evaluates the predictive performance of a fitted Probit model.

    This function calculates key performance metrics, including the Area Under
    the ROC Curve (AUC), and generates the data for plotting the ROC curve.

    Args:
        results (ProbitResults): The fitted model results object.
        econometric_df (pd.DataFrame): The dataset used for fitting.

    Returns:
        Dict[str, Any]: A dictionary containing performance metrics, including
                        'auc', 'roc_curve_data', and other model diagnostics.
    """
    # Get the true outcomes (dependent variable).
    y_true = econometric_df[results.model.endog_names]

    # Get the predicted probabilities from the model.
    y_pred_prob = results.predict(econometric_df)

    # --- Step 9.2.1: Calculate AUC ---
    auc = roc_auc_score(y_true, y_pred_prob)

    # --- Step 9.2.2: Generate ROC Curve data ---
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)

    # --- Step 9.2.3: Compute additional model diagnostics ---
    diagnostics = {
        'auc': auc,
        'roc_curve_data': (fpr, tpr),
        'log_likelihood': results.llf,
        'pseudo_r_squared': results.prsquared,
        'aic': results.aic,
        'bic': results.bic
    }

    logging.info(f"Model Performance: AUC = {auc:.4f}, Pseudo R^2 = {results.prsquared:.4f}")
    return diagnostics

# =============================================================================
# Task 9.3: Results Interpretation and Validation
# =============================================================================

def _interpret_and_summarize_results(
    results: ProbitResults,
    density_variable: str
) -> Dict[str, Any]:
    """
    Extracts, interprets, and summarizes key econometric results.

    This function extracts the coefficient, standard error, p-value, and
    calculates the Average Marginal Effect (AME) for the primary variable of
    interest (the density metric).

    Args:
        results (ProbitResults): The fitted model results object.
        density_variable (str): The name of the primary predictor variable.

    Returns:
        Dict[str, Any]: A dictionary containing the key statistical results
                        and their interpretation.
    """
    # --- Step 9.3.1: Extract coefficient estimates ---
    summary = {
        'coefficient': results.params.get(density_variable),
        'std_error': results.bse.get(density_variable),
        'p_value': results.pvalues.get(density_variable),
        'conf_int': results.conf_int().loc[density_variable].tolist()
    }

    # --- Step 9.3.2: Calculate Average Marginal Effects (AME) ---
    # This is the most meaningful way to interpret the magnitude of the effect.
    marginal_effects = results.get_margeff()
    summary['average_marginal_effect'] = marginal_effects.summary().tables[1].data[1][1]

    logging.info(f"--- Results for '{density_variable}' ---")
    logging.info(f"  Coefficient: {summary['coefficient']:.4f}")
    logging.info(f"  P-value: {summary['p_value']:.4f}")
    logging.info(f"  Average Marginal Effect: {summary['average_marginal_effect']:.4f}")

    return summary

# =============================================================================
# Task 9: Orchestrator Function
# =============================================================================

def run_econometric_analysis(
    econometric_df: pd.DataFrame,
    replication_manifest: Dict[str, Any]
) -> Dict[str, Dict[str, Any]]:
    """
    Orchestrates the end-to-end econometric analysis pipeline.

    This function runs the Probit regressions for both upstream and downstream
    density metrics, evaluates their performance, and summarizes the results
    in a structured format, replicating the core analysis of the paper.

    Args:
        econometric_df (pd.DataFrame): The final, analysis-ready dataset from Task 8.
        replication_manifest (Dict[str, Any]): The dictionary of study parameters.

    Returns:
        Dict[str, Dict[str, Any]]: A nested dictionary containing the full
                                   results for both the 'downstream' and
                                   'upstream' models.
    """
    logging.info("Starting Task 9: Econometric Analysis...")

    final_results = {}
    density_variables = ['downstream_density', 'upstream_density']

    for density_var in density_variables:
        model_name = density_var.replace('_density', '')
        logging.info(f"\n===== Running Analysis for {model_name.upper()} Model =====")

        # --- Step 9.1: Estimate the Probit model ---
        # The paper's main results include fixed effects.
        model_results = _estimate_probit_model(
            econometric_df=econometric_df,
            density_variable=density_var,
            include_fixed_effects=True
        )

        # --- Step 9.2: Evaluate model performance ---
        performance_metrics = _evaluate_model_performance(
            results=model_results,
            econometric_df=econometric_df
        )

        # --- Step 9.3: Interpret and summarize results ---
        interpretation = _interpret_and_summarize_results(
            results=model_results,
            density_variable=density_var
        )

        # Consolidate all results for this model.
        final_results[model_name] = {
            'model_summary': interpretation,
            'performance': performance_metrics,
            'statsmodels_results': model_results # For full inspection
        }

    logging.info("\nTask 9 successfully completed. Econometric analysis finished.")
    return final_results


In [None]:
# Task 10: Orchestrator Function Creation

# =============================================================================
# Task 10: Orchestrator Function Creation
# =============================================================================


def run_end_to_end_pipeline(
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    replication_manifest: Dict[str, Any],
    external_networks: Dict[str, Any] = None
) -> Dict[str, Any]:
    """
    Executes the full, end-to-end research pipeline from data validation to
    final econometric analysis.

    This master orchestrator function serves as the single entry point for
    replicating the study "Deciphering the global production network from
    cross-border firm transactions". It sequentially executes each of the nine
    major tasks, handling the flow of data between them and aggregating the
    final results into a comprehensive output dictionary.

    Args:
        transactions_log_frame (pd.DataFrame): Raw log of firm transactions.
        firm_metadata_frame (pd.DataFrame): Raw firm ownership and location data.
        comtrade_exports_frame (pd.DataFrame): Raw country-product export data.
        country_data_frame (pd.DataFrame): Raw country population data.
        supply_chains_definitions_dict (Dict[str, Any]): Definitions of
            manual supply chains for validation.
        replication_manifest (Dict[str, Any]): Dictionary of all study parameters.
        external_networks (Dict[str, Any], optional): A dictionary containing
            loaded external networks for comparison. Defaults to None.

    Returns:
        Dict[str, Any]: A nested dictionary containing all key outputs of the
                        pipeline, including the final network objects, analysis
                        results, validation metrics, and econometric findings.

    Raises:
        DataValidationError: If any critical validation or processing step fails.
        Exception: For any other unexpected runtime errors.
    """
    # --- Initialization ---
    # Set up a comprehensive logger for the entire pipeline run.
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - [Task Orchestrator] - %(message)s'
    )

    # Initialize a dictionary to store all final results.
    pipeline_results = {}

    logging.info("===== STARTING END-TO-END RESEARCH PIPELINE =====")

    try:
        # --- Task 1: Data Validation and Quality Assurance ---
        # This step validates all inputs. It returns None on success or raises
        # a DataValidationError on failure, halting the pipeline early.
        validate_and_assess_inputs(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            replication_manifest=replication_manifest
        )

        # --- Task 2: Data Preprocessing and Cleansing ---
        # This step cleans the raw transaction data.
        cleansed_df = preprocess_and_cleanse_data(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            replication_manifest=replication_manifest
        )

        # --- Task 3: Firm Classification and Set Construction ---
        # This step creates the fundamental sets for network inference.
        producer_sets, purchaser_sets, intersection_metrics = classify_firms_and_construct_sets(
            cleansed_df=cleansed_df
        )

        # --- Task 4: Network Inference Implementation ---
        # This is the core step where the network is built.
        adj_matrix, graph, product_to_idx = infer_and_construct_network(
            cleansed_df=cleansed_df,
            producer_sets=producer_sets,
            purchaser_sets=purchaser_sets,
            intersection_metrics=intersection_metrics,
            replication_manifest=replication_manifest
        )
        # Store the primary network artifacts.
        pipeline_results['network_objects'] = {
            'adjacency_matrix': adj_matrix,
            'graph': graph,
            'product_to_idx_map': product_to_idx
        }

        # --- Task 5: Community Detection and Structural Analysis ---
        # This step analyzes the network's meso-scale structure.
        community_labels = analyze_network_structure(
            graph=graph,
            adj_matrix=adj_matrix,
            product_to_idx=product_to_idx,
            replication_manifest=replication_manifest
        )
        pipeline_results['network_analysis'] = {'community_labels': community_labels}

        # --- Task 6: Centrality Calculations ---
        # This step calculates node-level importance metrics.
        centrality_df = calculate_centralities(
            graph=graph,
            replication_manifest=replication_manifest
        )
        pipeline_results['network_analysis']['centralities'] = centrality_df

        # --- Task 7: Validation Procedures ---
        # This step validates the network against statistical null models.
        validation_results = run_validation_procedures(
            adj_matrix=adj_matrix,
            product_to_idx=product_to_idx,
            supply_chains_definitions=supply_chains_definitions_dict,
            replication_manifest=replication_manifest,
            external_networks=external_networks
        )
        pipeline_results['validation_results'] = validation_results

        # --- Task 8: Economic Feature Engineering ---
        # This step prepares the data for the final econometric analysis.
        econometric_df = engineer_economic_features(
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            adj_matrix=adj_matrix,
            product_to_idx=product_to_idx,
            replication_manifest=replication_manifest
        )
        pipeline_results['econometric_dataset'] = econometric_df

        # --- Task 9: Econometric Analysis ---
        # This step runs the final Probit models.
        econometric_results = run_econometric_analysis(
            econometric_df=econometric_df,
            replication_manifest=replication_manifest
        )
        pipeline_results['econometric_results'] = econometric_results

        logging.info("===== END-TO-END RESEARCH PIPELINE COMPLETED SUCCESSFULLY =====")

    except DataValidationError as e:
        # Catch specific validation errors and log them before re-raising.
        logging.error(f"A critical data validation error occurred: {e}")
        logging.error("PIPELINE HALTED.")
        raise
    except Exception as e:
        # Catch any other unexpected errors.
        logging.error(f"An unexpected error occurred during pipeline execution: {e}", exc_info=True)
        logging.error("PIPELINE HALTED.")
        raise

    # Return the final, aggregated results.
    return pipeline_results


In [None]:
# Task 11: Robustness Analysis

# =============================================================================
# Task 11.1: Parameter Sensitivity Analysis
# =============================================================================

def _update_nested_dict(
    d: Dict[str, Any],
    path: Tuple[str, ...],
    value: Any
) -> None:
    """
    A robust helper utility to set a value in a nested dictionary using a path tuple.

    This function traverses a nested dictionary according to the keys provided
    in the path tuple and sets the value at the final destination. It includes
    error handling to ensure the path is valid.

    Args:
        d (Dict[str, Any]): The nested dictionary to update.
        path (Tuple[str, ...]): A tuple of keys representing the path to the
                                value to be updated.
        value (Any): The new value to set at the specified path.

    Raises:
        KeyError: If any key in the path (except the last one) does not exist
                  or does not lead to a sub-dictionary.
    """
    # Start the traversal from the top level of the dictionary.
    current_level = d
    # Iterate through the path until the second-to-last key.
    for key in path[:-1]:
        # Check if the key exists and leads to another dictionary.
        if key in current_level and isinstance(current_level[key], dict):
            # Move one level deeper into the dictionary.
            current_level = current_level[key]
        else:
            # If the path is invalid, raise a specific KeyError.
            raise KeyError(f"Invalid path in manifest: key '{key}' not found or not a dict.")

    # Get the final key in the path.
    final_key = path[-1]
    # Check if the final key exists at the target level.
    if final_key not in current_level:
        raise KeyError(f"Invalid path in manifest: final key '{final_key}' not found.")

    # Set the value of the final key.
    current_level[final_key] = value


def _run_single_pipeline_iteration(
    params_combination: Dict[Tuple[str, ...], Any],
    base_manifest: Dict[str, Any],
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    external_networks: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Worker function to execute one full pipeline run with a specific, robustly-
    defined parameter set.

    This function is designed for general-purpose sensitivity analysis.
    It accepts a dictionary of parameters where keys are tuples representing the
    path to the parameter in the nested manifest. It updates a deep copy of the
    base manifest, executes the end-to-end pipeline, and extracts key results.

    Args:
        params_combination (Dict[Tuple[str, ...], Any]): A dictionary where keys
            are path tuples (e.g., ('network_inference', 'primary_firmcount_threshold'))
            and values are the specific settings for this iteration.
        base_manifest (Dict[str, Any]): The original, unmodified replication manifest.
        **kwargs: All the raw data inputs required by the main pipeline.

    Returns:
        Dict[str, Any]: A dictionary containing the parameters for this run
                        (with flattened, dot-notation keys) and the key extracted
                        results. Returns a dictionary with an 'error' key if the
                        run fails.
    """
    try:
        # Create a deep copy of the manifest to ensure run-to-run isolation.
        run_manifest = deepcopy(base_manifest)

        # Iterate through the parameter combination and update the manifest copy.
        for path, value in params_combination.items():
            # Use the robust helper function to update the nested dictionary.
            _update_nested_dict(run_manifest, path, value)

        # Execute the full end-to-end pipeline with the modified manifest.
        pipeline_results = run_end_to_end_pipeline(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            supply_chains_definitions_dict=supply_chains_definitions_dict,
            replication_manifest=run_manifest,
            external_networks=external_networks
        )

        # Extract the key results for this run.
        downstream_auc = pipeline_results['econometric_results']['downstream']['performance']['auc']
        upstream_auc = pipeline_results['econometric_results']['upstream']['performance']['auc']
        edge_count = pipeline_results['network_objects']['graph'].number_of_edges()

        # Flatten the parameter path tuples into dot-notation strings for the output.
        flattened_params = {'.'.join(path): value for path, value in params_combination.items()}

        # Combine the flattened parameters and results into a single record.
        result_record = {
            **flattened_params,
            'downstream_auc': downstream_auc,
            'upstream_auc': upstream_auc,
            'edge_count': edge_count,
            'error': None
        }
        return result_record

    except Exception as e:
        # If any part of the pipeline fails, log the error and return a record
        # indicating failure.
        flattened_params = {'.'.join(path): value for path, value in params_combination.items()}
        logging.error(f"Pipeline run failed for parameters {flattened_params}: {e}")
        return {**flattened_params, 'error': str(e)}


def run_parameter_sensitivity_analysis(
    parameter_grid: Dict[str, List[Any]],
    base_manifest: Dict[str, Any],
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    external_networks: Dict[str, Any] = None,
    n_jobs: int = -1
) -> pd.DataFrame:
    """
    Performs a sensitivity analysis by running the full pipeline across a grid
    of specified parameter values.

    This function systematically varies key methodological parameters (e.g.,
    network sparsification thresholds), executes the entire research pipeline
    for each combination, and collects key performance indicators (like model
    AUC and network edge count) to assess the robustness of the findings.

    Args:
        parameter_grid (Dict[str, List[Any]]): A dictionary defining the
            parameter sweep. Keys are parameter names (e.g.,
            'primary_firmcount_threshold'), and values are lists of values to test.
        base_manifest (Dict[str, Any]): The base replication manifest with
                                        default parameter settings.
        **kwargs: All the raw data inputs required by the main pipeline.
        n_jobs (int): The number of CPU cores to use for parallel execution.
                      -1 means using all available cores. 1 means no parallelism.

    Returns:
        pd.DataFrame: A tidy DataFrame where each row corresponds to one full
                      pipeline run, with columns for the parameters used and
                      the resulting performance metrics.
    """
    logging.info("===== STARTING TASK 11.1: PARAMETER SENSITIVITY ANALYSIS =====")

    # --- Step 1: Generate all parameter combinations from the grid ---
    param_names = parameter_grid.keys()
    param_values = parameter_grid.values()

    # itertools.product creates the Cartesian product of the parameter values.
    all_combinations = [
        dict(zip(param_names, combo)) for combo in itertools.product(*param_values)
    ]
    logging.info(f"Generated {len(all_combinations)} parameter combinations to test.")

    # --- Step 2: Execute the pipeline for each combination in parallel ---
    # Use joblib.Parallel for robust and easy-to-use parallel processing.
    # `delayed` is a wrapper that makes the worker function picklable.
    if n_jobs == 1:
        logging.info("Running iterations sequentially...")
        results_list = [
            _run_single_pipeline_iteration(
                params, base_manifest, transactions_log_frame, firm_metadata_frame,
                comtrade_exports_frame, country_data_frame,
                supply_chains_definitions_dict, external_networks
            ) for params in all_combinations
        ]
    else:
        logging.info(f"Running iterations in parallel on {n_jobs if n_jobs > 0 else cpu_count()} cores...")
        results_list = Parallel(n_jobs=n_jobs)(
            delayed(_run_single_pipeline_iteration)(
                params, base_manifest, transactions_log_frame, firm_metadata_frame,
                comtrade_exports_frame, country_data_frame,
                supply_chains_definitions_dict, external_networks
            ) for params in all_combinations
        )

    # --- Step 3: Aggregate results into a final DataFrame ---
    # Filter out any runs that may have failed completely (though the worker
    # function is designed to return an error message instead).
    successful_results = [r for r in results_list if r is not None]

    if not successful_results:
        logging.error("All pipeline runs failed during sensitivity analysis.")
        return pd.DataFrame()

    results_df = pd.DataFrame(successful_results)

    logging.info("===== PARAMETER SENSITIVITY ANALYSIS COMPLETED =====")

    return results_df


# =============================================================================
# Task 11.2: Temporal Window Robustness Testing
# =============================================================================

def _extract_key_econometric_results(
    pipeline_results: Dict[str, Any]
) -> Dict[str, float]:
    """
    A helper function to extract key econometric metrics from a pipeline result.

    Args:
        pipeline_results (Dict[str, Any]): The output dictionary from a full
                                           pipeline run.

    Returns:
        Dict[str, float]: A flattened dictionary with key performance and
                          model summary statistics.
    """
    # Initialize an empty dictionary to store the extracted results.
    extracted = {}

    # Check if the expected results are present.
    if 'econometric_results' in pipeline_results:
        # Iterate through both the 'downstream' and 'upstream' models.
        for model_name, results in pipeline_results['econometric_results'].items():
            # Extract performance metrics, providing None as a default.
            performance = results.get('performance', {})
            extracted[f'{model_name}_auc'] = performance.get('auc')

            # Extract model summary statistics, providing None as a default.
            summary = results.get('model_summary', {})
            extracted[f'{model_name}_coeff'] = summary.get('coefficient')
            extracted[f'{model_name}_p_value'] = summary.get('p_value')
            extracted[f'{model_name}_ame'] = summary.get('average_marginal_effect')

        # Extract the final sample size for this econometric run.
        if 'econometric_dataset' in pipeline_results:
            extracted['sample_size'] = len(pipeline_results['econometric_dataset'])

    return extracted


def run_temporal_robustness_analysis(
    base_manifest: Dict[str, Any],
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    external_networks: Dict[str, Any] = None
) -> pd.DataFrame:
    """
    Performs a temporal robustness analysis by comparing the baseline results
    to results from an alternative, longer time window.

    This function executes the full end-to-end pipeline twice:
    1.  With the baseline econometric time window (e.g., 2016-2021).
    2.  With an alternative, longer time window (e.g., 2011-2021).

    It then extracts key econometric findings (AUC, coefficients, p-values)
    from both runs and presents them in a comparative DataFrame to assess the
    robustness of the conclusions to the choice of time period.

    Args:
        base_manifest (Dict[str, Any]): The base replication manifest with
                                        default parameter settings.
        **kwargs: All the raw data inputs required by the main pipeline.

    Returns:
        pd.DataFrame: A DataFrame comparing the key econometric results
                      across the two temporal windows.

    Raises:
        DataValidationError: If any of the pipeline runs fail critically.
    """
    logging.info("===== STARTING TASK 11.2: TEMPORAL ROBUSTNESS ANALYSIS =====")

    all_results = {}

    # --- Run 1: Baseline Analysis (5-Year Window) ---
    logging.info("\n--- Running Baseline Analysis (5-Year Window: 2016-2021) ---")
    try:
        # Execute the pipeline with the original, unmodified manifest.
        baseline_results = run_end_to_end_pipeline(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            supply_chains_definitions_dict=supply_chains_definitions_dict,
            replication_manifest=base_manifest,
            external_networks=external_networks
        )
        # Extract and store the key results.
        all_results['Baseline_5_Year_Window'] = _extract_key_econometric_results(baseline_results)
    except Exception as e:
        logging.error(f"Baseline pipeline run failed: {e}")
        # If the baseline fails, we cannot proceed.
        raise

    # --- Run 2: Alternative Analysis (10-Year Window) ---
    logging.info("\n--- Running Robustness Analysis (10-Year Window: 2011-2021) ---")
    try:
        # Create a deep copy of the manifest to modify for the alternative run.
        robustness_manifest = deepcopy(base_manifest)

        # Modify the econometric analysis years.
        robustness_manifest['parameters']['econometric_analysis']['start_year_for_diversification'] = 2011
        robustness_manifest['parameters']['econometric_analysis']['end_year_for_diversification'] = 2021

        # Execute the pipeline with the modified manifest.
        robustness_results = run_end_to_end_pipeline(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            supply_chains_definitions_dict=supply_chains_definitions_dict,
            replication_manifest=robustness_manifest,
            external_networks=external_networks
        )
        # Extract and store the key results.
        all_results['Robustness_10_Year_Window'] = _extract_key_econometric_results(robustness_results)
    except Exception as e:
        # If the robustness run fails, log it but don't halt. We can still
        # report the baseline results.
        logging.error(f"Temporal robustness pipeline run failed: {e}")
        all_results['Robustness_10_Year_Window'] = {}

    # --- Step 3: Aggregate results into a final comparative DataFrame ---
    # Convert the nested dictionary of results into a DataFrame for easy comparison.
    results_df = pd.DataFrame(all_results).T # Transpose to have scenarios as rows.

    logging.info("\n===== TEMPORAL ROBUSTNESS ANALYSIS COMPLETED =====")

    # Display the comparative results.
    print("\n--- Comparative Temporal Robustness Results ---")
    print(results_df)
    print("\n")

    return results_df

# =============================================================================
# Task 11.3: Alternative Network Construction Robustness
# =============================================================================

def _identify_significant_entities_refactored(
    cleansed_df: pd.DataFrame,
    entity_type: str,
    method: Any = 'mean'
) -> Dict[int, Set[str]]:
    """
    Identifies significant producers or purchasers for each product using a
    configurable statistical method. (REFACTORED for Task 11.3)

    This function implements the core classification methodology from the paper.
    An entity (a firm's ultimate owner in a specific country) is classified as
    a "significant" producer or purchaser of a product if its total sales or
    purchase value for that product exceeds a statistical threshold. This
    refactored version allows the threshold to be defined as the mean, median,
    or a specific quantile of the total values across all entities active in
    that product market.

    Args:
        cleansed_df (pd.DataFrame): The preprocessed and cleansed transaction log,
                                    containing resolved owner-country entities.
        entity_type (str): The type of entity to identify. Must be either
                           'producer' or 'purchaser'.
        method (Any): The statistical method for the threshold. Can be the
                      string 'mean', the string 'median', or a float between
                      0 and 1 for a specific quantile (e.g., 0.75 for the
                      75th percentile). Defaults to 'mean'.

    Returns:
        Dict[int, Set[str]]: A dictionary mapping each product HS code (int) to
                             a set of significant entity identifiers (str).

    Raises:
        ValueError: If `entity_type` or `method` are invalid.
    """
    # --- Step 1: Input Validation and Column Selection ---
    # Validate the entity_type parameter to ensure correct column selection.
    if entity_type == 'producer':
        # For producers, the relevant entity is the seller.
        entity_col = 'seller_owner_country_entity'
    elif entity_type == 'purchaser':
        # For purchasers, the relevant entity is the buyer.
        entity_col = 'buyer_owner_country_entity'
    else:
        # Raise an error for invalid entity types.
        raise ValueError("`entity_type` must be either 'producer' or 'purchaser'.")

    # Log the start of the process with the specified method.
    logging.info(f"Identifying significant {entity_type}s using '{method}' threshold...")

    # --- Step 2: Calculate total transaction value per (entity, product) pair ---
    # Group by the product and the relevant entity column, then sum the values.
    # This creates a summary of total activity for each entity in each product.
    entity_product_values = cleansed_df.groupby(
        ['product_hs_code', entity_col]
    )['transaction_value_usd'].sum().reset_index()

    # Rename the aggregated column for clarity.
    entity_product_values.rename(
        columns={'transaction_value_usd': 'total_value'}, inplace=True
    )

    # --- Step 3: Determine and apply the statistical threshold function ---
    # This block allows for flexible thresholding methods based on the `method` param.
    if method == 'mean':
        # Use the string 'mean' for the highly optimized transform function.
        transform_func = 'mean'
    elif method == 'median':
        # Use the string 'median' for the highly optimized transform function.
        transform_func = 'median'
    elif isinstance(method, float) and 0 < method < 1:
        # For a float, use a lambda function to calculate the specified quantile.
        transform_func = lambda x: x.quantile(method)
    else:
        # If the method is not recognized, raise a descriptive ValueError.
        raise ValueError(f"Invalid classification method: '{method}'. Must be "
                         "'mean', 'median', or a float between 0 and 1.")

    # Calculate the threshold for each product using the selected function.
    # `.transform()` is highly efficient as it computes the group-wise statistic
    # and broadcasts the result back to the original shape, avoiding a merge.
    entity_product_values['threshold'] = entity_product_values.groupby(
        'product_hs_code'
    )['total_value'].transform(transform_func)

    # --- Step 4: Filter for entities that exceed the calculated threshold ---
    # Create a boolean mask to identify the significant entities for each product.
    significant_mask = entity_product_values['total_value'] > entity_product_values['threshold']

    # Apply the mask to get the final DataFrame of significant entities.
    significant_entities_df = entity_product_values[significant_mask]

    # --- Step 5: Aggregate the significant entities into sets for each product ---
    # Group the filtered DataFrame by product code.
    # Apply the `set` constructor to the entity column for each group. This
    # efficiently collects all significant entities for a given product.
    entity_sets_series = significant_entities_df.groupby(
        'product_hs_code'
    )[entity_col].apply(set)

    # Convert the resulting pandas Series to a dictionary for fast lookups later.
    entity_sets_dict = entity_sets_series.to_dict()

    # Log a summary of the operation's result.
    logging.info(
        f"Identified significant {entity_type}s for "
        f"{len(entity_sets_dict):,} products using method='{method}'."
    )

    # Return the final dictionary of sets.
    return entity_sets_dict

def classify_firms_and_construct_sets_variant(
    cleansed_df: pd.DataFrame,
    classification_method: Any
) -> Tuple[Dict[int, Set[str]], Dict[int, Set[str]], pd.DataFrame]:
    """
    Orchestrates firm classification using a specified statistical method.
    (VARIANT for Task 11.3)

    This function serves as a variant of the main Task 3 orchestrator,
    allowing the method for identifying significant entities (e.g., 'mean',
    'median', quantile) to be passed as a parameter.

    Args:
        cleansed_df (pd.DataFrame): The fully preprocessed and cleansed
                                    transaction log from Task 2.
        classification_method (Any): The statistical method for the threshold.
            Can be 'mean', 'median', or a float for a quantile.

    Returns:
        Tuple[Dict[int, Set[str]], Dict[int, Set[str]], pd.DataFrame]:
        A tuple containing producer sets, purchaser sets, and intersection metrics.
    """
    # Log the start of the task with the specified method.
    logging.info(
        f"Starting Task 3 (Variant): Firm Classification with method='{classification_method}'..."
    )

    # Step 3.1 (Variant): Identify significant producers with the specified method.
    producer_sets = _identify_significant_entities_refactored(
        cleansed_df=cleansed_df,
        entity_type='producer',
        method=classification_method
    )

    # Step 3.2 (Variant): Identify significant purchasers with the specified method.
    purchaser_sets = _identify_significant_entities_refactored(
        cleansed_df=cleansed_df,
        entity_type='purchaser',
        method=classification_method
    )

    # Step 3.3 (Standard): Compute intersection metrics based on the new sets.
    intersection_metrics = _compute_intersection_metrics(
        cleansed_df=cleansed_df,
        producer_sets=producer_sets,
        purchaser_sets=purchaser_sets
    )

    logging.info(f"Task 3 (Variant) successfully completed for method='{classification_method}'.")

    # Return the three essential data structures.
    return producer_sets, purchaser_sets, intersection_metrics

def run_end_to_end_pipeline_variant(
    classification_method: Any,
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    replication_manifest: Dict[str, Any],
    external_networks: Dict[str, Any] = None
) -> Dict[str, Any]:
    """
    Executes the full research pipeline with a parameterizable firm
    classification method. (VARIANT for Task 11.3)

    This master orchestrator is a fully functional equivalent of the main
    `run_end_to_end_pipeline`. It is specifically designed for the construction
    robustness analysis by accepting a `classification_method` parameter. This
    parameter is injected into the Task 3 stage of the workflow, allowing the
    entire network and subsequent analyses to be rebuilt based on different
    foundational assumptions about what constitutes a significant firm.

    Args:
        classification_method (Any): The method ('mean', 'median', quantile)
            to be used in the firm classification stage (Task 3).
        transactions_log_frame (pd.DataFrame): Raw log of firm transactions.
        firm_metadata_frame (pd.DataFrame): Raw firm ownership and location data.
        comtrade_exports_frame (pd.DataFrame): Raw country-product export data.
        country_data_frame (pd.DataFrame): Raw country population data.
        supply_chains_definitions_dict (Dict[str, Any]): Definitions of
            manual supply chains for validation.
        replication_manifest (Dict[str, Any]): Dictionary of all study parameters.
        external_networks (Dict[str, Any], optional): Loaded external networks
            for comparison. Defaults to None.

    Returns:
        Dict[str, Any]: A nested dictionary containing all key outputs of the
                        pipeline for this specific, variant run.
    """
    # Initialize a dictionary to store all final results for this run.
    pipeline_results = {}

    # Log the start of the variant pipeline run, specifying the method.
    logging.info(f"===== STARTING END-TO-END PIPELINE (VARIANT: method='{classification_method}') =====")

    # --- Task 1: Data Validation and Quality Assurance ---
    # This step is independent of the classification method and runs as standard.
    validate_and_assess_inputs(
        transactions_log_frame, firm_metadata_frame, comtrade_exports_frame,
        country_data_frame, replication_manifest
    )

    # --- Task 2: Data Preprocessing and Cleansing ---
    # This step is also independent of the classification method.
    cleansed_df = preprocess_and_cleanse_data(
        transactions_log_frame, firm_metadata_frame, replication_manifest
    )

    # --- Task 3 (Variant Call) ---
    # This is the critical modification: call the new variant orchestrator for Task 3,
    # passing the specified classification method.
    producer_sets, purchaser_sets, intersection_metrics = classify_firms_and_construct_sets_variant(
        cleansed_df=cleansed_df,
        classification_method=classification_method
    )

    # --- Task 4: Network Inference Implementation ---
    # All subsequent tasks use the outputs from the variant Task 3.
    adj_matrix, graph, product_to_idx = infer_and_construct_network(
        cleansed_df, producer_sets, purchaser_sets, intersection_metrics, replication_manifest
    )
    pipeline_results['network_objects'] = {'adjacency_matrix': adj_matrix, 'graph': graph, 'product_to_idx_map': product_to_idx}

    # --- Task 5: Community Detection and Structural Analysis ---
    community_labels = analyze_network_structure(
        graph, adj_matrix, product_to_idx, replication_manifest
    )

    # --- Task 6: Centrality Calculations ---
    centrality_df = calculate_centralities(graph, replication_manifest)
    pipeline_results['network_analysis'] = {'community_labels': community_labels, 'centralities': centrality_df}

    # --- Task 7: Validation Procedures ---
    validation_results = run_validation_procedures(
        adj_matrix, product_to_idx, supply_chains_definitions_dict, replication_manifest, external_networks
    )
    pipeline_results['validation_results'] = validation_results

    # --- Task 8: Economic Feature Engineering ---
    econometric_df = engineer_economic_features(
        comtrade_exports_frame, country_data_frame, adj_matrix, product_to_idx, replication_manifest
    )
    pipeline_results['econometric_dataset'] = econometric_df

    # --- Task 9: Econometric Analysis ---
    econometric_results = run_econometric_analysis(
        econometric_df, replication_manifest
    )
    pipeline_results['econometric_results'] = econometric_results

    # Log the successful completion of this variant run.
    logging.info(f"===== END-TO-END PIPELINE (VARIANT: method='{classification_method}') COMPLETED SUCCESSFULLY =====")

    # Return the comprehensive results dictionary.
    return pipeline_results

def run_construction_robustness_analysis_revised(
    methods_to_test: List[Any],
    base_manifest: Dict[str, Any],
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    external_networks: Dict[str, Any] = None
) -> pd.DataFrame:
    """
    Performs a robustness analysis by reconstructing the network using
    alternative firm classification methods. (REVISED for Task 11.3)

    This function iterates through a list of methods (e.g., 'mean', 'median'),
    invokes a variant of the end-to-end pipeline for each, and produces a
    comparative summary of the key econometric findings.

    Args:
        methods_to_test (List[Any]): A list of classification methods to test.
        base_manifest (Dict[str, Any]): The base replication manifest.
        **kwargs: All the raw data inputs required by the main pipeline.

    Returns:
        pd.DataFrame: A DataFrame comparing key results across the different
                      network construction methods.
    """
    logging.info("===== STARTING TASK 11.3: CONSTRUCTION ROBUSTNESS ANALYSIS (REVISED) =====")

    all_results = {}

    # Iterate through each specified classification method.
    for method in methods_to_test:
        # Create a descriptive name for the run.
        method_name = f"Method_{method}"
        logging.info(f"\n--- Running Analysis for Construction Method: '{method}' ---")

        try:
            # Call the pipeline variant, passing the current method.
            pipeline_results = run_end_to_end_pipeline_variant(
                classification_method=method,
                transactions_log_frame=transactions_log_frame,
                firm_metadata_frame=firm_metadata_frame,
                comtrade_exports_frame=comtrade_exports_frame,
                country_data_frame=country_data_frame,
                supply_chains_definitions_dict=supply_chains_definitions_dict,
                replication_manifest=base_manifest,
                external_networks=external_networks
            )

            # Extract and store the key results for this method.
            run_summary = _extract_key_econometric_results(pipeline_results)
            # Also capture how the network structure changes.
            run_summary['edge_count'] = pipeline_results['network_objects']['graph'].number_of_edges()
            all_results[method_name] = run_summary

        except Exception as e:
            # Log any failure for a specific method and continue.
            logging.error(f"Pipeline run failed for construction method '{method}': {e}")
            all_results[method_name] = {'error': str(e)}

    # Aggregate results into a final comparative DataFrame.
    results_df = pd.DataFrame(all_results).T

    logging.info("\n===== CONSTRUCTION ROBUSTNESS ANALYSIS COMPLETED =====")

    # Display the comparative results.
    print("\n--- Comparative Construction Robustness Results ---")
    print(results_df)
    print("\n")

    return results_df

# =============================================================================
# Robustness Analysis Orchestrator Function
# =============================================================================

def run_full_robustness_analysis(
    # All raw data inputs required by the core pipeline
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    base_manifest: Dict[str, Any],
    external_networks: Optional[Dict[str, Any]] = None,
    # Configuration for the specific robustness checks
    parameter_grid: Optional[Dict[str, List[Any]]] = None,
    methods_to_test: Optional[List[Any]] = None,
    n_jobs: int = -1
) -> Dict[str, pd.DataFrame]:
    """
    Orchestrates the execution of a comprehensive suite of robustness analyses.

    This master function serves as the single entry point for Task 11. It
    sequentially invokes the dedicated orchestrators for each type of
    robustness check:
    1.  Parameter Sensitivity: Varies key network inference parameters.
    2.  Temporal Window: Re-runs the analysis over a different time period.
    3.  Construction Method: Re-builds the network using alternative firm
        classification rules.

    It aggregates the results from each check into a final, structured
    dictionary, providing a complete picture of the stability and reliability
    of the study's findings.

    Args:
        transactions_log_frame (pd.DataFrame): Raw log of firm transactions.
        firm_metadata_frame (pd.DataFrame): Raw firm ownership and location data.
        comtrade_exports_frame (pd.DataFrame): Raw country-product export data.
        country_data_frame (pd.DataFrame): Raw country population data.
        supply_chains_definitions_dict (Dict[str, Any]): Definitions of
            manual supply chains for validation.
        base_manifest (Dict[str, Any]): The base replication manifest with
                                        default parameter settings.
        external_networks (Optional[Dict[str, Any]]): Loaded external networks
            for comparison. Defaults to None.
        parameter_grid (Optional[Dict[str, List[Any]]]): Configuration for the
            parameter sensitivity analysis. If None, this check is skipped.
        methods_to_test (Optional[List[Any]]): Configuration for the construction
            robustness analysis. If None, this check is skipped.
        n_jobs (int): The number of CPU cores to use for parallelizable tasks.
                      -1 means using all available cores.

    Returns:
        Dict[str, pd.DataFrame]: A dictionary where keys are the names of the
                                 robustness checks and values are the
                                 corresponding result DataFrames.
    """
    # Set up a comprehensive logger for the entire robustness suite.
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - [Robustness Orchestrator] - %(message)s'
    )

    # Initialize a dictionary to store all robustness results.
    robustness_results = {}

    logging.info("===== STARTING FULL ROBUSTNESS ANALYSIS SUITE =====")

    # --- Task 11.1: Parameter Sensitivity Analysis ---
    # This check is run only if a parameter grid is provided.
    if parameter_grid:
        try:
            # Execute the parameter sensitivity analysis.
            sensitivity_results_df = run_parameter_sensitivity_analysis(
                parameter_grid=parameter_grid,
                base_manifest=base_manifest,
                transactions_log_frame=transactions_log_frame,
                firm_metadata_frame=firm_metadata_frame,
                comtrade_exports_frame=comtrade_exports_frame,
                country_data_frame=country_data_frame,
                supply_chains_definitions_dict=supply_chains_definitions_dict,
                external_networks=external_networks,
                n_jobs=n_jobs
            )
            # Store the resulting DataFrame.
            robustness_results['parameter_sensitivity'] = sensitivity_results_df
        except Exception as e:
            # Log a failure in this specific check but do not halt the entire suite.
            logging.error(f"Task 11.1 (Parameter Sensitivity) failed: {e}", exc_info=True)
            robustness_results['parameter_sensitivity'] = pd.DataFrame({'error': [str(e)]})

    # --- Task 11.2: Temporal Window Robustness Testing ---
    try:
        # Execute the temporal robustness analysis.
        temporal_results_df = run_temporal_robustness_analysis(
            base_manifest=base_manifest,
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            supply_chains_definitions_dict=supply_chains_definitions_dict,
            external_networks=external_networks
        )
        # Store the resulting comparative DataFrame.
        robustness_results['temporal_robustness'] = temporal_results_df
    except Exception as e:
        # Log a failure in this specific check.
        logging.error(f"Task 11.2 (Temporal Robustness) failed: {e}", exc_info=True)
        robustness_results['temporal_robustness'] = pd.DataFrame({'error': [str(e)]})

    # --- Task 11.3: Alternative Network Construction Robustness ---
    # This check is run only if a list of alternative methods is provided.
    if methods_to_test:
        try:
            # Execute the construction robustness analysis.
            construction_results_df = run_construction_robustness_analysis_revised(
                methods_to_test=methods_to_test,
                base_manifest=base_manifest,
                transactions_log_frame=transactions_log_frame,
                firm_metadata_frame=firm_metadata_frame,
                comtrade_exports_frame=comtrade_exports_frame,
                country_data_frame=country_data_frame,
                supply_chains_definitions_dict=supply_chains_definitions_dict,
                external_networks=external_networks
            )
            # Store the resulting comparative DataFrame.
            robustness_results['construction_robustness'] = construction_results_df
        except Exception as e:
            # Log a failure in this specific check.
            logging.error(f"Task 11.3 (Construction Robustness) failed: {e}", exc_info=True)
            robustness_results['construction_robustness'] = pd.DataFrame({'error': [str(e)]})

    # Log the completion of the entire suite.
    logging.info("===== FULL ROBUSTNESS ANALYSIS SUITE COMPLETED =====")

    # Return the final, aggregated robustness results.
    return robustness_results


In [None]:
# Master Orchestrator Callable

# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# Fused High-Fidelity Constellation of Callables for Replicating the Study:
# "Deciphering the global production network from cross-border firm transactions"
#
# This script provides a complete, end-to-end implementation of the research pipeline from the paper.
# It is designed with modularity, rigor, and performance as primary objectives.
#
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-


def main_analysis_orchestrator(
    # All raw data inputs
    transactions_log_frame: pd.DataFrame,
    firm_metadata_frame: pd.DataFrame,
    comtrade_exports_frame: pd.DataFrame,
    country_data_frame: pd.DataFrame,
    supply_chains_definitions_dict: Dict[str, Any],
    # Configuration inputs
    base_manifest: Dict[str, Any],
    # Optional inputs for robustness checks
    run_robustness_checks: bool = True,
    parameter_grid: Optional[Dict[str, List[Any]]] = None,
    methods_to_test: Optional[List[Any]] = None,
    external_networks: Optional[Dict[str, Any]] = None,
    n_jobs: int = -1
) -> Dict[str, Any]:
    """
    Serves as the top-level entry point for the entire research project,
    orchestrating both the baseline replication and the full suite of
    robustness analyses.

    This function provides a single, powerful interface to a complex and
    rigorous research pipeline. It first executes the end-to-end replication
    of the study's main findings. Then, if enabled, it proceeds to run a
    comprehensive set of robustness checks to assess the stability of those
    findings with respect to key methodological choices.

    Args:
        transactions_log_frame (pd.DataFrame): Raw log of firm transactions.
        firm_metadata_frame (pd.DataFrame): Raw firm ownership and location data.
        comtrade_exports_frame (pd.DataFrame): Raw country-product export data.
        country_data_frame (pd.DataFrame): Raw country population data.
        supply_chains_definitions_dict (Dict[str, Any]): Definitions of
            manual supply chains for validation.
        base_manifest (Dict[str, Any]): The base replication manifest with
                                        default parameter settings.
        run_robustness_checks (bool): A flag to enable or disable the entire
                                      robustness analysis suite. Defaults to True.
        parameter_grid (Optional[Dict[str, List[Any]]]): Configuration for the
            parameter sensitivity analysis. Required if robustness checks are run.
        methods_to_test (Optional[List[Any]]): Configuration for the construction
            robustness analysis. Required if robustness checks are run.
        external_networks (Optional[Dict[str, Any]]): Loaded external networks
            for comparison. Defaults to None.
        n_jobs (int): The number of CPU cores to use for parallelizable tasks.
                      -1 means using all available cores.

    Returns:
        Dict[str, Any]: A master dictionary containing two top-level keys:
                        'baseline_results': The full output of the main study
                                            replication.
                        'robustness_results': A dictionary with the results of
                                              each completed robustness check.
    """
    # --- Initialization ---
    # Set up a global logger for the entire session.
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - [Main Orchestrator] - %(message)s',
        force=True # Force re-configuration if logger is already set
    )

    # Initialize the master dictionary to store all results.
    master_results = {
        'baseline_results': None,
        'robustness_results': None
    }

    logging.info("||||| COMMENCING FULL ANALYSIS AND ROBUSTNESS PIPELINE |||||")

    # --- Baseline Replication Run ---
    try:
        logging.info("\n" + "="*80 + "\n" + "||" + " "*25 + "EXECUTING BASELINE REPLICATION" + " "*25 + "||\n" + "="*80)

        # Execute the primary end-to-end pipeline.
        baseline_pipeline_results = run_end_to_end_pipeline(
            transactions_log_frame=transactions_log_frame,
            firm_metadata_frame=firm_metadata_frame,
            comtrade_exports_frame=comtrade_exports_frame,
            country_data_frame=country_data_frame,
            supply_chains_definitions_dict=supply_chains_definitions_dict,
            replication_manifest=base_manifest,
            external_networks=external_networks
        )
        # Store the complete set of results from the baseline run.
        master_results['baseline_results'] = baseline_pipeline_results

        logging.info("||" + " "*24 + "BASELINE REPLICATION COMPLETED SUCCESSFULLY" + " "*23 + "||")

    except Exception as e:
        # If the baseline run fails, it's a critical error. Log and terminate.
        logging.critical(f"The baseline pipeline run failed critically: {e}", exc_info=True)
        logging.critical("Cannot proceed to robustness checks. Terminating analysis.")
        # Return the partial results dictionary with the error.
        master_results['error'] = f"Baseline pipeline failed: {e}"
        return master_results

    # --- Full Robustness Analysis Suite ---
    # Proceed only if the baseline was successful and checks are enabled.
    if run_robustness_checks:
        try:
            logging.info("\n" + "="*80 + "\n" + "||" + " "*24 + "EXECUTING ROBUSTNESS ANALYSIS SUITE" + " "*23 + "||\n" + "="*80)

            # Execute the master orchestrator for all robustness checks.
            robustness_suite_results = run_full_robustness_analysis(
                transactions_log_frame=transactions_log_frame,
                firm_metadata_frame=firm_metadata_frame,
                comtrade_exports_frame=comtrade_exports_frame,
                country_data_frame=country_data_frame,
                supply_chains_definitions_dict=supply_chains_definitions_dict,
                base_manifest=base_manifest,
                external_networks=external_networks,
                parameter_grid=parameter_grid,
                methods_to_test=methods_to_test,
                n_jobs=n_jobs
            )
            # Store the complete set of results from the robustness suite.
            master_results['robustness_results'] = robustness_suite_results

            logging.info("||" + " "*23 + "ROBUSTNESS ANALYSIS SUITE COMPLETED SUCCESSFULLY" + " "*22 + "||")

        except Exception as e:
            # If the robustness suite fails, log the error but do not crash.
            # The baseline results are still valuable.
            logging.error(f"The robustness analysis suite encountered an error: {e}", exc_info=True)
            master_results['robustness_results'] = {'error': f"Robustness suite failed: {e}"}
    else:
        # Log if robustness checks were intentionally skipped.
        logging.info("Robustness checks were disabled. Skipping Task 11.")

    # Log the successful completion of the entire analysis.
    logging.info("\n||||| FULL ANALYSIS AND ROBUSTNESS PIPELINE CONCLUDED |||||")

    # Return the final, comprehensive results.
    return master_results
