In [1]:
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt

In [2]:
def remove_par(df, feat):
    df[feat] = df[feat].str.replace('(','')
    df[feat] = df[feat].str.replace(')','')
    
def string2int(df, feat):
    df[feat] = df[feat].astype(int)
    
def polish_trade_feature(df, feat):
    remove_par(df, feat)
    string2int(df, feat)

In [3]:
# correct the name of the country to facilitate the merge betwen sambanis and geodata
dic_country_conv = dict({
    "Northern Alliance (Afghanistan)*":"Afghanistan",
    "South Vietnam":"Republic of Vietnam",
    'UAE':'United Arab Emirates',
    'UNITA (Angola)*':'Angola',
    'DR Congo':'Democratic Republic of the Congo',
    'United States':'United States of America',
    'Iran (Islamic Republic of)':'Iran',
    'Cote d’Ivoire':'Ivory Coast',
    "Lao People's Democratic Republic":'Laos',
    "Moldova (the Republic of)":'Moldova',
    'Viet Nam':'Republic of Vietnam',
    'Russian Federation':'Russia',
    'Korea (the Republic of)':'South Korea',
    'Eswatini':'Swaziland',
    'Syrian Arab Republic':'Syria',
    'Taiwan (Province of China)':'Taiwan',
    'Tanzania (the United Republic of)':'Tanzania',
    'United Kingdom of Great Britain and Northern Ireland':'United Kingdom',
    'Venezuela (Bolivarian Republic of)':'Venezuela',
    
    "Yemen People's Republic":'Yemen',
    'Yemen Arab Republic':'Yemen',
    '???':'Czechoslovakia',
    "???":'Yugoslavia',
    "???":'Zanzibar',
    "???":'Kosovo',
    "???":'Macedonia'
})
#left a space between the countries that required manual handling

#update the name of the countries in the geodata to allow a merge
#geodata["country_name"].replace(to_replace=dic_country_conv, inplace=True)
#geodata["country_border_name"].replace(to_replace=dic_country_conv, inplace=True)

In [4]:
trade = pd.read_csv("trade_reg.csv",sep='\t')

cols_sambanis = ["warstds", "ager", "agexp", "anoc", "army85", "autch98", "auto4",
        "autonomy", "avgnabo", "centpol3", "cowcode", "coldwar", "decade1", "decade2",
        "decade3", "decade4", "dem", "dem4", "demch98", "dlang", "drel",
        "durable", "ef", "ef2", "ehet", "elfo", "elfo2", "etdo4590",
        "expgdp", "exrec", "fedpol3", "fuelexp", "gdpgrowth", "geo1", "geo2",
        "geo34", "geo57", "geo69", "geo8", "illiteracy", "incumb", "infant",
        "inst", "inst3", "life", "lmtnest", "ln_gdpen", "lpopns", "major", "manuexp", "milper",
        "mirps0", "mirps1", "mirps2", "mirps3", "nat_war", "ncontig",
        "nmgdp", "nmdp4_alt", "numlang", "nwstate", "oil", "p4mchg",
        "parcomp", "parreg", "part", "partfree", "plural", "plurrel",
        "pol4", "pol4m", "pol4sq", "polch98", "polcomp", "popdense",
        "presi", "pri", "proxregc", "ptime", "reg", "regd4_alt", "relfrac", "seceduc",
        "second", "semipol3", "sip2", "sxpnew", "sxpsq", "tnatwar", "trade",
        "warhist", "xconst", "year", "country"]

sambanis = pd.read_csv("sambanis_extented.csv", usecols=cols_sambanis)

#keep only the year value for date (drop 1-january, dummy)
#sambanis["year"] = pd.DatetimeIndex(sambanis["year"]).year

#civil_war contains only the observations with a civil war
civil_war = sambanis[sambanis["warstds"]==1]

In [6]:
trade.loc[trade.recipient=="Viet Nam"]["no_ordered"].sum()
trade["recipient"].replace(to_replace=dic_country_conv, inplace=True)
trade["supplier"].replace(to_replace=dic_country_conv, inplace=True)
civil_war['year'] = pd.DatetimeIndex(civil_war['year']).year

In [7]:
import plotly.express as px
import plotly.graph_objects as go

In [20]:
types = {'aircraft':0, 'helicopter':0, 'gun':0, 'missile':0,
         'tank':0, 'engine':0, 'car':0, 'ac':0, 'apc':0}

def check_weapon_type(import_df):
    import_df["weapon_description"] = import_df[
                "weapon_description"].str.lower()
    
    # Get the wepon descriptions for given country's weapon import
    desc = import_df.weapon_description 
    tmp = types.copy()
    for t in tmp:
        idx = desc.loc[desc.str.count(t) == 1].index
        
        if t != 'missile':
          
            id_tmp = []
            anti_idx = desc.loc[desc.str.count('missile') == 1].index
            [id_tmp.append(x) for x in idx if x not in anti_idx]
            idx = id_tmp

        count = import_df.loc[idx].no_ordered.sum()

        tmp[t] = count
    total = import_df['no_ordered'].sum()
    tfound =sum(tmp.values())
    tmp['other'] = total-tfound

    return tmp


def imp_number(df, country, onset=None, buffer=5):
    
    country_imp = df.loc[trade.recipient == country]

    polish_trade_feature(country_imp, 'no_ordered')
    polish_trade_feature(country_imp, 'year_order')
    
    
    country_imp_pre = country_imp.loc[(country_imp.year_order < onset) & 
                                (country_imp.year_order > onset - buffer)]
    country_imp_post = country_imp.loc[(country_imp.year_order >= onset) & 
                                (country_imp.year_order < onset + buffer)]
    
 
    pre_onset_order = check_weapon_type(country_imp_pre).copy()
    pre_onset_order["when"]='pre onset'
    #total = country_imp_pre['no_ordered'].sum()

    #tfound =sum(pre_onset_order.values())

    #pre_onset_order["other"] = total - tfound
    
    
    
    
    post_onset_order = check_weapon_type(country_imp_post).copy()
    post_onset_order["when"]='post onset'
    
    onset_orders = pd.DataFrame.from_dict([pre_onset_order]).append(
        pd.DataFrame.from_dict([post_onset_order]))
    
    return onset_orders

In [35]:
cambodia_import = imp_number(trade, "Cambodia", onset = 1970)
cambodia_import

Unnamed: 0,aircraft,helicopter,gun,missile,tank,engine,car,ac,apc,other,when
0,52,2,8,0,0,0,0,6,0,3,pre onset
0,87,38,43,0,0,0,0,107,30,22,post onset


In [71]:
fig = go.Figure()
fig.add_trace(go.Bar(cambodia_import, x="when", y=cambodia_import.columns))

ValueError: The first argument to the plotly.graph_objs.Bar 
constructor must be a dict or 
an instance of :class:`plotly.graph_objs.Bar`

In [36]:
fig = px.bar(cambodia_import, x="when", y=cambodia_import.columns)

fig.update_layout(
    width=500,
    height=600,
)
fig.update_yaxes(type="log")
fig.show()

In [63]:
list_countries = ['Afghanistan', 'Laos', 'Cambodia',
                 'Azerbaijan', 'Rwanda']
onset = [1978, 1960, 1970, 1991, 1990]
imp = []


for c, o in zip(list_countries, onset):
    country_import = imp_number(trade, c, onset=o, buffer=5)
    imp.append(country_import)

d={'country':list_countries, 'onset_import':imp}
df = pd.DataFrame(data=d)
df[df.country==list_countries[0]]['onset_import'].iloc[0]




Unnamed: 0,aircraft,helicopter,gun,missile,tank,engine,car,ac,apc,other,when
0,62,0,0,0,0,0,0,0,0,46,pre onset
0,46,190,628,7000,850,0,0,0,1160,1400,post onset


In [69]:
list_countries = ['Afghanistan', 'Laos', 'Cambodia',
                 'Azerbaijan', 'Rwanda']
onset = [1978, 1960, 1970, 1991, 1990]
imp = []


for c, o in zip(list_countries, onset):
    country_import = imp_number(trade, c, onset=o, buffer=5)
    imp.append(country_import)

d={'country':list_countries, 'onset_import':imp}
df = pd.DataFrame(data=d)



fig = go.Figure() 
fig.add_trace(go.Bar(x=df[df.country==list_countries[0]]['onset_import'].iloc[0]["when"],
                     y=df[df.country==list_countries[0]]['onset_import'].iloc[0].columns))


#pre, post = imp_number(trade, c, onset = o, buffer = 5)
#fig.add_trace(go.Bar(x=[list_countries[0]], y=[pre]))
fig.update_layout(barmode='group')
#fig.update_yaxes(type="log")

fig.update_layout(title='Weapons import in',
               xaxis_title='Country',
               yaxis_title='Weapon import')

button_country = list([dict(
                 args = [{"x": df[df.country==list_countries[0]]['onset_import'].iloc[0]["when"],
                         "y": df[df.country==list_countries[k]]['onset_import'].iloc[0]  
                         }],
                 label = list_countries[k],
                 method = "restyle"
                ) for k in range(len(list_countries))])

fig.update_layout(
    width=800,
    height=800,
    updatemenus=[
        go.layout.Updatemenu(
            buttons=button_country,
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.3,
            xanchor="left",
            y=1.235,
            yanchor="top"
        ),
    ]
)
fig.show()

TypeError: Object of type DataFrame is not JSON serializable