## UN Comtrade (Comtrade+) -> features mensais para Ureia (HS 310210)

Features típicas:
- China exportações (oferta)
- Índia importações (demanda / tenders)
- (opcional) Brasil importações (proxy de demanda local)

Fonte/API:
- Data API endpoint: https://comtradeapi.un.org/data/v1/get/C/{freq}/HS  (freq: A ou M)
- Anal usa period=`YYYY` e pode filtrar flowCode (`X=Export`, `M=Import`)

## Importando principais bibliotecas
---

In [None]:
import os
import time
from dotenv import load_dotenv
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, List

import comtradeapicall
import pandas as pd
import requests

load_dotenv() # Carrega variáveis de ambiente do arquivo .env

### Configurações
---

In [None]:
def _safe_mkdir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

CHN = 156
IND = 356
BRA = 76
WORLD = 0

HS_UREA = "310210"
DATA_DIR = os.path.abspath("data")
COMTRADE_DIR = os.path.join(DATA_DIR, "comtrade")

_safe_mkdir(DATA_DIR)
_safe_mkdir(COMTRADE_DIR)

@dataclass(frozen=True)
class ComtradeQuery:
    reporter_code: int
    partner_code: int = WORLD
    cmd_code: str = HS_UREA
    freq: str = "M"  # "A" ou "M"
    flow_code: Optional[str] = None  # "X" (export) ou "M" (import) ou None (ambos)
    customs_code: str = "C00"
    mot_code: int = 0
    partner2_code: int = 0
    include_desc: bool = True
    max_records: int = 5000  # suficiente p/ 1 cmd + world + 1 reporter

In [None]:
def fetch_wits_data(reporter: str, product: str, start_year: int, end_year: int, flow: str = 'M') -> pd.DataFrame:
    """
    Baixa dados de comércio da WITS API para um produto específico.
    
    Parâmetros:
    - reporter: Código ISO3 do país (e.g., 'IND' para Índia)
    - product: Código HS do produto (e.g., '310210' para Ureia)
    - start_year: Ano inicial
    - end_year: Ano final
    - flow: 'M' para import, 'X' para export
    
    Retorna DataFrame com os dados.
    """
    base_url = "https://wits.worldbank.org/API/V1/wits/WITSApiService.svc/data"
    dfs = []
    
    trade_flow = 'Import' if flow == 'M' else 'Export'
    
    for year in range(start_year, end_year + 1):
        params = {
            'Reporter': reporter,
            'Partner': 'WLD',  # World
            'Product': product,
            'Year': str(year),
            'TradeFlow': trade_flow,
            'Format': 'JSON'
        }
        
        try:
            response = requests.post(base_url, params=params, timeout=30)
            response.raise_for_status()
            data = response.json()
            
            if data and isinstance(data, list) and len(data) > 0:
                df = pd.DataFrame(data)
                dfs.append(df)
            else:
                print(f"Nenhum dado para {year}")
                
        except Exception as e:
            print(f"Erro ao buscar dados para {year}: {e}")
    
    if dfs:
        return pd.concat(dfs, ignore_index=True)
    else:
        return pd.DataFrame()

## Métodos auxiliares
---

In [None]:
def _get_env_subscription_key() -> Optional[str]:
    # Sugestão: export COMTRADE_SUBSCRIPTION_KEY="xxxx"
    return (
        os.getenv("COMTRADE_SUBSCRIPTION_KEY")
        or os.getenv("COMTRADE_PRIMARY")
        or os.getenv("COMTRADE_KEY")
    )


def _request_with_retries(
    url: str,
    params: dict,
    subscription_key: str,
    timeout: int = 60,
    max_retries: int = 6,
    backoff_base: float = 1.5,
) -> dict:
    """
    UN Comtrade API aceita subscription key por:
    - query param: subscription-key
    - ou header: Ocp-Apim-Subscription-Key
    """
    headers = {"Ocp-Apim-Subscription-Key": subscription_key}

    # Também envia no query string (exemplos da própria documentação)
    params = dict(params)
    params["subscription-key"] = subscription_key

    last_err = None
    for attempt in range(max_retries):
        try:
            r = requests.get(url, params=params, headers=headers, timeout=timeout)
            # rate limit / temporários
            if r.status_code in (429, 500, 502, 503, 504):
                raise requests.HTTPError(f"{r.status_code}: {r.text}", response=r)

            r.raise_for_status()
            data = r.json()

            # Alguns retornos vêm com campo "error"
            if isinstance(data, dict) and data.get("error"):
                raise RuntimeError(f"Comtrade API error: {data.get('error')}")
            return data

        except Exception as e:
            last_err = e
            sleep_s = (backoff_base ** attempt) + (0.1 * attempt)
            time.sleep(sleep_s)

    raise RuntimeError(f"Falha após {max_retries} tentativas. Último erro: {last_err}")


def fetch_comtrade_yearly(
    query: ComtradeQuery,
    start_year: str,
    end_year: str,
    subscription_key: Optional[str] = None,
    columns_to_keep: Optional[List[str]] = None,
) -> pd.DataFrame:
    """
    Baixa dados anuais (freq="A") via period=YYYY.

    Retorna DataFrame "long" com colunas principais da API.
    """
    if query.freq != "A":
        raise ValueError("fetch_comtrade_yearly exige query.freq == 'A'")

    key = subscription_key or _get_env_subscription_key()
    if not key:
        raise ValueError(
            "Subscription key não encontrada. Defina COMTRADE_SUBSCRIPTION_KEY no ambiente."
        )

    years = [str(y) for y in range(int(start_year), int(end_year) + 1)]

    dfs = []
    for yyyy in years:
        mydf = comtradeapicall.getFinalData(
            subscription_key=key,
            typeCode='C',
            freqCode='A',
            clCode='HS',
            period=yyyy,
            reporterCode=query.reporter_code,
            cmdCode=query.cmd_code,
            flowCode=None if query.flow_code is None else query.flow_code,
            partnerCode=None,
            partner2Code=None, 
            customsCode=None,
            motCode=None,
            maxRecords=query.max_records,
            format_output='JSON',
            aggregateBy=None,
            breakdownMode='classic',
            countOnly=None,
            includeDesc=True,
        )
        dfs.append(mydf)
    
    df = pd.concat(dfs, ignore_index=True)
    if columns_to_keep and not all(df.empty for df in dfs):
        df = df[columns_to_keep]

    return df


def fetch_comtrade_monthly(
    query: ComtradeQuery,
    start_year: str,
    end_year: str,
    subscription_key: Optional[str] = None,
    columns_to_keep: Optional[List[str]] = None,
    print_progess: bool = True,
) -> pd.DataFrame:
    """
    Baixa dados mensais (freq="M") via period=YYYYMM.

    Retorna DataFrame "long" com colunas principais da API.
    """
    if query.freq != "M":
        raise ValueError("fetch_comtrade_monthly exige query.freq == 'M'")

    key = subscription_key or _get_env_subscription_key()
    if not key:
        raise ValueError(
            "Subscription key não encontrada. Defina COMTRADE_SUBSCRIPTION_KEY no ambiente."
        )

    year_months = [str(y) + f"{m:02d}" for y in range(int(start_year), int(end_year) + 1) for m in range(1, 13)]

    dfs = []
    for yyyymm in year_months:
        file_path = os.path.join(COMTRADE_DIR, f"{query.reporter_code}_comtrade_{yyyymm}.csv")
        if os.path.exists(file_path):
            try:
                mydf = pd.read_csv(file_path)
            except pd.errors.EmptyDataError:
                mydf = pd.DataFrame()
        else:
            mydf = comtradeapicall.getFinalData(
                subscription_key=key,
                typeCode='C',
                freqCode='M',
                clCode='HS',
                period=yyyymm,
                reporterCode=query.reporter_code,
                cmdCode=query.cmd_code,
                flowCode=None if query.flow_code is None else query.flow_code,
                partnerCode=None,
                partner2Code=None, 
                customsCode=None,
                motCode=None,
                maxRecords=query.max_records,
                format_output='JSON',
                aggregateBy=None,
                breakdownMode='classic',
                countOnly=None,
                includeDesc=True,
            )
            mydf.to_csv(file_path, index=False)

        if print_progess:
            print(f"arquivo carregado/baixado em {file_path}")

        dfs.append(mydf)

    df = pd.concat(dfs, ignore_index=True)
    if columns_to_keep and not all(df.empty for df in dfs):
        df = df[columns_to_keep]

    return df

def _to_tonnes(df: pd.DataFrame) -> pd.Series:
    """
    Preferência: netWgt (kg). Fallback: qty se unidade for kg.
    Retorna tonelada métrica (t).
    """
    netwgt = pd.to_numeric(df.get("netWgt"), errors="coerce")  # kg
    qty = pd.to_numeric(df.get("qty"), errors="coerce")
    unit = df.get("qtyUnitAbbr")

    tonnes = netwgt / 1000.0

    # fallback: qty em kg
    if unit is not None:
        unit_is_kg = unit.astype(str).str.lower().isin(["kg", "kilogram", "kilograms"])
        tonnes = tonnes.where(~tonnes.isna(), qty.where(unit_is_kg) / 1000.0)

    return tonnes


def agg(df: pd.DataFrame, prefix: str) -> pd.DataFrame:
    if df.empty:
        return pd.DataFrame(columns=["date"])

    df = df.copy()

    columns = [
        'reporterDesc',
        'cmdDesc',
        'qtyUnitAbbr',
        'cifvalue',
    ]
    
    # Dynamically find and convert date column
    if "refPeriodId" in df.columns:
        df["date"] = pd.to_datetime(df["refPeriodId"].astype(str), format="%Y%m%d", errors="coerce")
    elif "period" in df.columns:
        df["date"] = pd.to_datetime(df["period"].astype(str), format="%Y%m", errors="coerce")
    else:
        raise ValueError("No suitable date column found (e.g., 'refPeriodId' or 'period')")
    
    df["tonnes"] = _to_tonnes(df)
    df["value_usd"] = pd.to_numeric(df.get("primaryValue"), errors="coerce")

    df = df.drop(columns=["qty", "netWgt", "primaryValue"])

    # Build aggregation dictionary
    agg_dict = {
        f"{prefix}_tonnes": ("tonnes", "sum"),
        f"{prefix}_value_usd": ("value_usd", "sum"),
    }
    for col in columns:
        if df[col].dtype == 'object':
            agg_dict[f"{prefix}_{col}"] = (col, "first")
        else:
            agg_dict[f"{prefix}_{col}"] = (col, "sum")

    out = (
        df.groupby(["date", "flowDesc"], as_index=False)
        .agg(**agg_dict)
    )

    return out


def build_urea_trade_features(
    start_yyyy: str,
    end_yyyy: str,
    subscription_key: Optional[str] = None,
    include_brazil: bool = True,
    columns_to_keep: Optional[List[str]] = None,
    print_progess: bool = True,
) -> pd.DataFrame:
    """
    Gera features mensais:
    - chn_exports_* (China -> World, Export)
    - ind_imports_* (India <- World, Import)
    - (opcional) bra_imports_* (Brazil <- World, Import)
    """

    # China imports and exports
    df_chn = fetch_comtrade_monthly(
        ComtradeQuery(reporter_code=CHN, partner_code=None, flow_code=None, freq="M"),
        start_yyyy,
        end_yyyy,
        subscription_key=subscription_key,
        columns_to_keep=columns_to_keep,
        print_progess=print_progess,
    )

    # India imports and exports -> comentado pois não retorna nada, foi usado WITS
    # df_ind = fetch_comtrade_monthly(
    #     ComtradeQuery(reporter_code=IND, partner_code=None, flow_code=None, freq="M"),
    #     start_yyyy,
    #     end_yyyy,
    #     subscription_key=subscription_key,
    #     columns_to_keep=columns_to_keep,
    #     print_progess=print_progess,
    # )

    # feat_chn = agg(df_chn, "CHN")
    # feat_ind = agg(df_ind, "IND")

    # feats = feat_chn.merge(feat_ind, on=["date", "flowDesc"], how="outer").sort_values("date")

    feats = agg(df_chn, "CHN")

    if include_brazil:
        df_bra_y = fetch_comtrade_monthly(
            ComtradeQuery(reporter_code=BRA, partner_code=None, flow_code=None, freq="M"),
            start_yyyy,
            end_yyyy,
            subscription_key=subscription_key,
            columns_to_keep=columns_to_keep,
            print_progess=print_progess,
        )
        feat_bra = agg(df_bra_y, "BRA")
        feats = feats.merge(feat_bra, on=["date", "flowDesc"], how="outer").sort_values("date")

    # # features derivadas úteis p/ EDA (sem "inventar causalidade"):
    # for col in [c for c in feats.columns if c.endswith("_tonnes")]:
    #     feats[f"{col}_mom_pct"] = feats[col].pct_change()  # variação mensal
    #     feats[f"{col}_yoy_pct"] = feats[col].pct_change(12)  # variação anual (12m)

    return feats


def merge_with_urea_price(df_urea_price: pd.DataFrame, df_trade_feat: pd.DataFrame) -> pd.DataFrame:
    """
    df_urea_price precisa ter uma coluna 'date' mensal (Timestamp),
    e uma coluna de preço (ex.: 'urea_price_usd_t').
    """
    out = df_urea_price.copy()
    out["date"] = pd.to_datetime(out["date"]).dt.to_period("M").dt.to_timestamp()
    df_trade_feat = df_trade_feat.copy()
    df_trade_feat["date"] = pd.to_datetime(df_trade_feat["date"]).dt.to_period("M").dt.to_timestamp()
    return out.merge(df_trade_feat, on="date", how="left")