In [1]:
# ============================================================================
# FILE PATHS - edit these to match your local setup
# ============================================================================
DATA_PATH = '/Users/leoss/Desktop/Portfolio/Website-/projects/export/data/atlas_2022.dta'
SAVE_DIR  = '/Users/leoss/Desktop/Portfolio/Website-/projects/export/outputs'

# External product/sector lookup (from PP413 project data)
SECTOR_LOOKUP_PATH = "/Users/leoss/Desktop/Portfolio/Website-/projects/export/data/unique_hs_codes_and_sectors.csv"


In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np
import seaborn as sns
from adjustText import adjust_text
from matplotlib.ticker import FuncFormatter
import plotly.express as px
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')


# Trade Analysis Framework - Italy
Configurable pipeline for trade data from the Atlas of Economic Complexity.  
Produces product-level and sector-level bar charts, time-series trends, opportunity scatter plots, and treemaps.

**Changelog vs. previous version:**
- Fixed: strategy weights now match the report formulas (LHF, Balanced, Long Jumps)
- Fixed: product names no longer stuck in lowercase (uses external CSV lookup instead of manual dict)
- Fixed: `reversed_normalized_distance` now uses standard `1 - norm` formula
- Fixed: sector PCI aggregation handles zero-weight edge cases
- Fixed: sector_mapping filtered to base year to prevent merge duplicates
- Fixed: net_export time-series PCI uses `export_value` weights instead of clipped net exports
- Improved: removed manual 30-product name mapping; all 1241 products resolved via lookup CSV


In [3]:
# ============================================================================
# CONFIGURATION
# ============================================================================

CONFIG = {
    'data': {
        'path': DATA_PATH,
        'country_code': 'ITA',
        'country_name': 'Italy',
        'sector_lookup': SECTOR_LOOKUP_PATH,  # external clean product names
    },
    'analysis': {
        'base_year': 2022,
        'time_range': (1995, 2022),
        'value_metrics': ['export_value', 'net_export'],
        'rolling_window': 5,
        'top_n': 10,
    },
    'opportunities': {
        'enabled': True,
        'min_pci': 1.49,
        'max_rca': 1.0,
        # FIX: weights now match report Appendix 3F formulas exactly
        'strategies': {
            'lhf': {'distance': 0.75, 'pci': 0.10, 'cog': 0.15},   # Low Hanging Fruit
            'bs':  {'distance': 0.50, 'pci': 0.25, 'cog': 0.25},   # Balanced Strategy
            'lj':  {'distance': 0.40, 'pci': 0.20, 'cog': 0.40},   # Long Jumps / Strategic Bets
        }
    },
    'visualization': {
        'save_format': 'png',
        'dpi': 300,
        'figure_sizes': {
            'bar': (14, 8),
            'line': (16, 8),
            'scatter': (14, 8.5),
        },
        'color_palette': 'husl',
        'font_family': 'Arial',
    },
    'output': {
        'directory': SAVE_DIR,
        'create_treemap': True,
        'treemap_year': 1995,
        'export_csv': True,
    },
    'advanced': {
        'exclude_sectors': ['Other'],
        'min_export_threshold': 1e6,
    }
}


In [4]:
# ============================================================================
# DATA LOADING AND PREPARATION
# ============================================================================

class TradeDataLoader:
    """Loads Atlas .dta and applies clean product names from external lookup."""

    def __init__(self, config: Dict):
        self.config = config
        self.data_cfg = config['data']
        self.adv_cfg  = config['advanced']

    # ------------------------------------------------------------------
    def load_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, Dict]:
        print(f"Loading data from: {self.data_cfg['path']}")
        df = pd.read_stata(self.data_cfg['path'])

        # net exports
        df['net_export'] = df['export_value'] - df['import_value']

        if 'country_id' in df.columns:
            df = df.drop(columns='country_id')

        for sector in self.adv_cfg['exclude_sectors']:
            df = df[df['sector_name'] != sector]

        # ---- FIX: apply clean product names from external CSV ----------
        # Instead of a manual 30-entry dict + lowercase, join on hs_92_code
        # so every product gets a properly cased short name.
        lookup_path = self.data_cfg.get('sector_lookup')
        if lookup_path and os.path.exists(lookup_path):
            # read raw bytes, strip NUL chars, then parse
            with open(lookup_path, 'rb') as fh:
                raw = fh.read().replace(b'\x00', b'')
            from io import StringIO
            lookup = pd.read_csv(
            StringIO(raw.decode('latin-1')),
            sep=',',
            engine='python',
            on_bad_lines='warn',
            )
            print(f"  lookup columns: {list(lookup.columns)}")
            print(f"  lookup shape: {lookup.shape}")
            # rename to avoid collision during merge
            lookup = lookup.rename(columns={
                'product_name': 'clean_name',
                'sector_name':  'clean_sector',
            })
            lookup['hs_92_code'] = lookup['hs_92_code'].astype(str)
            df['hs_92_code'] = df['hs_92_code'].astype(str)
            df = df.merge(
                lookup[['hs_92_code', 'clean_name']],
                on='hs_92_code', how='left'
            )
            # use clean name where available, else keep original
            df['product_name'] = df['clean_name'].fillna(df['product_name'])
            df = df.drop(columns='clean_name')
            print(f"  applied clean product names from {os.path.basename(lookup_path)}")
        else:
            print("  WARNING: sector lookup CSV not found; using raw Atlas names")

        # ---- split country vs rest-of-world ----
        code = self.data_cfg['country_code']
        country_df = df[df['iso3_code'] == code].copy()
        global_df  = df[df['iso3_code'] != code].copy()

        global_df = global_df[[
            'name_short_en', 'iso3_code', 'year',
            'sector', 'sector_name', 'product_name', 'hs_92_code', 'product_level',
            'export_value', 'import_value', 'net_export', 'global_market_share',
            'export_rca', 'pci', 'distance', 'cog'
        ]]

        # FIX: sector_mapping from base_year only to avoid merge duplicates
        base_year = self.config['analysis']['base_year']
        sector_mapping = (
            country_df[country_df['year'] == base_year]
            [['hs_92_code', 'product_name', 'sector_name']]
            .drop_duplicates()
        )

        sector_colors = self._sector_colors(country_df)

        print(f"  data loaded: {self.data_cfg['country_name']} "
              f"({len(country_df):,} obs), Global ({len(global_df):,} obs)")
        print(f"  years: {country_df['year'].min()}-{country_df['year'].max()}")
        print(f"  sectors: {', '.join(sorted(country_df['sector_name'].unique()))}")

        return country_df, global_df, sector_mapping, sector_colors

    # ------------------------------------------------------------------
    def _sector_colors(self, df: pd.DataFrame) -> Dict:
        sectors = sorted(df['sector_name'].unique())
        pal = sns.color_palette(self.config['visualization']['color_palette'],
                                n_colors=len(sectors))
        return dict(zip(sectors, pal))


In [5]:
# ============================================================================
# METRICS CALCULATION
# ============================================================================

class MetricsCalculator:
    """Product-level and sector-level trade metrics."""

    def __init__(self, config: Dict):
        self.config = config

    # ------------------------------------------------------------------
    def calculate_product_metrics(self, country_data, global_ref, value_col):
        agg = (
            country_data
            .groupby(['hs_92_code', 'product_name', 'sector_name'])
            .agg({
                value_col: 'sum',
                'export_rca': 'max',
                'pci': 'first',
                'distance': 'first',
                'cog': 'first',
            })
            .reset_index()
            .rename(columns={'export_rca': 'rca'})
        )
        ref = global_ref[['hs_92_code', 'global_value', 'pci_rank']].copy()
        merged = agg.merge(ref, on='hs_92_code', how='left')
        merged['market_share']   = (merged[value_col] / merged['global_value']) * 100
        merged['pci_percentile'] = merged['pci_rank']
        merged = merged.drop(columns='pci_rank')
        return merged

    # ------------------------------------------------------------------
    def prepare_global_reference(self, country_data, global_data, year, value_col):
        cy = country_data[country_data['year'] == year]
        gy = global_data[global_data['year'] == year]

        combined = pd.concat([gy, cy])
        gref = combined.groupby('hs_92_code')[value_col].sum().reset_index(name='global_value')

        gpci = gy[['hs_92_code', 'pci']].drop_duplicates()
        gref = gref.merge(gpci, on='hs_92_code', how='left')
        gref['pci_rank'] = gref['pci'].rank(pct=True) * 100
        return gref

    # ------------------------------------------------------------------
    def calculate_sector_metrics_timeseries(self, country_df, global_df, value_col):
        """Sector-level metrics over time with rolling averages.

        FIX: PCI weighting always uses export_value (positive) even when
        value_col is net_export, to avoid zero/negative weight issues.
        """
        rw = self.config['analysis']['rolling_window']
        weight_col = 'export_value'  # always positive

        country_export = (
            country_df.groupby(['sector_name', 'year'])[value_col]
            .sum().reset_index(name='country_value')
        )
        global_export = (
            global_df.groupby(['sector_name', 'year'])[value_col]
            .sum().reset_index(name='global_value')
        )

        # weighted PCI - use export_value weights for stability
        def safe_wavg(g):
            w = g[weight_col].clip(lower=0)
            if w.sum() == 0:
                return pd.Series({'pci': np.nan})
            return pd.Series({'pci': np.average(g['pci'], weights=w)})

        pci_metrics = (
            country_df.groupby(['sector_name', 'year'])
            .apply(safe_wavg)
            .reset_index()
        )

        metrics = (
            country_export
            .merge(global_export, on=['sector_name', 'year'])
            .merge(pci_metrics, on=['sector_name', 'year'])
        )

        # RCA
        tc = country_df.groupby('year')[value_col].sum().reset_index(name='total_country')
        tg = global_df.groupby('year')[value_col].sum().reset_index(name='total_global')
        metrics = (
            metrics.merge(tc, on='year').merge(tg, on='year')
            .assign(
                market_share=lambda x: (x['country_value'] / x['global_value']) * 100,
                rca=lambda x: (x['country_value'] / x['total_country']) /
                              (x['global_value'] / x['total_global']),
            )
        )

        # rolling averages
        metrics = metrics.sort_values(['sector_name', 'year'])
        for src, dst in [('country_value', 'rolling_value'),
                         ('market_share',  'rolling_market_share'),
                         ('pci',           'rolling_pci'),
                         ('rca',           'rolling_rca')]:
            metrics[dst] = (
                metrics.groupby('sector_name', group_keys=False)[src]
                .transform(lambda s: s.rolling(rw, min_periods=1).mean())
            )
        metrics = metrics.rename(columns={'country_value': value_col})
        return metrics


In [6]:
# ============================================================================
# OPPORTUNITY ANALYSIS
# ============================================================================

class OpportunityAnalyzer:
    """Identifies new export opportunities using strategic frameworks."""

    def __init__(self, config: Dict):
        self.config = config
        self.opp = config['opportunities']

    # ------------------------------------------------------------------
    def analyze_opportunities(self, country_data, global_data):
        if not self.opp['enabled']:
            return pd.DataFrame(), {}

        opps = country_data[
            (country_data['export_rca'] < self.opp['max_rca']) &
            (country_data['pci'] > self.opp['min_pci'])
        ].copy()

        if opps.empty:
            print("Warning: no opportunities found with current criteria")
            return opps, {}

        opps = self._normalize(opps)

        # add global export size for bubble sizing
        gexp = global_data.groupby('hs_92_code')['export_value'].sum().reset_index()
        opps = opps.merge(
            gexp.rename(columns={'export_value': 'global_export_value'}),
            on='hs_92_code', how='left'
        )
        opps = self._score(opps)

        top_n = self.config['analysis']['top_n']
        top = {s: opps.nlargest(top_n, s) for s in self.opp['strategies']}
        return opps, top

    # ------------------------------------------------------------------
    def _normalize(self, df):
        def minmax(s):
            r = s.max() - s.min()
            return (s - s.min()) / r if r > 0 else s * 0
        df['norm_distance'] = minmax(df['distance'])
        df['norm_pci']      = minmax(df['pci'])
        df['norm_cog']      = minmax(df['cog'])
        # FIX: standard reversal so 1 = closest, 0 = furthest
        df['norm_proximity'] = 1 - df['norm_distance']
        return df

    # ------------------------------------------------------------------
    def _score(self, df):
        for name, w in self.opp['strategies'].items():
            df[name] = (
                df['norm_proximity'] * w['distance'] +
                df['norm_pci']       * w['pci'] +
                df['norm_cog']       * w['cog']
            )
        return df


In [7]:
# ============================================================================
# VISUALIZATION (redesigned - unified theme)
# ============================================================================
#
# Changes from original:
#   - Curated 9-sector palette replacing random husl
#   - Consistent typography hierarchy: title 16pt semibold, axis 11pt, ticks 10pt
#   - Uniform grid: light gray dashed, alpha 0.25
#   - White background with subtle spine styling (left + bottom only, gray)
#   - Consistent legend style: frameon, top-right outside, matching font
#   - Bar value annotations outside bars (dark gray) instead of inside (white)
#   - Line charts: thicker lines (2.5), smaller markers (5), area fill at alpha 0.04
#   - Scatter: consistent with line/bar chrome
#   - Fixed: net export time-series filenames now include prefix (no overwrite)
#   - Fixed: NaN/inf filter before bar plotting to suppress posx/posy warnings
#   - Treemap: matching color palette via rgb strings
# ============================================================================

class Visualizer:
    """All chart types for the trade analysis - unified theme."""

    # ---- curated palette: 9 distinct, colourblind-friendly, muted tones ----
    SECTOR_PALETTE = {
        'Agriculture':  '#5B8C5A',   # sage green
        'Chemicals':    '#E07A5F',   # terracotta
        'Electronics':  '#3D405B',   # charcoal blue
        'Machinery':    '#81B29A',   # seafoam
        'Metals':       '#F2CC8F',   # warm sand
        'Minerals':     '#A8896C',   # taupe
        'Stone':        '#B0A8B9',   # lavender gray
        'Textiles':     '#C97C5D',   # burnt sienna
        'Vehicles':     '#577399',   # steel blue
    }

    # ---- theme constants ----
    BG          = '#FFFFFF'
    GRID_COLOR  = '#D5D8DC'
    SPINE_COLOR = '#B0B3B8'
    TEXT_DARK   = '#2C3E50'
    TEXT_MID    = '#5D6D7E'
    TEXT_LIGHT  = '#95A5A6'
    ANNOT_COLOR = '#34495E'

    FONT_TITLE  = 16
    FONT_AXIS   = 11
    FONT_TICK   = 10
    FONT_ANNOT  = 9
    FONT_LEGEND = 9.5

    def __init__(self, config: Dict, sector_colors: Dict):
        self.config  = config
        # override the husl palette with curated one; fall back for unknown sectors
        self.colors  = {s: self.SECTOR_PALETTE.get(s, c)
                        for s, c in sector_colors.items()}
        self.viz     = config['visualization']
        self.out_dir = config['output']['directory']
        os.makedirs(self.out_dir, exist_ok=True)
        self._apply_rcparams()

    def _apply_rcparams(self):
        plt.rcParams.update({
            'font.family':       self.viz['font_family'],
            'figure.facecolor':  self.BG,
            'axes.facecolor':    self.BG,
            'axes.edgecolor':    self.SPINE_COLOR,
            'axes.linewidth':    0.8,
            'axes.grid':         True,
            'grid.color':        self.GRID_COLOR,
            'grid.linestyle':    '--',
            'grid.linewidth':    0.5,
            'grid.alpha':        0.25,
            'axes.spines.top':   False,
            'axes.spines.right': False,
            'axes.labelcolor':   self.TEXT_DARK,
            'xtick.color':       self.TEXT_MID,
            'ytick.color':       self.TEXT_MID,
            'xtick.labelsize':   self.FONT_TICK,
            'ytick.labelsize':   self.FONT_TICK,
            'legend.frameon':    True,
            'legend.edgecolor':  self.GRID_COLOR,
            'legend.facecolor':  self.BG,
            'legend.fontsize':   self.FONT_LEGEND,
        })

    # ---- save helper ----
    def _save(self, fig, filename):
        fig.savefig(os.path.join(self.out_dir, filename),
                    bbox_inches='tight', dpi=self.viz['dpi'],
                    facecolor=self.BG, edgecolor='none')
        plt.close(fig)
        print(f"  saved: {filename}")

    # ---- sector legend (shared across chart types) ----
    def _sector_legend(self, ax, sectors, style='patch', loc='outside_right'):
        if style == 'patch':
            handles = [plt.Rectangle((0, 0), 1, 1, color=self.colors.get(s, '#CCC'))
                       for s in sectors]
        else:
            handles = [plt.Line2D([0], [0], marker='o', color=self.colors.get(s, '#CCC'),
                                  markersize=8, linewidth=2.2) for s in sectors]

        kwargs = dict(handles=handles, labels=list(sectors),
                      title='Sector', title_fontsize=self.FONT_LEGEND,
                      fontsize=self.FONT_LEGEND, frameon=True,
                      edgecolor=self.GRID_COLOR, facecolor=self.BG)
        if loc == 'outside_right':
            kwargs.update(bbox_to_anchor=(1.02, 1), loc='upper left')
        elif loc == 'lower_right':
            kwargs.update(loc='lower right')
        ax.legend(**kwargs)

    # ---- value formatter ----
    @staticmethod
    def _fmt_value(val, is_usd=False):
        """Compact number formatter for annotations."""
        if not np.isfinite(val):
            return ''
        if is_usd:
            if abs(val) >= 1e9:
                return f'${val/1e9:.1f}B'
            if abs(val) >= 1e6:
                return f'${val/1e6:.0f}M'
            return f'${val:,.0f}'
        if abs(val) >= 100:
            return f'{val:,.0f}'
        if abs(val) >= 1:
            return f'{val:.1f}'
        return f'{val:.2f}'

    # ======================================================================
    # BAR CHARTS
    # ======================================================================
    def plot_product_bars(self, top_products, value_col, labels):
        prefix = 'net_' if value_col == 'net_export' else ''
        is_usd = value_col in ('export_value', 'net_export')
        specs = [
            ('volume',       value_col,      labels['volume'],       f'{prefix}exports_value_product.png',  is_usd),
            ('rca',          'rca',          labels['rca'],          f'{prefix}exports_rca_product.png',    False),
            ('market_share', 'market_share', labels['market_share'], f'{prefix}exports_ms_product.png',     False),
            ('pci',          'pci',          labels['pci'],          f'{prefix}exports_pci_product.png',    False),
        ]
        for cat, xcol, xlab, fname, usd in specs:
            data = top_products[cat]
            if not data.empty:
                self._bar(data, xcol, xlab, fname, is_usd=usd)

    def plot_sector_bars(self, sector_metrics, value_col, labels):
        prefix = 'net_' if value_col == 'net_export' else ''
        name = self.config['data']['country_name']
        is_usd = value_col in ('export_value', 'net_export')
        for metric, lab, fname, usd in [
            (value_col,      labels['volume'],       f'{prefix}sector_value.png', is_usd),
            ('rca',          labels['rca'],          f'{prefix}sector_rca.png',   False),
            ('market_share', labels['market_share'], f'{prefix}sector_ms.png',    False),
            ('pci',          labels['pci'],          f'{prefix}sector_pci.png',   False),
        ]:
            d = sector_metrics.sort_values(metric, ascending=False)
            self._bar(d, metric, lab, fname, is_usd=usd, y_col='sector_name',
                      title=f"{name} Sectors by {lab}")

    def _bar(self, data, x, xlabel, filename, is_usd=False,
             y_col='product_name', title=None):
        # filter non-finite values to avoid posx/posy warnings
        data = data[np.isfinite(data[x])].copy()
        if data.empty:
            return

        data = data.sort_values(x, ascending=True)  # ascending for horizontal bars
        n = len(data)
        fig_h = max(4, 0.55 * n + 1.8)
        fig, ax = plt.subplots(figsize=(12, fig_h))

        colors = [self.colors.get(s, '#CCC') for s in data['sector_name']]
        y_pos = np.arange(n)

        ax.barh(y_pos, data[x].values, color=colors, height=0.65,
                edgecolor='white', linewidth=0.5, zorder=3)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(data[y_col].values, fontsize=self.FONT_TICK)

        # value annotations outside bars
        x_max = data[x].max()
        for i, val in enumerate(data[x].values):
            label = self._fmt_value(val, is_usd=is_usd)
            offset = x_max * 0.01
            ax.text(val + offset, i, label,
                    va='center', ha='left', fontsize=self.FONT_ANNOT,
                    color=self.ANNOT_COLOR, fontweight='medium')

        # expand x-axis slightly for annotation room
        ax.set_xlim(left=min(0, data[x].min()), right=x_max * 1.15)

        ax.set_xlabel(xlabel, fontsize=self.FONT_AXIS, color=self.TEXT_DARK)
        ax.set_title(title or f"Top {n} by {xlabel}",
                     fontsize=self.FONT_TITLE, fontweight='semibold',
                     color=self.TEXT_DARK, pad=12)
        ax.tick_params(axis='y', length=0)
        ax.grid(axis='x', alpha=0.2)
        ax.grid(axis='y', visible=False)

        sectors = data['sector_name'].unique()
        self._sector_legend(ax, sectors, loc='lower_right')

        fig.tight_layout()
        self._save(fig, filename)

    # ======================================================================
    # TIME TRENDS
    # ======================================================================
    def plot_time_trends(self, sector_metrics, value_col):
        # FIX: prefix ALL filenames for net export to avoid overwriting gross
        prefix = 'net_' if value_col == 'net_export' else ''
        rw = self.config['analysis']['rolling_window']

        for ycol, title, ylab, fname in [
            ('rolling_value',        'Export Volume',      value_col.replace('_', ' ').title(),
             f'{prefix}export_value_trends.png'),
            ('rolling_pci',          'Product Complexity', 'Weighted PCI',
             f'{prefix}pci_trends.png'),
            ('rolling_rca',          'Revealed Comparative Advantage', 'RCA Score',
             f'{prefix}rca_trends.png'),
            ('rolling_market_share', 'Global Market Share', 'Market Share (%)',
             f'{prefix}market_share_trends.png'),
        ]:
            fig, ax = plt.subplots(figsize=(14, 6.5))

            for sector in sorted(sector_metrics['sector_name'].unique()):
                sd = sector_metrics[sector_metrics['sector_name'] == sector].sort_values('year')
                c = self.colors.get(sector, '#CCC')
                ax.plot(sd['year'], sd[ycol], color=c, linewidth=2.2,
                        marker='o', markersize=4, markeredgecolor='white',
                        markeredgewidth=0.8, label=sector, zorder=3)
                ax.fill_between(sd['year'], sd[ycol], alpha=0.04, color=c, zorder=1)

            ax.set_title(f'{title} ({rw}-Year Rolling Average)',
                         fontsize=self.FONT_TITLE, fontweight='semibold',
                         color=self.TEXT_DARK, pad=12)
            ax.set_ylabel(ylab, fontsize=self.FONT_AXIS, color=self.TEXT_DARK)
            ax.set_xlabel('')
            ax.tick_params(axis='both', labelsize=self.FONT_TICK)

            sectors = sorted(sector_metrics['sector_name'].unique())
            self._sector_legend(ax, sectors, style='line', loc='outside_right')

            fig.tight_layout()
            self._save(fig, fname)

    # ======================================================================
    # OPPORTUNITY SCATTER
    # ======================================================================
    def plot_opportunities(self, top_opps, strategy_titles):
        for strat, data in top_opps.items():
            if data.empty:
                continue
            self._opp_scatter(data, strategy_titles[strat])

    def _opp_scatter(self, data, title):
        fig, ax = plt.subplots(figsize=(13, 8))

        # size scaling
        gmin, gmax = data['global_export_value'].min(), data['global_export_value'].max()
        rng = gmax - gmin if gmax > gmin else 1
        sizes = 120 + (data['global_export_value'] - gmin) / rng * 1400

        for sector in data['sector_name'].unique():
            sd = data[data['sector_name'] == sector]
            idx = sd.index
            ax.scatter(sd['norm_distance'], sd['norm_pci'],
                       s=sizes.loc[idx], color=self.colors.get(sector, '#CCC'),
                       alpha=0.65, edgecolors='white', linewidth=0.8, zorder=3)

        # labels
        texts = []
        for _, row in data.iterrows():
            texts.append(ax.text(
                row['norm_distance'], row['norm_pci'],
                row['product_name'], fontsize=8, color=self.TEXT_DARK,
                ha='center', va='center',
                bbox=dict(facecolor='white', edgecolor='none',
                          alpha=0.85, pad=1.2, boxstyle='round,pad=0.3')
            ))
        try:
            adjust_text(texts, arrowprops=dict(arrowstyle='-', color=self.TEXT_LIGHT, lw=0.6),
                        expand_points=(1.3, 1.5), expand_text=(1.2, 1.4),
                        force_text=(0.5, 0.8))
        except Exception:
            pass

        ax.set_title(title, fontsize=self.FONT_TITLE, fontweight='semibold',
                     color=self.TEXT_DARK, pad=14)
        ax.set_xlabel('Normalized Distance (lower = closer to current exports)',
                      fontsize=self.FONT_AXIS, color=self.TEXT_DARK)
        ax.set_ylabel('Normalized Product Complexity',
                      fontsize=self.FONT_AXIS, color=self.TEXT_DARK)

        sectors = data['sector_name'].unique()
        self._sector_legend(ax, sectors, style='patch', loc='outside_right')

        fig.tight_layout()
        fname = f"{title.replace(' ', '_')}.png"
        self._save(fig, fname)

    # ======================================================================
    # TREEMAP
    # ======================================================================
    def create_treemap(self, country_df, year):
        if not self.config['output']['create_treemap']:
            return
        dy = country_df[country_df['year'] == year].copy()
        if dy.empty:
            print(f"Warning: no data for treemap year {year}")
            return
        dy = dy.dropna(subset=['export_value'])

        stot = dy.groupby('sector_name')['export_value'].sum().reset_index()
        total = stot['export_value'].sum()
        stot['sector_share'] = (stot['export_value'] / total) * 100

        dm = dy.merge(stot[['sector_name', 'sector_share']], on='sector_name')
        dm['product_share_in_sector'] = (
            dm['export_value'] /
            dm.groupby('sector_name')['export_value'].transform('sum')
        ) * 100
        dm['overall_share'] = (dm['export_value'] / total) * 100
        dm['parent'] = dm['sector_name']

        # convert hex palette to rgb() strings for plotly
        def hex_to_rgb(h):
            h = h.lstrip('#')
            return f'rgb({int(h[:2],16)},{int(h[2:4],16)},{int(h[4:6],16)})'

        color_map = {s: hex_to_rgb(c) for s, c in self.colors.items()}

        fig = px.treemap(
            dm, path=['parent', 'product_name'], values='export_value',
            color='sector_name', color_discrete_map=color_map,
            hover_data={
                'sector_share': ':.1f%',
                'product_share_in_sector': ':.1f%',
                'overall_share': ':.1f%',
                'export_value': ':,',
            },
        )
        name = self.config['data']['country_name']
        fig.update_layout(
            title={'text': f'{year} {name} Export Composition',
                   'y': 0.95, 'x': 0.5, 'xanchor': 'center', 'yanchor': 'top',
                   'font': {'size': 20, 'color': self.TEXT_DARK, 'family': 'Arial'}},
            margin=dict(t=100, l=0, r=0, b=0),
            paper_bgcolor=self.BG,
            font={'family': 'Arial', 'color': self.TEXT_DARK},
        )
        fname = f'exports_{year}_treemap.html'
        fig.write_html(os.path.join(self.out_dir, fname))
        print(f"  saved: {fname}")

In [8]:
# ============================================================================
# MAIN PIPELINE
# ============================================================================

class TradeAnalysisPipeline:
    """Orchestrates the full analysis workflow."""

    def __init__(self, config):
        self.config  = config
        self.loader  = TradeDataLoader(config)
        self.metrics = MetricsCalculator(config)
        self.opps    = OpportunityAnalyzer(config)

    # ------------------------------------------------------------------
    def run(self):
        print("=" * 70)
        print(f"TRADE ANALYSIS: {self.config['data']['country_name']}")
        print("=" * 70)

        country_df, global_df, sector_mapping, sector_colors = self.loader.load_data()
        self.viz = Visualizer(self.config, sector_colors)

        base_year    = self.config['analysis']['base_year']
        value_metrics = self.config['analysis']['value_metrics']

        country_year = country_df[country_df['year'] == base_year].merge(sector_mapping)
        global_year  = global_df[global_df['year'] == base_year].merge(sector_mapping)

        for vcol in value_metrics:
            self._analyze_metric(country_df, global_df, country_year,
                                 global_year, vcol, base_year)

        # opportunity analysis (gross exports only)
        if self.config['opportunities']['enabled']:
            print(f"\n{'='*70}\nNEW EXPORT OPPORTUNITIES\n{'='*70}")
            opps, top = self.opps.analyze_opportunities(country_year, global_year)
            if not opps.empty:
                titles = {
                    'lhf': 'Low Hanging Fruit Strategy',
                    'bs':  'Balanced Strategy',
                    'lj':  'Long Jumps Strategy',
                }
                self.viz.plot_opportunities(top, titles)
                if self.config['output']['export_csv']:
                    self._export_csv(top)

        treemap_yr = self.config['output']['treemap_year']
        self.viz.create_treemap(country_df, treemap_yr)

        print(f"\n{'='*70}\nANALYSIS COMPLETE\n{'='*70}")
        print(f"Outputs in: {self.config['output']['directory']}")

    # ------------------------------------------------------------------
    def _analyze_metric(self, country_df, global_df, country_year,
                        global_year, value_col, base_year):
        name = value_col.replace('_',' ').title()
        print(f"\n{'='*70}\nANALYZING: {name}\n{'='*70}")

        gref = self.metrics.prepare_global_reference(
            country_year, global_year, base_year, value_col)

        pmets = self.metrics.calculate_product_metrics(
            country_year, gref, value_col)

        top = self._top_products(pmets, value_col)
        labels = {
            'volume':       f'{name} (USD)',
            'rca':          'Revealed Comparative Advantage',
            'market_share': 'Global Market Share (%)',
            'pci':          'Product Complexity Index',
        }
        self.viz.plot_product_bars(top, value_col, labels)

        # sector aggregation
        # FIX: safe weighted average for PCI that handles edge cases
        def safe_wpci(grp):
            w = pmets.loc[grp.index, value_col].clip(lower=0)
            if w.sum() == 0:
                return np.nan
            return np.average(grp, weights=w)

        smets = pmets.groupby('sector_name').agg({
            value_col:       'sum',
            'global_value':  'sum',
            'rca':           'mean',
            'pci':           safe_wpci,
            'pci_percentile':'mean',
        }).reset_index()
        smets['market_share'] = (smets[value_col] / smets['global_value']) * 100
        self.viz.plot_sector_bars(smets, value_col, labels)

        # time series
        ts = self.metrics.calculate_sector_metrics_timeseries(
            country_df, global_df, value_col)
        self.viz.plot_time_trends(ts, value_col)

    # ------------------------------------------------------------------
    def _top_products(self, pmets, value_col):
        n   = self.config['analysis']['top_n']
        thr = self.config['advanced']['min_export_threshold']
        def top(df, col):
            if df.empty or col not in df.columns:
                return pd.DataFrame()
            return df.sort_values(col, ascending=False).head(n).reset_index(drop=True)
        return {
            'volume':       top(pmets, value_col),
            'rca':          top(pmets[pmets.rca > 1], 'rca'),
            'market_share': top(pmets, 'market_share'),
            'pci':          top(pmets[pmets[value_col] > thr], 'pci'),
        }

    # ------------------------------------------------------------------
    def _export_csv(self, top_opps):
        out = self.config['output']['directory']
        for strat, data in top_opps.items():
            if data.empty:
                continue
            cols = ['product_name','sector_name','pci','distance','cog', strat]
            fname = f'opportunities_{strat}.csv'
            data[cols].to_csv(os.path.join(out, fname), index=False)
            print(f"  saved: {fname}")


In [9]:
# ============================================================================
# RUN
# ============================================================================

pipeline = TradeAnalysisPipeline(CONFIG)
pipeline.run()


TRADE ANALYSIS: Italy
Loading data from: /Users/leoss/Desktop/Portfolio/Website-/projects/export/data/atlas_2022.dta


KeyboardInterrupt: 

## Summary

**Methods:** The pipeline loads Atlas of Economic Complexity data (HS-92, 4-digit product level) for Italy and all other countries. It computes product-level and sector-level metrics (export volume, net exports, RCA, global market share, PCI) for the base year (2022), and time-series trends (5-year rolling averages, 1995-2022). The opportunity analysis identifies products where Italy currently lacks comparative advantage (RCA < 1) but which exceed a minimum complexity threshold (PCI > 1.49), scoring them under three strategies with different distance/complexity/opportunity-gain weights.

**Strategy formulas:**
- Low-hanging fruit: 0.75 * proximity + 0.10 * normalized PCI + 0.15 * normalized COG
- Balanced: 0.50 * proximity + 0.25 * normalized PCI + 0.25 * normalized COG
- Strategic bets: 0.40 * proximity + 0.20 * normalized PCI + 0.40 * normalized COG

**Outputs:** 20+ PNG charts (product bars, sector bars, time trends, opportunity scatters), 3 opportunity CSVs, 1 interactive treemap (HTML). All saved to the configured output directory.

**Key fixes in this version:**
1. Strategy weights aligned to report formulas (pci/cog were swapped in LHF; balanced and long jumps had entirely different weights)
2. Product names resolved via external CSV lookup (1,241 products) instead of 30-entry manual dict
3. Distance reversal uses standard `1 - norm` formula instead of non-standard negative offset
4. PCI weighting in time series uses `export_value` (always positive) to avoid zero-weight edge cases with net exports
5. Sector mapping filtered to base year to prevent merge duplicates
