In [1]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os
from typing import Optional

## Load data

In [2]:
def load(folder: str, debug: bool = False) -> pd.DataFrame:
    """
    Load the tsv data into a single DataFrame

    Parameters:
        - `tsv_folder` - the location of all the .tsv files
        - `debug` - if `True` the `print` messages will be displayed

    Output:
        - the whole DataFrame
    """
    
    all_data = []
    for file_name in os.listdir(folder):

        file_path = os.path.join(folder, file_name)
        data = pd.read_csv(file_path, sep="\t").drop_duplicates().reset_index(drop=True)

        # There is no useful information with trade_value 0
        query = data["trade_value"] != 0
        data = data.loc[query].reset_index(drop=True).copy()
        if len(data) == 0:
            
            if debug: 
                print(f"No data found for:\t{file_name}")

            continue

        data["month_id"] = pd.to_datetime(data["month_id"].astype(str), format="%Y%m").dt.to_period("M")

        # The product ID is actually the HS6 - HS6, denoted by the first six digits of the HS code, provides the most detailed classification level.
        data["product_id"] = data["product_id"].astype(str).apply(lambda value: f"0{value}" if len(value) == 5 else value)

        # Creating HS2 and HS4 are based on - https://www.icustoms.ai/blogs/hs-code/
        data = data.assign(
            hs2_id = data["product_id"].str[:2], # The first two digits of the HS code, or “HS2,” provide a wide range of product classifications.
            hs4_id = data["product_id"].str[:4] # Going on to HS4, this level of classification offers a more precise identification and includes the first four digits of the HS code.
        ).rename(columns={"product_id": "hs6_id"})\
        .drop(["product_name"], axis=1)
        
        data = data.rename(columns={"month_id": "period"})
        data.insert(0, "year", data["period"].dt.year)
        data.insert(1, "month", data["period"].dt.month)

        all_data.append(data.copy())

        if debug:
            print(f"Loaded:\t{file_name}")

    return pd.concat(all_data, ignore_index=True).drop_duplicates().reset_index(drop=True)

In [3]:
def split(data: pd.DataFrame, debug: bool = False) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Split the data in imports and exports.

    Parameters:
        - `data` - the data from `load`
        - `debug` - if `True` the `print` messages will be displayed

    Output:
        - the data for Imports
        - the data for Exports
    """
    # Split the data in Imports and Exports
    query = data["trade_flow_name"] == "Imports"

    imports = data.loc[query].reset_index(drop=True).copy()
    exports = data.loc[~query].reset_index(drop=True).copy()

    if debug:
        print(f"Data:    {data.shape[0]:,}")
        print(f"Exports: {exports.shape[0]:,}")
        print(f"Imports: {imports.shape[0]:,}")

    return imports.copy(), exports.copy()

In [4]:
# Change folder
folder = r"C:\Users\Nikolai\Documents\GitHub\Valor-Real-Estate-Partners-Data-Engineer\US Trade flow"
data = load(folder)
imports, exports = split(data, True)

us_state_mapping = pd.read_csv("us_codes.csv").set_index("state")["code"].to_dict()

query = ~data["state_name"].isin(list(us_state_mapping.keys()))
non_states = data.loc[query, "state_name"].drop_duplicates().to_list()

Data:    18,984,187
Exports: 9,336,951
Imports: 9,647,236


In [5]:
COLORS = {
    "platinum": "#E9E9E5",
    "ash_gray": "#9EAA91",
    "olivine": "#91AC66",
    "feldgrau": "#3C503C",
}

## 1.1 How does the total export and import change by different State (`state_name`) from 2021-2022? 

In [36]:
def create_usa_fig(data: pd.DataFrame, trade_type: str, year: Optional[int] = None):

    local = data.dropna().reset_index(drop=True).copy()

    if year:
        local = local.loc[local["year"] == year].reset_index(drop=True).copy()

    column_name = "trade_value" if year else "change"

    threshold = round(local[column_name].max() * 0.95)

    fig = px.choropleth(
        local.rename(columns={column_name: "Trade Value"}),
        locations="state_code",          
        locationmode="USA-states",       
        color="Trade Value",             
        hover_name="state_name",
        scope="usa",                     
        title=f"US States {trade_type.title()} Trade Value" + (f" for {year}" if year else " change"),
        color_continuous_scale=list(COLORS.values())
    )

    fig.update_layout(
        title={
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top"
        },
        title_font_size=40,
        width=1600,
        height=1200,
    )

    for i, row in local.iterrows():
        fig.add_trace(go.Scattergeo(
            locationmode='USA-states',
            locations=[row["state_code"]],
            text=row["state_code"],
            mode='text',
            textfont=dict(
                size=12,
                color="white" if row[column_name] > threshold else "black"
            ),
            showlegend=False
        ))

    fig.show()

In [None]:
def create_bar(df: pd.DataFrame, trade_type: str):

    max_change = max(abs(df['change'].min()), df['change'].max())

    fig = px.bar(
        df,
        x="state_name",
        y="change",
        text="change"
    )

    fig.update_layout(
        title={
            'text': f"Change in Trade Value by {trade_type.title()}",
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        title_font_size=20,
        xaxis_title="Non US state",
        yaxis_title="Change (%)",
        width=800,
        height=500,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)',
        yaxis=dict(range=[-max_change, max_change])
    )

    fig.update_traces(
        marker_color=COLORS["ash_gray"],
        texttemplate='%{text:.2f}%',
        insidetextanchor='middle',
        textfont=dict(color="black")
    )

    fig.show()


In [22]:
def get_difference(data: pd.DataFrame, pct_change: bool = False) -> pd.DataFrame:

    local = data.pivot(index="state_name", columns="year", values="trade_value").reset_index().copy()

    local.columns.name = ""

    if pct_change:
        local["change"] = (((local[2022] / local[2021]) - 1) * 100).round(2)
    else:
        local["change"] = local[2022] - local[2021]
        
    return local[["state_name", "change"]].copy()

In [23]:
to_yearly_data = lambda df: df.groupby(["state_name", "year"])["trade_value"]\
                                .sum()\
                                .reset_index()\
                                .assign(
                                    state_code=lambda x: x["state_name"].map(us_state_mapping),
                                )\
                                .sort_values("year")\
                                .reset_index(drop=True)

yearly_imports: pd.DataFrame = to_yearly_data(imports)
yearly_exports: pd.DataFrame = to_yearly_data(exports)



In [24]:
groups = {
    "exports": yearly_exports.groupby("year"),
    "imports": yearly_imports.groupby("year"),
}

# for trade_type, groupby in groups.items():

#     for year, group in groupby:
#         create_usa_fig(group, trade_type, year)

In [28]:
groups = {
    "exports": yearly_exports,
    "imports": yearly_imports
}

for trade_type, group in groups.items():
    difference = get_difference(group, True)

    # create_bar(difference.loc[difference["state_name"].isin(non_states)].reset_index(drop=True).copy(), trade_type)
    # difference = difference.assign(
    #     state_code = difference["state_name"].map(us_state_mapping)
    # ).dropna().reset_index(drop=True)

    # create_usa_fig(difference, trade_type)

## 1.2 How does the total export and import change by HS2 level products (use map table to map `product_id` to HS2 level) in US from 2021-2022?

In [6]:
def load_hs_mappings(file_path: str = "HS mapping.xlsx") -> pd.DataFrame:
    """
    Load the HS mapping and fix the HS IDs based on https://www.icustoms.ai/blogs/hs-code/

    Parameters:
        - `file_path` - the location of the hs mapping

    Output:
        - the hs mappings 
    """
    hs_mappings = pd.read_excel(file_path)
    hs_mappings.columns = [column.lower().replace(" ", "_") for column in hs_mappings.columns]
    hs_mappings["hs6_id"] = hs_mappings["hs6_id"].astype(str).apply(lambda value: f"0{value}" if len(value) == 5 else value)

    # The first two digits of the HS code, or “HS2,” provide a wide range of product classifications.
    hs_mappings["hs2_id"] = hs_mappings["hs6_id"].str[:2]

    # Going on to HS4, this level of classification offers a more precise identification and includes the first four digits of the HS code.
    hs_mappings["hs4_id"] = hs_mappings["hs6_id"].str[:4]

    return hs_mappings.drop("product_id", axis=1).copy()

In [22]:
def get_hs2_difference(data: pd.DataFrame, pct_change: bool = False) -> pd.DataFrame:

    columns = ["year", "hs2_id", "trade_value"]
    local = data[columns].groupby(columns[:2]).sum().reset_index().copy()
    local = local.pivot(index="hs2_id", columns="year", values="trade_value").reset_index().copy()

    local.columns.name = ""

    if pct_change:
        local["change"] = (((local[2022] / local[2021]) - 1) * 100).round(2)
    else:
        local["change"] = local[2022] - local[2021]
        
    return local[["hs2_id", "change"]].copy()

In [53]:
def create_bar(df: pd.DataFrame, trade_type: str):

    max_change = max(abs(df['change'].min()), df['change'].max())

    fig = px.bar(
        df,
        x="section",
        y="change",
        text="change"
    )

    fig.update_layout(
        title={
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top"
        },
        title_font_size=40,
        width=1600,
        height=1200,
        xaxis_title="Non US state",
        yaxis_title="Change (%)",
        # plot_bgcolor='rgba(0,0,0,0)',
        # paper_bgcolor='rgba(0,0,0,0)',
        yaxis=dict(range=[-max_change, max_change])
    )

    fig.update_traces(
        marker_color=COLORS["ash_gray"],
        texttemplate='%{text:.2f}%',
        insidetextanchor='middle',
        textfont=dict(color="black")
    )

    fig.show()


In [82]:
def create_bar(df: pd.DataFrame, title: str):

    # Create a horizontal bar chart
    fig = px.bar(
        df,
        x="change",   # Set x to "change" to represent the bar length
        y="section",  # Set y to "section" to represent each category on the y-axis
        text="change",
        orientation='h'  # Specify horizontal orientation
    )

    fig.update_layout(
        title={
            "text": title,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
            "pad": {"t": 20, "b": 30}
        },
        title_font_size=40,
        width=1600,
        height=1200,
        xaxis_title="Change (%)",
        # yaxis_title="Section",
        yaxis=dict(autorange="reversed"), 
        yaxis_tickangle=0, 
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)',
    )

    fig.update_traces(
        marker_color=COLORS["ash_gray"],
        texttemplate='%{text:.2f}%',
        textposition='inside',
        insidetextanchor='middle',
        textfont=dict(color="black")
    )

    fig.show()


In [83]:
hs_mappings = load_hs_mappings()
get_us_data = lambda df: df.loc[df["state_name"].isin(us_state_mapping.keys())].reset_index(drop=True)

In [85]:
for trade_type, group in {"exports": get_us_data(exports).copy(), "imports": get_us_data(imports).copy()}.items():

    temp = get_hs2_difference(group, True).merge(hs_mappings[["hs2_id", "hs2", "section"]].drop_duplicates(), "left", "hs2_id").copy()

    print(trade_type)
    # create_bar(temp.groupby("section")["change"].sum().reset_index(), "" ) #f"US States {trade_type.title()} change by Section"

exports
imports


## 1.3 What are the top 10 partner countries (`country_name`) for import and export in 2022 and how its share (as part of total import and export value) changed YoY

In [6]:
def get_difference(data: pd.DataFrame, pct_change: bool = False) -> pd.DataFrame:

    columns = ["year", "country_name", "trade_value"]
    local = data[columns].groupby(columns[:2]).sum().reset_index().copy()
    local = local.pivot(index="country_name", columns="year", values="trade_value").reset_index().copy()

    local.columns.name = ""

    if pct_change:
        local["change"] = (((local[2022] / local[2021]) - 1) * 100).round(2)
    else:
        local["change"] = local[2022] - local[2021]
        
    return local[["country_name", "change"]].copy()

In [7]:
def get_data(data: pd.DataFrame) -> pd.DataFrame:

    return get_difference(data, True).merge(pd.read_csv("country_codes.csv"), "left", "country_name").sort_values("change").dropna().reset_index(drop=True).copy()
   

In [73]:
def create_world_map(data: pd.DataFrame, trade_type: str):
    
    # Make a copy of the dataframe to avoid modifying the original
    df = data.copy()
    
    pd.concat([data.tail(10), data.head(10)]).sort_values("change", ascending=False).drop("country_code", axis=1).to_excel(f"data/{trade_type}.xlsx", index=False)

    # Define a color scale that assigns red to negative values and green to positive values
    fig = px.choropleth(
        df,
        locations="country_code",
        color="change",
        hover_name="country_name",
        title=f"World Map Showing Change in {trade_type.title()}",
        color_continuous_scale=["white", "green"],  # Red for negative, Green for positive
        range_color=(df['change'].min(), df['change'].max())  # Ensure color range covers the full data range
    )

    # Update layout for better presentation and larger size
    fig.update_layout(
        title_font_size=20,
        width=1400,  # Set width to make the map larger
        height=800,  # Set height to make the map larger
        geo=dict(showframe=False, showcoastlines=False),
        coloraxis_colorbar=dict(title="Change (%)"),
    )


    
    # Show the figure
    fig.show()

In [74]:
create_world_map(get_data(imports), "imports")

In [76]:
create_world_map(get_data(exports), "exports")