In [4]:
# comtrade_api.py

import requests
import json
from collections import defaultdict
import time
from typing import List, Dict, Tuple
import logging
import pandas as pd
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ComtradeAPI:
    BASE_URL = "https://comtradeapi.un.org/data/v1/get/C/M/HS"
    
    def __init__(self, api_key: str = None):
        self.api_key = api_key or os.getenv('API_KEY')
        if not self.api_key:
            raise ValueError("API key is required. Set it as COMTRADE_API_KEY environment variable or pass it to the constructor.")
        self.headers = {
            "Cache-Control": "no-cache",
            "Ocp-Apim-Subscription-Key": self.api_key
        }
    
    def fetch_data(self, reporter_code: str, period: str) -> Dict:
        params = {
            "reporterCode": reporter_code,
            "period": period,
            "partnerCode": "0",  # 0 for World (all partners)
            "cmdCode": "All",    # To get all HS codes
            "flowCode": "M,X"    # M for imports, X for exports
        }
        
        try:
            response = requests.get(self.BASE_URL, params=params, headers=self.headers)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.HTTPError as e:
            if response.status_code == 429:
                logger.warning("Rate limit exceeded. Waiting for 3 seconds before retrying.")
                time.sleep(3)
                return self.fetch_data(reporter_code, period)
            else:
                logger.error(f"HTTP Error: {e}")
        except requests.exceptions.RequestException as e:
            logger.error(f"Request failed: {e}")
        return None

def process_data(data: Dict) -> Tuple[Dict[str, float], Dict[str, float]]:
    import_values = defaultdict(float)
    export_values = defaultdict(float)
    
    for item in data.get('data', []):
        hs_code = item.get('cmdCode', '')[:2]  # Get the first 2 digits of the HS code
        trade_value = item.get('primaryValue', 0)
        
        if item.get('flowCode') == 'M':
            import_values[hs_code] += trade_value
        elif item.get('flowCode') == 'X':
            export_values[hs_code] += trade_value
    
    return import_values, export_values

def create_dataframe(results: List[Dict]) -> pd.DataFrame:
    df = pd.DataFrame(results)
    
    # Melt the DataFrame to create separate rows for imports and exports
    df_melted = df.melt(id_vars=['Country', 'Period', 'HS_Code'], 
                        var_name='Flow', 
                        value_name='Value')
    
    # Split the 'Flow' column into separate 'Flow' and 'Metric' columns
    df_melted[['Flow', 'Metric']] = df_melted['Flow'].str.split('_', expand=True)
    
    # Pivot the DataFrame to have 'Import' and 'Export' as separate columns
    df_final = df_melted.pivot_table(values='Value', 
                                     index=['Country', 'Period', 'HS_Code', 'Metric'],
                                     columns='Flow', 
                                     fill_value=0).reset_index()
    
    # Rename columns for clarity
    df_final.columns.name = None
    df_final = df_final.rename(columns={'M': 'Import', 'X': 'Export'})
    
    return df_final

def fetch_and_process_data(api: ComtradeAPI, country_codes: List[str], periods: List[str]) -> pd.DataFrame:
    results = []
    
    for country_code in country_codes:
        for period in periods:
            logger.info(f"Fetching data for country code {country_code} and period {period}")
            data = api.fetch_data(country_code, period)
            
            if data:
                import_values, export_values = process_data(data)
                
                for hs_code in set(list(import_values.keys()) + list(export_values.keys())):
                    results.append({
                        'Country': country_code,
                        'Period': period,
                        'HS_Code': hs_code,
                        'Import_Value': import_values.get(hs_code, 0),
                        'Export_Value': export_values.get(hs_code, 0)
                    })
            else:
                logger.error(f"Failed to fetch data for country code {country_code} and period {period}")
            
            # Add a delay to avoid hitting rate limits
            time.sleep(4)
    
    return create_dataframe(results)

def main(country_codes: List[str], periods: List[str]) -> pd.DataFrame:
    api = ComtradeAPI()
    return fetch_and_process_data(api, country_codes, periods)


In [6]:
if __name__ == "__main__":
    # Example usage
    country_codes = ["894", "716", "288", "566"]  # Zambia, Zimbabwe, Ghana, Nigeria
    periods = ["202201", "202202"]
    
    result_df = main(country_codes, periods)
    print(result_df.head())
    #result_df.to_csv("comtrade_data.csv", index=False)
    #logger.info("Data has been saved to comtrade_data.csv")

2024-08-13 13:38:41,570 - INFO - Fetching data for country code 894 and period 202201
2024-08-13 13:39:03,052 - INFO - Fetching data for country code 894 and period 202202
2024-08-13 13:39:22,600 - INFO - Fetching data for country code 716 and period 202201
2024-08-13 13:39:30,201 - INFO - Fetching data for country code 716 and period 202202
2024-08-13 13:39:42,596 - INFO - Fetching data for country code 288 and period 202201
2024-08-13 13:39:58,944 - INFO - Fetching data for country code 288 and period 202202
2024-08-13 13:40:28,866 - INFO - Fetching data for country code 566 and period 202201
2024-08-13 13:40:40,511 - INFO - Fetching data for country code 566 and period 202202


  Country  Period HS_Code Metric        Export        Import
0     288  202201      01  Value  1.164238e+05  2.705067e+07
1     288  202201      02  Value  7.194600e+02  2.750473e+08
2     288  202201      03  Value  2.726731e+07  1.813444e+08
3     288  202201      04  Value  1.874746e+07  1.051209e+08
4     288  202201      05  Value  9.241175e+04  4.951798e+07
