In [6]:
import pandas as pd
import numpy as np

from typing import List

import plotly.graph_objects as go

### Load Flat File

In [3]:
FILE = "../data/coffee_exports.csv"

data = pd.read_csv(FILE)
data.head()


Unnamed: 0,country,1990,1991,1992,1993,1994,1995,1996,1997,1998,...,2011,2012,2013,2014,2015,2016,2017,2018,2019,all_years
0,Angola,84,71,80,39,8,41,52,50,54,...,8,8,6,9,11,11,9,9,23,722
1,Bolivia (Plurinational State of),156,74,96,47,84,94,123,111,97,...,74,63,55,62,30,22,26,22,20,2291
2,Brazil,16936,21183,18791,17838,17273,14468,15251,16801,18144,...,33806,28549,31651,37335,37563,34269,30925,35637,40698,786432
3,Burundi,585,688,646,418,508,528,224,529,374,...,218,392,195,252,230,204,169,202,293,10770
4,Cameroon,2611,1752,1646,705,546,407,564,1368,746,...,490,622,272,375,390,281,245,287,250,23332


### Reformat Data

In [8]:
def re4mat(df:pd.DataFrame, index:List[str]=["country"])->pd.DataFrame:

    return df.set_index(index).stack().reset_index().rename(columns = {"level_1":"year", 0:"counts"})

data = re4mat(data)
data.head()

Unnamed: 0,country,year,counts
0,Angola,1990,84
1,Angola,1991,71
2,Angola,1992,80
3,Angola,1993,39
4,Angola,1994,8


### Filter out 'Total' and 'all_years'

In [13]:
# just to show the data we'd like to remove
display(data[data.country == "Total"].tail())

display(data[data.year == "all_years"].head())

def __filter_list_from_df(data:pd.DataFrame, col:str="", target_list:List[str]=[])->pd.DataFrame:

    # TODO: add test for is col in columns
    # TODO: add test for is data still present
    # TODO: add test for is list list, 
    # TODO: add test for does list contain values - case instead of test
    # TODO: add test for values in list being string

    return data[~data[col].isin(target_list)]

def filter_out_countries(data:pd.DataFrame, leave_out:List[str]=["Total"])->pd.DataFrame:

    return __filter_list_from_df(data, "country", leave_out)

def filter_out_years(data:pd.DataFrame, leave_out:List[str]=["all_years"])->pd.DataFrame:

    return __filter_list_from_df(data, "year", leave_out)

data = filter_out_countries(data)
data = filter_out_years(data)
data.head()



Unnamed: 0,country,year,counts
1731,Total,2016,121334
1732,Total,2017,119519
1733,Total,2018,126598
1734,Total,2019,131694
1735,Total,all_years,2836697


Unnamed: 0,country,year,counts
30,Angola,all_years,722
61,Bolivia (Plurinational State of),all_years,2291
92,Brazil,all_years,786432
123,Burundi,all_years,10770
154,Cameroon,all_years,23332


Unnamed: 0,country,year,counts
0,Angola,1990,84
1,Angola,1991,71
2,Angola,1992,80
3,Angola,1993,39
4,Angola,1994,8


### Visualize Data

In [21]:
countries = data.groupby("country").counts.sum().sort_values(ascending=False)
display(countries.pop("Brazil"))

countries.head()

786432

country
Viet Nam     415408
Colombia     331308
Indonesia    176799
India        115566
Guatemala    107739
Name: counts, dtype: int64

In [27]:

# high producers
fig = go.Figure()

# get a list of the countries to make each of the countries a trace
countries = data.groupby("country").counts.sum().sort_values(ascending=False).index.tolist()
i = 0
while i < 5:
    country = countries.pop(0)
    this_data = data[np.in1d(data.country, country)]

    fig.add_trace(go.Scatter(
        x = this_data.year,
        y = this_data.counts,
        name = country,
        opacity=0.5
    ))
    i+=1
fig.update_layout(plot_bgcolor="white").show()

In [28]:

fig = go.Figure()
while countries:
    country = countries.pop(0)
    this_data = data[np.in1d(data.country, country)]

    fig.add_trace(go.Scatter(
        x = this_data.year,
        y = this_data.counts,
        name = country,
        opacity=0.5
    ))
    i+=1
fig.update_layout(plot_bgcolor="white").show()

In [32]:
def n_lag(data:pd.DataFrame, lags:List[int]=[1, 3, 5, 7]):

    for this_lag in lags: data[f"lag_{this_lag}"] = data.groupby("country").counts.shift(this_lag)

    return data

data = n_lag(data)
data.head()

Unnamed: 0,country,year,counts,lag_1,lag_3,lag_5,lag_7
0,Angola,1990,84,,,,
1,Angola,1991,71,84.0,,,
2,Angola,1992,80,71.0,,,
3,Angola,1993,39,80.0,84.0,,
4,Angola,1994,8,39.0,71.0,,


In [34]:
corr_matrix = data.corr()
fig = go.Figure()

fig.add_trace(go.Heatmap(
    x = corr_matrix.columns,
    y = corr_matrix.index,
    z = corr_matrix.values
))
fig.show()

In [36]:
fig = go.Figure()
fig.add_trace(go.Box(
    x = data.country,
    y = data.counts
))