In [238]:
import pandas as pd
import matplotlib
from pylab import *
import networkx as nx

In [239]:
# Dataframe of edges
df= pd.read_csv('toy model trade.csv')
df.rename(columns= {'Country_A': 'country_A', 'Country_B': 'country_B','Partners_consump':'Partn_consump', 'Income_level': 'Group', 'Convergence rate': 'Beta', 'Imports growth rate': 'Alpha'}, inplace=True)
# dataframe of nodes
nodes= df.drop_duplicates(subset=['country_A'], keep='first')
df

Unnamed: 0,country_A,country_B,Exports,Imports,Group,Consumption,Partn_consump,Exports_share,Imports_share,Trade relation,Consumption transmission,Beta,Consumption growth,Alpha
0,1,2,40,160,1,100,50.0,0.118,0.326,2.770,0.163,0.01,0.002,0.05
1,1,3,50,52,1,100,25.0,0.147,0.106,0.720,0.026,0.01,0.000,0.05
2,1,4,20,40,1,100,25.0,0.059,0.081,1.385,0.020,0.01,0.000,0.05
3,1,5,30,32,1,100,10.0,0.088,0.065,0.739,0.007,0.01,0.000,0.05
4,1,6,30,25,1,100,10.0,0.088,0.051,0.577,0.005,0.01,0.000,0.05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
191,15,10,3,4,3,2,,,,,,0.01,,0.05
192,15,11,0,0,3,2,,,,,,0.01,,0.05
193,15,12,3,4,3,2,,,,,,0.01,,0.05
194,15,13,2,2,3,2,,,,,,0.01,,0.05


In [None]:
############# MAP ###############

def initialize():
    global G, nextg, pos
    # Create graph with edge attributes: Imports, parteners_consumption, and Imports growth rate= alpha
    G = nx.from_pandas_edgelist(df[['country_A','country_B','Imports','Partn_consump','Alpha']], 'country_A', 'country_B', edge_attr=['Imports','Partn_consump','Alpha'], create_using= nx.MultiDiGraph())
        # set graph layout
    pos = nx.spring_layout(G)
        # convert dataframe to dictionary of attributes
    node_attr = nodes[['country_A','Consumption','Beta','Group']].set_index('country_A').to_dict('index')
        # incorporate dictionary of attributes: Consumption, Convergence_rate= Beta, Group= income level
    nx.set_node_attributes(G, node_attr)

    nextg= G.copy()
    
#def observe():
    #global G, nextg, pos
    #cla()
    #plt.figure(figsize=(14,14))
    #nx.draw(G, cmap = cm.winter, vmin = 0, vmax = 1, pos = pos, node_size=12, width=0.1,
            #node_color = [G.node[i]['Consumption'] for i in G.nodes()])
 
    #plt.savefig("network.png", dpi=1000)
    
    
def update():
    global G, nextg
    
    # Compute Imports variation 

    for i, j, k, weight in nextg.edges(data="weight", keys=True): 
        #nextg.add_edge(i, j, 0, Imports=(1 + data['Alpha'])* data['Imports'])
        nextg[i][j][0]['Imports']= (1 + data['Alpha'])* data['Imports'] 

    # Compute Imports Share for each edge
    for i, j, k, weight in nextg.edges(data="weight", keys=True):
        tot_M=nextg.out_degree(i,'Imports') # total imports
        nextg[i][j][0]['Import_share'] = data['Imports']/tot_M # Imports share

    # Compute Consumption transmission using= Import_share * Partner_consump

    for i, j, k, weight in nextg.edges(data="weight", keys=True):
        nextg[i][j][0]['Consump_trans'] = data['Import_share']* data['Partn_consump']

        # Compute Consumption growth using= Consumption_transmission * Convergence rate Beta
        nextg[i][j][0]['Consump_growth']= data['Consump_trans']* nextg.nodes(data='Beta')[i]

    # Compute total consumption in next period
    for i in nextg.nodes():
        nextg.node[i]['Consumption'] = nextg.out_degree(i,'Consump_growth')+ nextg.nodes(data='Consumption')[i] 
        #Update partener_consumption
        for j in nextg.nodes():
            if i in nextg.neighbors(j):
                nextg[j][i][0]['Partn_consump']= nextg.node[i]['Consumption']
    
    
    
# Total consumption by group of countries (by income-level)    
def consumption_patterns():
    global G
    high_income=0
    middle_income= 0
    low_income= 0
    for i in G.nodes():
        if G.node[i]['Group']==1:
            high_income+= G.node[i]['Consumption']
        elif G.node[i]['Group']==2
            middle_income+= G.node[i]['Consumption']
        else:
            low_income+= G.node[i]['Consumption']
    return high_income, middle_income, low_income



N= 5
initialize()
consumption_high=[]
consumption_middle=[]
consumption_low=[]

for i in range(N):

    H, M, L = consumption_patterns() # Get consumption by groups at the end of period
    consumption_high.append(H) # array of total consumption for high income countries
    consumption_middle.append(M) # array of total consumption for middle income countries
    consumption_low.append(L) # array of total consumption for middle income countries
        
    update()
    

    
plt.plot(range(N),consumption_high, label='Consumption High-income countries')
plt.plot(range(N),consumption_high, label='Consumption Middle-income countries')
plt.plot(range(N),consumption_high, label='Consumption Low-income countries')

plt.xlabel('Time')
plt.ylabel('Total energy consumption')

In [None]:
for i, data in nextg.nodes.data():
    print(nextg.nodes(data=True))

In [None]:
for i, j, data in nextg.edges.data():
    print(nextg.edges(data=True))
    break