# PEP-1-1 model implementation in pyomo


https://www.pep-net.org/sites/pep-net.org/files/EXTER-en.pdf

## 1 Preparation

### 1.1 Load libraries

In [None]:
import pyomo.environ as pyo
from pyomo.environ import AbstractModel, Set, Param, Var, Constraint
import pyomo.dataportal.DataPortal as DataPortal
import pandas as pd
import json
from pyomo.core.expr.numvalue import value as constraintValue
from pyomo.core.util import prod
import idaes.core.util as idaes_utils
from pyomo.core.expr.current import identify_variables
from pyomo.common.collections import ComponentSet
import numpy as np
from openpyxl import load_workbook

### 1.2 Rewrite data

In [None]:
df = pd.read_excel('data/SAM-V2_0.xls', 
    sheet_name='SAM', header=[0,1], index_col=[0,1], skiprows=3)

## drop two last rows, which are filler
df = df[:-2]

df.to_json('data/pep-1-1-SAM.json')

### 1.3 Helper functions

In [None]:
sam_dict = {}

## Make some helper functions
## Dictionary is organized as "FROM (C, D) TO (A, B)"
## But commands will come in form (TO (A, B) FROM (C, D))
def sam(c, over=[], prices=None):
    
    coord = [0,0,0,0]
    d = {}
    p = {}
    if prices:
        if isinstance(prices, int):
            if(len(over) == 1):
                for cat in over[0]:
                    p[cat] = prices
            elif(len(over) == 2):
                for cat in over[1]:
                    p[cat] = prices
            elif (len(over) == 0):
                p['singular'] = prices
        else:
            p = prices
    p['ADM'] = 1 #hack
    
    # print(p)
    ## need to loop over levels:
    if (len(over) == 0):
        # print(sam_dict[f"('{c[0]}', '{c[1]}')"][f"('{c[2]}', '{c[3]}')"])
        if ('singular' in list(p.keys())):
            return sam_dict[f"('{c[2]}', '{c[3]}')"][f"('{c[0]}', '{c[1]}')"] / p['singular']    
        else:
            return sam_dict[f"('{c[2]}', '{c[3]}')"][f"('{c[0]}', '{c[1]}')"]
        # return sam[(c[0],c[1])][(c[2],c[3])]
        ## return the coordinate
    if len(over) == 1:
        idx = c.index(0)
        for cat in over[0]:
            coord = c.copy()
            coord[idx] = cat
            d[cat] = sam(coord)
            if d[cat] is None: d[cat] = 0 #fix nonetype values
            ## adjust for prices
            if cat in p.keys():
                d[cat] = d[cat]/p[cat]
    if len(over) == 2:
        idx = c.index(0)
        idx2 = c.index(1)
        for cat in over[0]:
            for cat2 in over[1]:
                coord = c.copy()
                coord[idx] = cat
                coord[idx2] = cat2
                d[(cat,cat2)] = sam(coord)
                if d[(cat,cat2)] is None: d[(cat,cat2)] = 0 #fix nonetype values
                ## assume only 2nd variable is adjusted for prices
                if cat in p.keys():
                    d[(cat,cat2)] = d[(cat,cat2)]/p[cat]
                
    


    # print(d)
    return d

## helper function for retriving initial values
def val(obj, sub1=None, sub2=None, default=None):
    try:
        if isinstance(obj, Var):
            if (sub2): return obj._value_init_value[(sub1,sub2)]
            if (sub1): return obj._value_init_value[sub1]
            return obj._value_init_value
        if isinstance(obj, Param):
            values = obj.default()
            if (sub2): return values[(sub1,sub2)]
            if (sub1): return values[sub1]
            return values
    except KeyError:
        if default is not None:
            return default
        else:
            raise


def divide(d1, d2):
    d = {}
    for key in d1.keys():
        d[key] = d1[key]/d2[key]
    return d


### 1.4 Initialize model

In [None]:
m = AbstractModel()

## 2 Define sets

### 2.1 Sectors and commodities

In [None]:
## Sectors and commodities
industries = [
    'AGR',      # Agriculture and other primary industries
    'IND',      # Manufacturing and construction
    'SER',      # Services
    'ADM'       # Public administration
    ]
m.industries = Set(dimen=1,initialize=industries, doc='All industries') 

commodities = [
    'AGR',      # Agriculture and other primary commodities
    'FOOD',     # Food and beverages
    'OTHIND',   # Other manufacturing and construction
    'SER',      # Services
    'ADM'       # Public administration
    ]
commodities_ex_agr = ['FOOD','OTHIND','SER','ADM']
commodities_ex_adm = ['AGR','FOOD','OTHIND','SER']
m.commodities = Set(dimen=1,initialize=commodities, doc='All commodities') 
m.commodities_ex_agr = Set(dimen=1,initialize=commodities_ex_agr, doc='All commodities except agriculture') 
m.commodities_ex_adm = Set(dimen=1,initialize=commodities_ex_adm, doc='All commodities except public administration') 

### 2.2 Production factors

In [None]:
## Production factors
labour = [
    'USK',      # Unskilled workers
    'SK'        # Skilled workers
    ]
m.labour = Set(dimen=1,initialize=labour, doc='Labor categories') 

capital = [
    'CAP',      # Capital
    'LAND'      # Land
]
m.capital = Set(dimen=1,initialize=capital, doc='Capital categories') 

### 2.3 Agents

In [None]:
## Agents
agents = [
    'HRP',      # Poor rural households
    'HUP',      # Poor urban households
    'HRR',      # Rich rural households
    'HUR',      # Rich urban households
    'FIRM',     # Firms
    'GVT',      # Government
    'ROW'       # Rest of the world
    ]
agents_nongov = ['HRP','HUP','HRR','HUR','FIRM','ROW']
agents_domestic = ['HRP','HUP','HRR','HUR','FIRM','GVT']
households = ['HRP','HUP','HRR','HUR']
firms = ['FIRM']
m.agents = Set(dimen=1,initialize=agents, doc='All agents') 
m.agents_nongov = Set(dimen=1,initialize=agents_nongov, doc='Non governmental agents') 
m.agents_domestic = Set(dimen=1,initialize=agents_domestic, doc='Domestic agents') 
m.households = Set(dimen=1,initialize=households, doc='Households') 
m.firms = Set(dimen=1,initialize=firms, doc='Firms') 

## 3 Load data

### 3.1 Data from the SAM

In [None]:
sam_dict = DataPortal()
sam_dict.load(filename='data/pep-1-1-SAM.json', encoding='utf-8')
# sam_dict.data()

### 3.2 Other data (parameters)

In [None]:
industry_indexed_params = pd.read_excel('data/VAL_PAR.xlsx', 
    sheet_name='PAR', header=[0], index_col=[0], skiprows=4, nrows=4)
commodity_indexed_params = pd.read_excel('data/VAL_PAR.xlsx', 
    sheet_name='PAR', header=[0], index_col=[0], skiprows=11, nrows=5)
agent_indexed_params = pd.read_excel('data/VAL_PAR.xlsx', 
    sheet_name='PAR', header=[0], index_col=[0], skiprows=26, nrows=10).transpose()

## special treatment for sigma_X, sigma_Y
industry_and_commodity_indexed_params = pd.read_excel('data/VAL_PAR.xlsx', 
    sheet_name='PAR', header=[0], index_col=[0], skiprows=19, nrows=4)
sigma_X = {}
sigma_Y = {}
for i in commodities:
    for j in industries:
        sigma_X[(j,i)] = industry_and_commodity_indexed_params[i][j]
    for h in households:
        sigma_Y[(i,h)] = agent_indexed_params[i][h]

In [None]:
## example of loading value
agent_indexed_params['frisch'][:4].to_dict()

## 4 Variables and parameters

### 4.1 Simple parameters

In [None]:
## CES and CET elasticities
m.sigma_KD  = Param(m.industries, doc='Elasticity (CES - composite capital)', default=industry_indexed_params['sigma_KD'].to_dict())
m.sigma_LD  = Param(m.industries, doc='Elasticity (CES - composite labor)', default=industry_indexed_params['sigma_LD'].to_dict())
m.sigma_M   = Param(m.commodities, doc='Elasticity (CES - composite commodity)', default=commodity_indexed_params['sigma_M'].to_dict())
m.sigma_VA  = Param(m.industries, doc='Elasticity (CES - value added)', default=industry_indexed_params['sigma_VA'].to_dict())
m.sigma_XT  = Param(m.industries, doc='Elasticity (CET - total output)', default=industry_indexed_params['sigma_XT'].to_dict())
m.sigma_X   = Param(m.industries, m.commodities, doc='Elasticity (CET - exports and local sales)', default=sigma_X)
 
## Elasticity of international demand for exported commodity i
m.sigma_XD  = Param(m.commodities, doc='Price elasticity of the world demand for exports of product i', default=commodity_indexed_params['sigma_XD'].to_dict())

## LES parameters
m.frisch    = Param(m.households, doc='Frisch parameter (LES function)', default=agent_indexed_params['frisch'][:4].to_dict())

# ** Price elasticity (should be set equal to one when verifying model homogeneity)
m.eta   = Param(doc='Price elasticity of indexed transfers and parameters', default=1)

## Distribution shares
lambda_RK_unweighted = sam(['AG',0,'K',1], over=[agents, capital])
lambda_WL_unweighted = sam(['AG',0,'L',1], over=[households, labour])




### 4.1 Tax components

In [None]:
## Get taxes and margins
m.TDF       = Var(m.firms, doc='Income taxes of type f businesses', 
    initialize=sam(['AG','TD','AG',0], over=[firms]))
m.TDH       = Var(m.households, doc='Income taxes of type h households', 
    initialize=sam(['AG','TD','AG',0], over=[households]))
m.TIC       = Var(m.commodities, doc='Government revenue from indirect taxes on product i', 
    initialize=sam(['AG','TI','I',0], over=[commodities]))
m.TIM       = Var(m.commodities, doc='Government revenue from import duties on product i', 
    initialize=sam(['AG','TM','I',0], over=[commodities]))
m.TIX       = Var(m.commodities_ex_adm, doc='Government revenue from export taxes on product i', 
    initialize=sam(['AG','GVT','X',0], over=[commodities_ex_adm]))
m.TIW       = Var(m.labour, m.industries, doc='Government revenue from payroll taxes on type l labor in industry j', 
    initialize=sam(['AG',0,'J',1], over=[labour, industries]))
m.TIK       = Var(m.capital, m.industries, doc='Government revenue from taxes on type k capital used by industry j', 
    initialize=sam(['AG',0,'J',1], over=[capital, industries]))
m.TIP       = Var(m.industries, doc='Government revenue from taxes on industry j production (excluding taxes directly related to the use of capital and labor)', 
    initialize=sam(['AG','GVT','J',0], over=[industries]))
   

### 4.2 Prices

In [None]:
#note! structure here is income/row-oriented: TO (AG, CAP) FROM (AG, FIRM), TO row FROM column

# Let's get prices first:
prices = {
    'PL': dict(map(lambda i: (i,1.0), commodities)),
    'PE': dict(map(lambda i: (i,1.0), commodities)),
    'e': 1,
    'PWM': dict(map(lambda i: (i,1.0), commodities)),
    'W': dict(map(lambda l: (l,1.0), labour)),
    'RK': dict(map(lambda k: (k,1.0), capital)),
    'R': dict(map(lambda kj: (kj, 1.0), val(m.TIK).keys()))
}

## Get some nominal variables
DS_nominal = sam(['J', 0, 'I', 1], over=[industries, commodities])  #"Demand for domestic commodity i"
DD_nominal = dict(map(lambda i: (i, sum([DS_nominal[(j,i)] for j in industries])), commodities))
IM_nominal = sam(['AG', 'ROW', 'I', 0,], over=[commodities]) #"Imports of commodity i"
EX_nominal = sam(['J', 0, 'X', 1], over=[industries,commodities_ex_adm]) #"Exports of commodity i"
EXD_nominal= sam(['X', 0, 'AG', 'ROW'], over=[commodities_ex_adm]) #"World demand for exports of product i"
DI_nominal = sam(['I', 0, 'J', 1], over=[commodities, industries]) # Intermediate consumption of commodity i by industry j
KD_nominal = sam(['K', 0, 'J', 1], over=[capital, industries]) # Demand for type k capital by industry j
LD_nominal = sam(['L', 0, 'J', 1], over=[labour, industries]) # Demand for type l labor by industry j
tmrg_nominal = sam(['I', 0,'I', 1], over=[commodities, commodities])  # Rate of margin i applied to commodity i


## Consumer prices
prices['PC'] = dict(map(lambda i: (i, 
    (DD_nominal[i] + IM_nominal[i]+ sum([tmrg_nominal[i2,i] for i2 in commodities]) +val(m.TIC, i) + val(m.TIM, i)) / (DD_nominal[i] + IM_nominal[i])
    ), commodities))

## get some tax variables
ttix = dict(map(lambda i: (i,
    val(m.TIX, i)/(sam(['X',i,'AG','ROW'])-val(m.TIX, i))
),commodities_ex_adm))
ttiw = dict(map(lambda i:
    (i, val(m.TIW, i[0], i[1])/LD_nominal[i] if LD_nominal[i] else 0)
, val(m.TIW).keys()))
ttik = dict(map(lambda kj:
    (kj, val(m.TIK, kj[0], kj[1])/KD_nominal[kj] if KD_nominal[kj] else 0)
, val(m.TIK).keys()))


## recalulcate margins:
tmrg_volume = sam(['I', 0,'I', 1], over=[commodities, commodities], prices=prices['PC'])
tmrg_X_volume = sam(['I', 0,'X', 1], over=[commodities, commodities_ex_adm], prices=prices['PC'])

## recalculate nominal variables in volumes:
DD_volume = dict(map(lambda i: (i, DD_nominal[i]/prices['PL'][i]), commodities))
IM_volume = dict(map(lambda i: (i, IM_nominal[i]/(prices['PWM'][i]*prices['e'])), commodities))
DI_volume = dict(map(lambda i: (i, DI_nominal[i]/prices['PC'][i[0]]), DI_nominal.keys()))
EX_volume = sam(['J', 0,'X', 1], over=[industries, commodities_ex_adm], prices=prices['PE']) #"Exports of commodity i"
DS_volume = dict(map(lambda ji: (ji,
    DS_nominal[ji]/prices['PL'][ji[1]]
), DS_nominal.keys()))
XS_volume = dict(map(lambda ji: (ji,
    DS_volume[ji] + EX_volume[ji] if ji in EX_volume.keys() else DS_volume[ji]
), DS_nominal.keys()))
XST_volume = dict(map(lambda j: (j,
    sum([XS_volume[(j,i)] for i in commodities])
), industries))
CI_volume = dict(map(lambda j: (j,
    sum([DI_volume[(i,j)] for i in commodities])
),industries))
LD_volume = dict(map(lambda lj: (lj, 
    LD_nominal[lj]/prices['W'][lj[0]]
), LD_nominal.keys()))
KD_volume = dict(map(lambda kj: (kj, 
    KD_nominal[kj]/prices['R'][kj]
), KD_nominal.keys()))
LDC_volume = dict(map(lambda j: (j,
    sum([LD_volume[(l, j)] for l in labour])
), industries))
KDC_volume = dict(map(lambda j: (j,
    sum([KD_volume[(k, j)] for k in capital])
), industries))
VA_volume = dict(map(lambda j: (j,
    LDC_volume[j]+KDC_volume[j]
), industries))

## hacks
EX_volume_safe = EX_volume.copy()
for j in industries:
    EX_volume_safe[(j, 'ADM')] = 0

## get another tax variable
ttim = dict(map(lambda i: (i,
    val(m.TIM, i)/(prices['e']*prices['PWM'][i]*IM_volume[i]) if IM_volume[i] > 0 else 0
), commodities))


## Weigh margins
tmrg_weighted = dict(map(lambda i:
    (i, tmrg_volume[i]/(DD_volume[i[1]]+IM_volume[i[1]]))
, tmrg_volume.keys()))

tmrg_X_weighted = dict(map(lambda i:
    (i, tmrg_X_volume[i]/(sum([EX_volume[(j, i[1])] for j in industries])))
, tmrg_X_volume.keys()))

## Get another tax variable
ttic = dict(map(lambda i: (i, 
    val(m.TIC, i)/(
        (prices['PL'][i] + sum([prices['PC'][i2]*tmrg_weighted[(i2, i)] for i2 in commodities]))*DD_volume[i] + 
        (prices['e']*prices['PWM'][i]+sum([prices['PC'][i2]*tmrg_weighted[(i2, i)] for i2 in commodities]))*IM_volume[i] +
        val(m.TIM, i))
), commodities))


## Additional price calulations
prices['PD'] = dict(map(lambda i: (i,
    (prices['PL'][i] + sum([prices['PC'][i2]*tmrg_weighted[(i2,i)] for i2 in commodities]))*(1+ttic[i])
), commodities))

## import prices:
prices['PM'] = dict(map(lambda i: (i,
    ((1+ttim[i])*prices['e']*prices['PWM'][i] + sum([prices['PC'][i2]*tmrg_weighted[(i2, i)] for i2 in commodities])) * (1+ttic[i])
),commodities))

prices['PE_FOB'] = dict(map(lambda i: (i,
   (1+ttix[i])*(prices['PE'][i]+sum([prices['PC'][i2]*tmrg_X_weighted[(i2,i)] for i2 in commodities_ex_adm]))
),commodities_ex_adm))
prices['PE_FOB']['ADM'] = 1.0

prices['PWX'] = dict(map(lambda i: (i,
   prices['PE_FOB'][i]/prices['e']
),commodities))

prices['P'] = dict(map(lambda ji: (ji,
    (prices['PL'][ji[1]]*DS_volume[ji]+prices['PE'][ji[1]]*EX_volume_safe[ji])/XS_volume[ji] if XS_volume[ji] else (prices['PL'][ji[1]]*DS_volume[ji]+prices['PE'][ji[1]]*EX_volume_safe[ji])
), DS_nominal.keys()))

prices['PT'] = dict(map(lambda j: (j,
    sum([prices['P'][(j,i)]*XS_volume[(j,i)] for i in commodities]) / XST_volume[j]
), industries))

prices['PCI'] = dict(map(lambda j: (j,
    sum([prices['PC'][i]*DI_volume[(i,j)] for i in commodities]) / CI_volume[j]
), industries))

prices['WTI'] = dict(map(lambda lj:
    (lj, prices['W'][lj[0]]*(1+ttiw[lj]))
, val(m.TIW).keys()))

prices['RTI'] = dict(map(lambda kj:
    (kj, prices['R'][kj]*(1+ttik[kj]))
, val(m.TIK).keys()))

prices['WC'] = dict(map(lambda j: (j,
    sum([prices['WTI'][(l, j)]*LD_volume[(l,j)] for l in labour]) / LDC_volume[j] if LDC_volume[j] else 0
), industries))

prices['RC'] = dict(map(lambda j: (j,
    sum([prices['RTI'][(k, j)]*KD_volume[(k,j)] for k in capital]) / KDC_volume[j] if KDC_volume[j] else 0
), industries))

prices['PVA'] = dict(map(lambda j: (j,
    (prices['WC'][j]*LDC_volume[j] + prices['RC'][j]*KDC_volume[j]) / VA_volume[j]
), industries))

## trailing tax variable
ttip = dict(map(lambda j: (j,
    val(m.TIP, j) / (prices['PVA'][j]*VA_volume[j] + sum([prices['PC'][i]*DI_volume[(i,j)] for i in commodities]))
), industries))

## last non-index price
prices['PP'] = dict(map(lambda j: (j,
    prices['PT'][j]/(1+ttip[j])
), industries))

#  PIXGDPO is tautologically equal to 1, based on its formula
#  PIXGDPO         = {SUM[j,{(PVAO(j)*VAO(j)+TIPO(j))/VAO(j)}*VAO(j)]
#                    /SUM[j,{(PVAO(j)*VAO(j)+TIPO(j))/VAO(j)}*VAO(j)]
#                    *SUM[j,{(PVAO(j)*VAO(j)+TIPO(j))/VAO(j)}*VAO(j)]
#                    /SUM[j,{(PVAO(j)*VAO(j)+TIPO(j))/VAO(j)}*VAO(j)]}**0.5;
prices['PIXGDP'] = 1.0

#  PIXCONO is tautologically equal to 1, based on its formula
#  PIXCONO         = SUM[i,PCO(i)*SUM[h,CO(i,h)]]/SUM[i,PCO(i)*SUM[h,CO(i,h)]];
prices['PIXCON'] = 1.0

#  PIXGVTO is tautologically equal to 1, based on its formula
#  PIXGVTO         = PROD[i$gamma_GVT(i),(PCO(i)/PCO(i))**gamma_GVT(i)];
prices['PIXGVT'] = 1.0

#  PIXINVO is tautologically equal to 1, based on its formula
#  PIXINVO         = PROD[i$gamma_INV(i),(PCO(i)/PCO(i))**gamma_INV(i)];
prices['PIXINV'] = 1.0


## Assign price variables
m.e     = Var(doc='Exchange rate (price of foreign currency in local currency)', initialize=prices['e'])
m.P     = Var(m.industries, m.commodities, doc='Basic price of industry j\'s production of commodity i', initialize=prices['P'])
m.PC    = Var(m.commodities, doc='Purchaser price of composite comodity i (including all taxes and margins)', initialize=prices['PC'])
m.PCI   = Var(m.industries, doc='Intermediate consumption price index of industry j', initialize=prices['PCI'])
m.PD    = Var(m.commodities, doc='Price of local product i sold on the domestic market (including all taxes and margins)', initialize=prices['PD'])
m.PE    = Var(m.commodities, doc='Price received for exported commodity i (excluding export taxes)', initialize=prices['PE'])
m.PE_FOB= Var(m.commodities, doc='FOB price of exported commodity i (in local currency)', initialize=prices['PE_FOB'])
m.PL    = Var(m.commodities, doc='Price of local product i (excluding all taxes on products)', initialize=prices['PL'])
m.PM    = Var(m.commodities, doc='Price of imported product i (including all taxes and tariffs)', initialize=prices['PM'])
m.PT    = Var(m.industries, doc='Basic price of industry j\'s output', initialize=prices['PT'])
m.PP    = Var(m.industries, doc='Industry j unit cost including taxes directly related to the use of capital and labor but excluding other taxes on production', initialize=prices['PP'])
m.PVA   = Var(m.industries, doc='Price of industry j value added (including taxes on production directly related to the use of capital and labor)', initialize=prices['PVA'])
m.PWM   = Var(m.commodities, doc='World price of imported product i (expressed in foreign currency)', initialize=prices['PWM'])
m.PWX   = Var(m.commodities, doc='World price of exported product i (expressed in foreign currency)', initialize=prices['PWX'])
m.R     = Var(m.capital, m.industries, doc='Rental rate of type k capital in industry j', initialize=prices['R'])
m.RC    = Var(m.industries, doc='Rental rate of industry j composite capital', initialize=prices['RC'])
m.RK    = Var(m.capital, doc='Rental rate of type k capital (if capital is mobile)', initialize=prices['RK'])
m.RTI   = Var(m.capital, m.industries, doc='Rental rate paid by industry j for type k capital including capital taxes', initialize=prices['RTI'])
m.W     = Var(m.labour, doc='Wage rate of type l labor', initialize=prices['W'])
m.WC    = Var(m.industries, doc='Wage rate of industry j composite labor', initialize=prices['WC'])
m.WTI   = Var(m.labour, m.industries, doc='Wage rate paid by industry j for type l labor including payroll taxes', initialize=prices['WTI']) 

m.PIXCON   = Var(doc='Consumer price index', initialize=prices['PIXCON'])
m.PIXGDP   = Var(doc='GDP deflator', initialize=prices['PIXGDP'])
m.PIXGVT   = Var(doc='Public expenditures price index', initialize=prices['PIXGVT'])
m.PIXINV   = Var(doc='Investment price index', initialize=prices['PIXINV'])

 


### 4.3 Volume Variables

In [None]:
m.C     = Var(m.commodities, m.households, doc='Consumption of commodity i by type h households', initialize=sam(['I',0,'AG',1], over=[commodities, households], prices=prices['PC']))
m.CG    = Var(m.commodities, doc='Public consumption of commodity i', initialize=sam(['I',0,'AG','GVT'], over=[commodities], prices=prices['PC']))
m.CI    = Var(m.industries, doc='Total intermediate consumption of industry j', initialize=dict(map(lambda j: (j,
    sum([DI_volume[(i,j)] for i in commodities])
), industries)))
m.DD    = Var(m.commodities, doc='Domestic demand for commodity i produced locally', initialize=DD_volume)
m.DI    = Var(m.commodities, m.industries, doc='Intermediate consumption of commodity i by industry j', initialize=DI_volume)
m.DIT   = Var(m.commodities, doc='Total intermediate demand for commodity i', initialize=dict(map(lambda i: (i,
    sum([val(m.DI, i, j) for j in industries])
), commodities)))
m.DS    = Var(m.industries, m.commodities, doc='Supply of commodity i by sector j to the domestic market', initialize=DS_volume)
m.EX    = Var(m.industries, m.commodities_ex_adm, doc='Quantity of product i exported by sector j', initialize=EX_volume)
m.EXD   = Var(m.commodities_ex_adm, doc='World demand for exports of product i', initialize=dict(map(lambda i: (i,
    EXD_nominal[i] / (prices['PWX'][i]*prices['e'])
), commodities_ex_adm)))

m.IM    = Var(m.commodities, doc='Quantity of product i imported', initialize=IM_volume)
m.INV   = Var(m.commodities, doc='Final demand of commodity i for investment purposes (GFCF)', initialize=sam(['I',0,'OTH','INV'], over=[commodities], prices=prices['PC']))
m.KD    = Var(m.capital, m.industries, doc='Demand for type k capital by industry j', initialize=KD_volume)
m.KDC   = Var(m.industries, doc='Industry j demand for composite capital', initialize=KDC_volume)
m.KS    = Var(m.capital, doc='Supply of type k capital', initialize=dict(map(lambda k: (k,
    sum([val(m.KD, k, j) for j in industries])
), capital)))
m.LD    = Var(m.labour, m.industries, doc='Demand for type l labor by industry j', initialize=LD_volume)
m.LDC   = Var(m.industries, doc='Industry j demand for composite labor', initialize=LDC_volume)
m.LS    = Var(m.labour, doc='Supply of type l labor', initialize=dict(map(lambda l: (l,
    sum([val(m.LD, l, j) for j in industries])
), labour)))
m.MRGN    = Var(m.commodities, doc='Demand for commodity i as a trade or transport margin', initialize=dict(map(lambda i: (i,
    sum([tmrg_weighted[(i, i2)]*DD_volume[i2] for i2 in commodities]) + 
    sum([tmrg_weighted[(i, i2)]*IM_volume[i2] for i2 in commodities]) + 
    sum([tmrg_X_weighted[(i, i2)]*EX_volume[(j,i2)] for i2 in commodities_ex_adm for j in industries])
), commodities)))
m.Q     = Var(m.commodities, doc='Quantity demanded of composite commodity i', initialize=dict(map(lambda i: (i,
    (prices['PM'][i]*IM_volume[i] + prices['PD'][i]*DD_volume[i]) / prices['PC'][i]
), commodities)))
m.VA    = Var(m.industries, doc='Value added of industry j', initialize=VA_volume)
m.VSTK  = Var(m.commodities, doc='Inventory change of commodity i', initialize=sam(['I',0,'OTH','VSTK'], over=[commodities], prices=prices['PC']))
m.XS    = Var(m.industries, m.commodities, doc='Industry j production of commodity i', initialize=XS_volume)
m.XST   = Var(m.industries, doc='Total aggregate output of industry j', initialize=XST_volume)

### 4.4 Nominal (value) Variables

In [None]:
## Transfers
m.TR    = Var(m.agents, m.agents, doc='Transfers from agent ag[1] to agent ag[0]', initialize=sam(['AG',0,'AG',1], over=[agents, agents]))

## Households income and savings
m.YHK   = Var(m.households, doc='Capital income of type h households', initialize=dict(map(lambda h: (h,
    sum([lambda_RK_unweighted[(h, k)] for k in capital])
), households)))
m.YHL   = Var(m.households, doc='Labor income of type h households', initialize=dict(map(lambda h: (h,
    sum([lambda_WL_unweighted[(h, l)] for l in labour])
), households)))
m.YHTR = Var(m.households, doc='Transfer income of type h households', initialize=dict(map(lambda h: (h,
    sum([val(m.TR, h, ag) for ag in agents])
), households)))
m.YH   = Var(m.households, doc='Total income of type h households', initialize=dict(map(lambda h: (h,
    val(m.YHK, h) + val(m.YHL, h) + val(m.YHTR, h)
), households)))
m.YDH  = Var(m.households, doc='Disposable income of type h households', initialize=dict(map(lambda h: (h,
    val(m.YH, h) - val(m.TDH, h) - val(m.TR, 'GVT', h)
), households)))
m.SH    = Var(m.households, doc='Savings of type h households', initialize=sam(['OTH','INV','AG',0], over=[households]))
m.CTH   = Var(m.households, doc='Consumption budget of type h households', initialize=dict(map(lambda h: (h,
    val(m.YDH, h) - val(m.SH, h) - sum([val(m.TR, agng, h) for agng in agents_nongov])
), households)))

## Firms income and savings
m.YFK   = Var(m.firms, doc='Capital income of type f businesses', initialize=dict(map(lambda f: (f,
    sum([lambda_RK_unweighted[(f, k)] for k in capital])
), firms)))
m.YFTR = Var(m.firms, doc='Transfer income of type f businesses', initialize=dict(map(lambda f: (f,
    sum([val(m.TR, f, ag) for ag in agents])
), firms)))
m.YF   = Var(m.firms, doc='Total income of type f businesses', initialize=dict(map(lambda f: (f,
    val(m.YFK, f) + val(m.YFTR, f)
), firms)))
m.YDF  = Var(m.firms, doc='Disposable income of type f businesses', initialize=dict(map(lambda f: (f,
    val(m.YF, f) - val(m.TDF, f)
), firms)))
m.SF    = Var(m.firms, doc='Savings of type f businesses', initialize=sam(['OTH','INV','AG',0], over=[firms]))

## Government income and savings
m.YGK   = Var(doc='Government capital income', initialize=sum([lambda_RK_unweighted[('GVT', k)] for k in capital]))
m.YGTR  = Var(doc='Government transfer income', initialize=sum([val(m.TR, 'GVT', ag) for ag in agents]))
m.TDHT  = Var(doc='Total government revenue from household income taxes', initialize=sum([val(m.TDH, h) for h in households]))
m.TDFT  = Var(doc='Total government revenue from business income taxes', initialize=sum([val(m.TDF, f) for f in firms]))
m.TICT  = Var(doc='Total government receipts of indirect taxes on commodities', initialize=sum([val(m.TIC, i) for i in commodities]))
m.TIMT  = Var(doc='Total government revenue from import duties', initialize=sum([val(m.TIM, i) for i in commodities]))
m.TIXT  = Var(doc='Total government revenue from export taxes', initialize=sum([val(m.TIX, i) for i in commodities_ex_adm]))
m.TIWT  = Var(doc='Total government revenue from payroll taxes', initialize=sum([val(m.TIW, l, j) for l in labour for j in industries]))
m.TIKT  = Var(doc='Total government revenue from from taxes on capital', initialize=sum([val(m.TIK, k, j) for k in capital for j in industries]))
m.TIPT  = Var(doc='Total government revenue from production taxes (excluding taxes directly related to the use of capital and labor)', initialize=sum([val(m.TIP, j) for j in industries]))
m.TPRODN  = Var(doc='Total government revenue from other taxes on production', initialize=val(m.TIKT) + val(m.TIWT) + val(m.TIPT))
m.TPRCTS  = Var(doc='Total government revenue from taxes on products and imports', initialize=val(m.TICT) + val(m.TIMT) + val(m.TIXT))
m.YG    = Var(doc='Total government income', initialize=
        val(m.YGK) + val(m.TDHT) + val(m.TDFT) + val(m.TPRODN) + val(m.TPRCTS) + val(m.YGTR))
m.SG    = Var(doc='Government savings', initialize=sam(['OTH','INV','AG','GVT']))
m.G     = Var(doc='Current government expenditures on goods and services', initialize=sum([prices['PC'][i]*val(m.CG, i) for i in commodities]))

## Rest of the world
m.YROW  = Var(doc='Rest-of-the-world income', initialize=
    sum([IM_nominal[i] for i in commodities]) + sum([lambda_RK_unweighted[('ROW',k)] for k in capital]) + sum([val(m.TR,'ROW',ag) for ag in agents]))
m.SROW  = Var(doc='Rest-of-the-world savings', initialize=sam(['OTH','INV','AG','ROW']))
m.CAB   = Var(doc='Current account balance', initialize=-1*val(m.SROW))


## Investment and capital formation
m.IT    = Var(doc='Total investment expenditures', initialize=
    sum([val(m.SH,h) for h in households]) + sum([val(m.SF,f) for f in firms]) + val(m.SG) + val(m.SROW))
m.GFCF  = Var(doc='Gross fixed capital formation', initialize=
    val(m.IT) - sum([prices['PC'][i]*val(m.VSTK,i) for i in commodities]))


## GDP measures
m.GDP_BP    = Var(doc='GDP at basic prices', initialize=sum([prices['PVA'][j]*val(m.VA,j) for j in industries]) + val(m.TIPT))
m.GDP_MP    = Var(doc='GDP at market prices', initialize=val(m.GDP_BP) + val(m.TPRCTS))
m.GDP_IB    = Var(doc='GDP at market prices (income-based)', initialize=
    sum([prices['W'][l]*val(m.LD,l,j) for l in labour for j in industries]) +
    sum([prices['R'][(k,j)]*val(m.KD,k,j) for k in capital for j in industries]) +
    val(m.TPRODN) + val(m.TPRCTS))
m.GDP_FD    = Var(doc='GDP at purchasers\' prices from the perspective of final demand', initialize=
    sum([prices['PC'][i]*(
        sum([val(m.C,i,h) for h in households]) + val(m.CG,i) + val(m.INV,i) + val(m.VSTK,i) 
    ) for i in commodities]) +
    sum([prices['PE_FOB'][i]*val(m.EXD,i) for i in commodities_ex_adm]) - 
    sum([prices['PWM'][i]*prices['e']*val(m.IM,i) for i in commodities]))



### 4.5 Remaining parameters

In [None]:

# ** 4.6 Calibration of function parameters

# **  4.6.1 Leontief functions
m.io = Param(m.industries, doc='Coefficient (Leontief - intermediate consumption)', default=dict(map(lambda j: (j,
    val(m.CI,j)/val(m.XST, j)
), industries)))
m.v = Param(m.industries, doc='Coefficient (Leontief - value added)', default=dict(map(lambda j: (j,
    val(m.VA,j)/val(m.XST, j)
), industries)))
m.aij = Param(m.commodities, m.industries, doc='Input output coefficient', default=dict(map(lambda ij: (ij,
    val(m.DI)[ij]/val(m.CI, ij[1])
), val(m.DI).keys())))

# **  4.6.2 Calibration of CET parameters
# **   4.6.2.1 CET between commodities
m.rho_XT = Param(m.industries, doc='Elasticity parameter (CET - total output)', default=dict(map(lambda j: (j,
    (1 + val(m.sigma_XT,j))/val(m.sigma_XT,j)
), industries)))
m.beta_XT = Param(m.industries, m.commodities, doc='Share parameter (CET - total output)', 
default=dict(map(lambda ji: (ji,
    prices['P'][ji]*val(m.XS)[ji]**(1-val(m.rho_XT, ji[0])) / 
    sum([ 0  if val(m.XS,ji[0],i) <= 0 else
         prices['P'][(ji[0],i)]*val(m.XS,ji[0],i)**(1-val(m.rho_XT,ji[0])) for i in commodities])
    if val(m.XS)[ji] else 0
), val(m.XS).keys())))
m.B_XT = Param(m.industries, doc='Scale parameter (CET - total output)', 
default=dict(map(lambda j: (j,
    val(m.XST, j) / 
    sum([ 0  if val(m.XS,j,i) <= 0 else
         val(m.beta_XT,j,i)*val(m.XS,j,i)**(val(m.rho_XT,j)) 
         for i in commodities])**(1/val(m.rho_XT, j))
), industries)))


# **   4.6.2.2 CET between exports and local production
m.rho_X = Param(m.industries, m.commodities, doc='Elasticity parameter (CET - exports and local sales)', default=dict(map(lambda ji: (ji,
    1 if 
        ji not in val(m.EX).keys() or 
        val(m.EX)[ji] == 0 or 
        ji not in val(m.DS).keys() or 
        val(m.DS)[ji] == 0 else 
        (1+val(m.sigma_X)[ji])/val(m.sigma_X)[ji]
), val(m.XS).keys())))
m.beta_X = Param(m.industries, m.commodities, doc='Share parameter (CET - exports and local sales)', 
default=dict(map(lambda ji: (ji,
    prices['PE'][ji[1]]*val(m.EX, ji[0], ji[1], default=0)**(1-val(m.rho_X)[ji]) /
    ( prices['PE'][ji[1]]*val(m.EX, ji[0], ji[1], default=0)**(1-val(m.rho_X)[ji]) +
      prices['PL'][ji[1]]*val(m.DS)[ji]**(1-val(m.rho_X)[ji]) )
), filter(lambda ji2: val(m.XS)[ji2] > 0, list(val(m.XS).keys())) )))
m.B_X = Param(m.industries, m.commodities, doc='Scale parameter (CET - exports and local sales)', 
default=dict(map(lambda ji: (ji,
    val(m.XS)[ji] / (
        val(m.beta_X)[ji]*val(m.EX, ji[0], ji[1], default=0)**val(m.rho_X)[ji] +
        (1-val(m.beta_X)[ji])*val(m.DS)[ji]**val(m.rho_X)[ji]
    )**(1/val(m.rho_X)[ji])
), filter(lambda ji2: val(m.XS)[ji2] > 0, list(val(m.XS).keys())) )))

# **  4.6.3 Calibration of CES parameters
# **   4.6.3.1 Composite good
m.rho_M = Param(m.commodities, doc='Elasticity parameter (CES - composite commodity)', default=dict(map(lambda i: (i,
    -1 if val(m.IM,i,default=0) == 0 or val(m.DD,i) == 0 else (1-val(m.sigma_M,i))/val(m.sigma_M,i)
), commodities)))
m.beta_M = Param(m.commodities, doc='Share parameter (CES - composite commodity)', 
default=dict(map(lambda i: (i,
    prices['PM'][i]*val(m.IM,i,default=0)**(val(m.rho_M,i) + 1) /
    ( prices['PM'][i]*val(m.IM,i, default=0)**(val(m.rho_M,i) + 1) +
      prices['PD'][i]*val(m.DD,i)**(val(m.rho_M,i) + 1) )
), commodities)))
m.B_M = Param(m.commodities, doc='Scale parameter (CES - composite commodity)', 
default=dict(map(lambda i: (i,
    0 if val(m.Q,i) == 0 else
    val(m.Q, i) / (
        val(m.beta_M,i)*val(m.IM,i)**(-1*val(m.rho_M,i)) +
        (1-val(m.beta_M,i))*val(m.DD,i)**(-1*val(m.rho_M,i))
    )**(-1/val(m.rho_M,i))
), commodities)))


# **   4.6.3.2 Composite capital
m.rho_KD = Param(m.industries, doc='Elasticity parameter (CES - composite capital)', default=dict(map(lambda j: (j,
    0  if val(m.KDC, j) == 0 else
    (1-val(m.sigma_KD,j))/val(m.sigma_KD,j)
), industries)))
m.beta_KD = Param(m.capital, m.industries, doc='Share parameter (CES - composite capital)', 
default=dict(map(lambda kj: (kj,
    0 if val(m.KD)[kj] == 0 else
    val(m.RTI)[kj]*val(m.KD)[kj]**(val(m.rho_KD)[kj[1]] + 1) /
    sum([
        0 if val(m.KD,k,kj[1]) == 0 else
        val(m.RTI,k,kj[1])*val(m.KD,k,kj[1])**(val(m.rho_KD,kj[1]) + 1) 
    for k in capital])
), val(m.KD).keys())))
m.B_KD = Param(m.industries, doc='Scale parameter (CES - composite capital)', default=dict(map(lambda j: (j,
    0 if val(m.KDC,j) == 0 else
    val(m.KDC,j) / 
    sum([ 0 if val(m.KD, k, j) == 0 else
        val(m.beta_KD, k, j)*val(m.KD, k, j)**(-1*val(m.rho_KD, j)) 
    for k in capital])**(-1/val(m.rho_KD, j))
), industries)))

# **   4.6.3.3 Composite labor
m.rho_LD = Param(m.industries, doc='Elasticity parameter (CES - composite labor)', default=dict(map(lambda j: (j,
    (1-val(m.sigma_LD,j))/val(m.sigma_LD,j)
), industries)))
m.beta_LD = Param(m.labour, m.industries, doc='Share parameter (CES - composite labor)', 
default=dict(map(lambda lj: (lj,
    0 if val(m.LD)[lj] == 0 else
    val(m.WTI)[lj]*val(m.LD)[lj]**(val(m.rho_LD)[lj[1]] + 1) /
    sum([
        0 if val(m.LD,l,lj[1]) == 0 else
        val(m.WTI,l,lj[1])*val(m.LD,l,lj[1])**(val(m.rho_LD,lj[1]) + 1) 
    for l in labour])
), val(m.LD).keys())))
m.B_LD = Param(m.industries, doc='Scale parameter (CES - composite labor)', default=dict(map(lambda j: (j,
    0 if val(m.LDC,j) == 0 else
    val(m.LDC,j) / 
    sum([ 0 if val(m.LD, l, j) == 0 else
        val(m.beta_LD, l, j)*val(m.LD, l, j)**(-1*val(m.rho_LD, j)) 
    for l in labour])**(-1/val(m.rho_LD, j))
), industries)))

# **   4.6.3.4 Value added
m.rho_VA = Param(m.industries, doc='Elasticity parameter (CES - composite labor)', default=dict(map(lambda j: (j,
    -1 if val(m.KDC,j) == 0 or val(m.LDC,j) == 0 else
        (1-val(m.sigma_VA,j))/val(m.sigma_VA,j)
), industries)))
m.beta_VA = Param(m.industries, doc='Share parameter (CES - composite commodity)', 
default=dict(map(lambda j: (j,
    prices['WC'][j]*val(m.LDC)[j]**(val(m.rho_VA,j) + 1) /
    ( prices['WC'][j]*val(m.LDC,j)**(val(m.rho_VA,j) + 1) +
      prices['RC'][j]*val(m.KDC,j)**(val(m.rho_VA,j) + 1) )
), industries)))
m.B_VA = Param(m.industries, doc='Scale parameter (CES - composite commodity)', 
default=dict(map(lambda j: (j,
    val(m.VA, j) / (
        val(m.beta_VA,j)*val(m.LDC,j)**(-1*val(m.rho_VA,j)) +
        (1-val(m.beta_VA,j))*val(m.KDC,j)**(-1*val(m.rho_VA,j))
    )**(-1/val(m.rho_VA,j))
), industries)))


# *   As the assigned values of income elasticities may not result in
# *   consumption shares that add up to 1, this first step
# *   adjusts the elasticities proportionnaly
#sigma_Y(i,h)    = sigma_Y(i,h)*CTHO(h)/SUM[ij,sigma_Y(ij,h)*PCO(ij)*CO(ij,h)];
m.sigma_Y   = Param(m.commodities, m.households, doc='Income elasticity of consumption', default=dict(map(lambda ih: (ih,
    sigma_Y[ih]*val(m.CTH, ih[1]) / sum([sigma_Y[(i,ih[1])]*prices['PC'][i]*val(m.C, i, ih[1]) for i in commodities])
), sigma_Y.keys())))

  
m.gamma_LES = Param(m.commodities, m.households, doc='Marginal share of commodity i in household h consumption budget', 
    default=dict(map(lambda ih: (ih,
    prices['PC'][ih[0]]*val(m.C)[ih]*val(m.sigma_Y)[ih] / val(m.CTH, ih[1])
), val(m.sigma_Y).keys())))

# ** 4.2 Calibration of investment and government spending shares
m.gamma_GVT = Param(m.commodities, doc='Share of commodity i in total current public expenditures on goods and services',
    default=dict(map(lambda i: (i,
    val(m.CG,i) / sum([val(m.CG,i2) for i2 in commodities])
), commodities)))
m.gamma_INV = Param(m.commodities, doc='Share of commodity i in total investment expenditures',
    default=dict(map(lambda i: (i,
    prices['PC'][i]*val(m.INV,i) / sum([prices['PC'][i2]*val(m.INV,i2) for i2 in commodities])
), commodities)))

lambda_TR = {}
for ag1 in agents:
    for ag2 in agents:
        if ag2 in households and ag1 != 'GVT':
            lambda_TR[(ag1,ag2)] = val(m.TR,ag1,ag2)/val(m.YDH,ag2)
        if ag2 in firms:
            lambda_TR[(ag1,ag2)] = val(m.TR,ag1,ag2)/val(m.YDF,ag2)


m.lambda_TR = Param(m.agents, m.agents, doc='Share parameter (transfer functions)', default=lambda_TR)
m.tmrg      = Param(m.commodities, m.commodities, doc='Rate of margin i applied to commodity ij', default=tmrg_weighted)
m.tmrg_X    = Param(m.commodities, m.commodities_ex_adm, doc='Rate of margin i applied to exported commodity i', default=tmrg_X_weighted)

#  kmob              Flag parameter (1 if capital is mobile)
m.kmob    = Param(doc='Flag parameter (1 if capital is mobile)', default=1, mutable=True)


### 4.6 Real values

In [None]:

## Some volume variables which are dependent on nominal variables
m.CMIN  = Var(m.commodities,m.households, doc='Minimum consumption of commodity i by type h households',
        initialize=dict(map(lambda ih: (ih,
                val(m.C)[ih] + val(m.gamma_LES)[ih]*(val(m.CTH,ih[1]) / (prices['PC'][ih[0]]*val(m.frisch,ih[1])))
        ), sigma_Y.keys())))
        
m.CTH_REAL      = Var(m.households, doc='Real consumption budget of type h households', initialize=dict(map(lambda h: (h,
        val(m.CTH,h) / val(m.PIXCON)
        ), households)))

m.G_REAL        = Var(doc='Real current government expenditures on goods and services', 
        initialize=val(m.G)/val(m.PIXGVT))

m.GDP_BP_REAL   = Var(doc='Real GDP at basic prices', 
        initialize=val(m.GDP_BP)/val(m.PIXGDP))

m.GDP_MP_REAL   = Var(doc='Real GDP at market prices', 
        initialize=val(m.GDP_MP)/val(m.PIXCON))

m.GFCF_REAL     = Var(doc='Real gross fixed capital formation', 
        initialize=val(m.GFCF)/val(m.PIXINV))



### 4.7 Intercepts and rates

In [None]:
## Intercepts and rates

# ** Intercepts of transfers, direct taxes and savings
# *  One can either choose to assign a value to the intercept and calibrate
# *  the slopes accordingly, or the other way around. This type of modelling
# *  can be useful to take into account known marginal savings or taxation rates
# *  or to deal with negative average saving rates in cases where savings are
# *  negative for some household groups.
# *  When no further information is available, one can simply set the intercepts
# *  to zero and calibrate an average rate: this is what we do here.
m.sh0   = Var(m.households, doc='Intercept (type h household savings)', initialize=agent_indexed_params['sh0O'][:4].to_dict())
m.sh1   = Var(m.households, doc='Slope (type h household savings)', initialize=dict(map(lambda h: (h,
        (val(m.SH,h)-val(m.sh0,h))/val(m.YDH,h)
        ), households)))

m.tr0   = Var(m.households, doc='Intercept (transfers by type h households to government)', initialize=agent_indexed_params['tr0O'][:4].to_dict())
m.tr1   = Var(m.households, doc='Marginal rate of transfers by type h households to government', initialize=dict(map(lambda h: (h,
        (val(m.TR,'GVT',h)-val(m.tr0,h))/val(m.YH,h)
        ), households)))


m.ttdf0 = Var(m.firms, doc='Intercept (income taxes of type f businesses)', initialize={'FIRM': agent_indexed_params['ttdf0O'][4]})
m.ttdf1 = Var(m.firms, doc='Marginal income tax rate of type f businesses', initialize=dict(map(lambda f: (f,
        (val(m.TDF,f)-val(m.ttdf0,f))/val(m.YFK,f)
        ), firms)))

m.ttdh0 = Var(m.households, doc='Intercept (income taxes of type h households)', initialize=agent_indexed_params['ttdh0O'][:4].to_dict())
m.ttdh1 = Var(m.households, doc='Marginal income tax rate of type h households', initialize=dict(map(lambda h: (h,
        (val(m.TDH,h)-val(m.ttdh0,h))/val(m.YH,h)
        ), households)))


m.ttic  = Var(m.commodities, doc='Tax rate on commodity i', initialize=ttic)
m.ttik  = Var(m.capital, m.industries, doc='Tax rate on type k capital used in industry j', initialize=ttik)
m.ttim  = Var(m.commodities, doc='Rate of taxes and duties on imports of commodity i', initialize=ttim)
m.ttip  = Var(m.industries, doc='Tax rate on the production of industry j', initialize=ttip)
m.ttiw  = Var(m.labour, m.industries, doc='Tax rate on type l worker compensation in industry j', initialize=ttiw)
m.ttix  = Var(m.commodities_ex_adm, doc='Export tax rate on exported commodity i', initialize=ttix)

#distribution parameters
m.lambda_RK = Param(m.agents, m.capital, doc='Share of type k capital income received by agent ag', default=dict(map(lambda agk: (agk,
        lambda_RK_unweighted[agk] / sum(val(m.KD, agk[1], j) for j in industries)
        ), lambda_RK_unweighted.keys())))
m.lambda_WL = Param(m.households, m.labour, doc='Share of type l labor income received by type h households', default=dict(map(lambda hl: (hl,
        lambda_WL_unweighted[hl] / sum(val(m.LD, hl[1], j) for j in industries)
        ), lambda_WL_unweighted.keys())))

m.LEON  = Var(doc='Excess supply on the last market', initialize=0)


## 5 Equations

### 5.1 Production

In [None]:
# EQ1(j)          Value added demand in industry j (Leontief)
def EQ1(m, j):
    return m.VA[j] == m.v[j]*m.XST[j]
m.EQ1 = Constraint(m.industries, rule=EQ1, doc='Value added demand in industry j (Leontief)')

# constrains: 4, vars: 8, 24

#  EQ2(j)          Total intermediate consumption demand in industry j (Leontief)
def EQ2(m, j):
    return m.CI[j] == m.io[j]*m.XST[j]
m.EQ2 = Constraint(m.industries, rule=EQ2, doc='Total intermediate consumption demand in industry j (Leontief)')

# constrains: 4, vars: 12, 28

#  EQ3(j)          CES between of composite labor and capital
def EQ3(m, j):
    return m.VA[j] == m.B_VA[j]* (
        (m.beta_VA[j]*m.LDC[j]**(-1*m.rho_VA[j])) + 
        ((1 - m.beta_VA[j])*m.KDC[j]**(-1*m.rho_VA[j]))
    )**(-1 / m.rho_VA[j])
m.EQ3 = Constraint(m.industries, rule=EQ3, doc='CES between of composite labor and capital')


#  EQ4(j)          Relative demand for composite labor and capital by industry j(CES)
#  Constraint only added for sectors with valid initial values
def EQ4(m, j):
    if val(m.KDC, j) != 0 and val(m.LDC,j) != 0:
        return m.LDC[j] == ((m.beta_VA[j] / (1 - m.beta_VA[j])) * ( m.RC[j]/m.WC[j] ) )**m.sigma_VA[j]*m.KDC[j]
    return Constraint.Skip
m.EQ4 = Constraint(m.industries, rule=EQ4, doc='Relative demand for composite labor and capital by industry j(CES)')

#  EQ5(j)          CES between labor categories
def EQ5(m, j):
    if val(m.LDC, j) != 0:
        return m.LDC[j] == m.B_LD[j]*sum(m.beta_LD[(l,j)]*m.LD[(l,j)]**(-1*m.rho_LD[j]) 
            for l in m.labour if val(m.LD,l,j) > 0)**(-1/m.rho_LD[j])
    return Constraint.Skip
m.EQ5 = Constraint(m.industries, rule=EQ5, doc='CES between labor categories')


#  EQ6(l,j)        Demand for type l labor by industry j (CES)
def EQ6(m, l, j):
    if val(m.LD, l, j) != 0:
        return m.LD[(l,j)] == ( m.beta_LD[(l,j)]*m.WC[j] / m.WTI[(l,j)])**m.sigma_LD[j]*m.B_LD[j]**(m.sigma_LD[j] - 1)*m.LDC[j] 
    return Constraint.Skip
m.EQ6 = Constraint(m.labour, m.industries, rule=EQ6, doc='Demand for type l labor by industry j (CES)')

#  EQ7(j)          CES between capital categories
def EQ7(m, j):
    if val(m.KDC,j) > 0:
        return m.KDC[j] == m.B_KD[j]*sum(
            m.beta_KD[(k,j)]*m.KD[(k,j)]**(-1*m.rho_KD[j]) 
            for k in m.capital if val(m.KD, k,j) > 0)**(-1/m.rho_KD[j])
    return Constraint.Skip
m.EQ7 = Constraint(m.industries, rule=EQ7, doc='CES between capital categories')

#  EQ8(k,j)        Demand for type k capital by industry j (CES)
def EQ8(m, k, j):
    if  val(m.KD, k, j) > 0:
        return m.KD[(k,j)] == ( m.beta_KD[(k,j)]*m.RC[j] / m.RTI[(k,j)])**m.sigma_KD[j]*m.B_KD[j]**(m.sigma_KD[j] - 1)*m.KDC[j] 
    return Constraint.Skip
m.EQ8 = Constraint(m.capital, m.industries, rule=EQ8, doc='Demand for type k capital by industry j (CES)')

#  EQ9(i,j)        Intermediate consumption of commodity i by industry j (Leontief)
def EQ9(m, i, j):
    return m.DI[(i,j)] == m.aij[(i,j)]*m.CI[j]
m.EQ9 = Constraint(m.commodities, m.industries, rule=EQ9, doc='Intermediate consumption of commodity i by industry j (Leontief)')



### 5.2 Income and Savings

In [None]:
## Households
#  EQ10(h)         Total income of type h households
def EQ10(m, h):
    return m.YH[h] == m.YHL[h] + m.YHK[h] + m.YHTR[h]
m.EQ10 = Constraint(m.households, rule=EQ10, doc='Total income of type h households')

#  EQ11(h)         Labor income of type h households
def EQ11(m, h):
    return m.YHL[h] == sum(m.lambda_WL[(h,l)]*m.W[l]*sum(m.LD[(l,j)] for j in m.industries
    if val(m.LD, l, j) > 0) for l in m.labour)
m.EQ11 = Constraint(m.households, rule=EQ11, doc='Labor income of type h households')

#  EQ12(h)         Capital income of type h households
def EQ12(m, h):
    return m.YHK[h] == sum(m.lambda_RK[(h,k)]*sum(m.R[(k,j)]*m.KD[(k,j)] for j in m.industries
    if val(m.KD,k,j) > 0) for k in m.capital)
m.EQ12 = Constraint(m.households, rule=EQ12, doc='Capital income of type h households')

#  EQ13(h)         Transfer income of type h households
def EQ13(m, h):
    return m.YHTR[h] == sum(m.TR[(h,ag)] for ag in m.agents)
m.EQ13 = Constraint(m.households, rule=EQ13, doc='Transfer income of type h households')

#  EQ14(h)         Disposable income of type h households
def EQ14(m, h):
    return m.YDH[h] == m.YH[h] - m.TDH[h] - m.TR[('GVT',h)]
m.EQ14 = Constraint(m.households, rule=EQ14, doc='Disposable income of type h households')

#  EQ15(h)         Consumption budget of type h households
def EQ15(m, h):
    return m.CTH[h] == m.YDH[h] - m.SH[h] - sum(m.TR[(agng,h)] for agng in m.agents_nongov)
m.EQ15 = Constraint(m.households, rule=EQ15, doc='Consumption budget of type h households')


#  EQ16(h)         Savings of type h households
def EQ16(m, h):
    return m.SH[h] == m.PIXCON**m.eta*m.sh0[h] + m.sh1[h]*m.YDH[h]
m.EQ16 = Constraint(m.households, rule=EQ16, doc='Savings of type h households')



In [None]:
### Firms
#  EQ17(f)         Total income of type f businesses
def EQ17(m, f):
    return m.YF[f] == m.YFK[f] + m.YFTR[f]
m.EQ17 = Constraint(m.firms, rule=EQ17, doc='Total income of type f businesses')

#  EQ18(f)         Capital income of type f businesses
def EQ18(m, f):
    return m.YFK[f] == sum(m.lambda_RK[(f,k)]*sum(
        m.R[(k,j)]*m.KD[(k,j)]
        for j in m.industries if val(m.KD, k, j) > 0
    ) for k in m.capital)
m.EQ18 = Constraint(m.firms, rule=EQ18, doc='Capital income of type f businesses')

#  EQ19(f)         Transfer income of type f businesses
def EQ19(m, f):
    return m.YFTR[f] == sum(m.TR[(f,ag)] for ag in m.agents)
m.EQ19 = Constraint(m.firms, rule=EQ19, doc='Transfer income of type f businesses')

#  EQ20(f)         Disposable income of type f businesses
def EQ20(m, f):
    return m.YDF[f] == m.YF[f] - m.TDF[f]
m.EQ20 = Constraint(m.firms, rule=EQ20, doc='Disposable income of type f businesses')

#  EQ21(f)         Savings of type f businesses
def EQ21(m, f):
    return m.SF[f] == m.YDF[f] - sum(m.TR[(ag,f)] for ag in m.agents)
m.EQ21 = Constraint(m.firms, rule=EQ21, doc='Savings of type f businesses')


In [None]:
# ## Government
#  EQ22            Total government income
def EQ22(m):
    return m.YG == m.YGK + m.TDHT + m.TDFT + m.TPRODN + m.TPRCTS + m.YGTR
m.EQ22 = Constraint(rule=EQ22, doc='Total government income')

#  EQ23            Government capital income
def EQ23(m):
    return m.YGK == sum(m.lambda_RK[('GVT', k)]*sum(
        m.R[(k,j)]*m.KD[(k,j)]
        for j in m.industries if val(m.KD, k, j) > 0
    ) for k in m.capital)
m.EQ23 = Constraint(rule=EQ23, doc='Total government income')

#  EQ24            Total government revenue from household income taxes
def EQ24(m):
    return m.TDHT == sum(m.TDH[h] for h in m.households)
m.EQ24 = Constraint(rule=EQ24, doc='Total government revenue from household income taxes')

#  EQ25            Total government revenue from business income taxes
def EQ25(m):
    return m.TDFT == sum(m.TDF[f] for f in m.firms)
m.EQ25 = Constraint(rule=EQ25, doc='Total government revenue from business income taxes')

#  EQ26            Total government revenue from other taxes on production
def EQ26(m):
    return m.TPRODN == m.TIWT + m.TIKT + m.TIPT
m.EQ26 = Constraint(rule=EQ26, doc='Total government revenue from other taxes on production')

#  EQ27            Total government receipts of indirect taxes on wages
def EQ27(m):
    return m.TIWT == sum(m.TIW[(l,j)] for l in m.labour for j in m.industries
    if val(m.LD,l,j) > 0)
m.EQ27 = Constraint(rule=EQ27, doc='Total government receipts of indirect taxes on wages')

#  EQ28            Total government receipts of indirect taxes on capital
def EQ28(m):
    return m.TIKT == sum(m.TIK[(k,j)] for k in m.capital for j in m.industries
    if val(m.KD, k,j) > 0)
m.EQ28 = Constraint(rule=EQ28, doc='Total government receipts of indirect taxes on capital')

#  EQ29            Total government revenue from production taxes
def EQ29(m):
    return m.TIPT == sum(m.TIP[j] for j in m.industries)
m.EQ29 = Constraint(rule=EQ29, doc='Total government revenue from production taxes')

#  EQ30            Total government revenue from taxes on products and imports
def EQ30(m):
    return m.TPRCTS == m.TICT + m.TIMT + m.TIXT
m.EQ30 = Constraint(rule=EQ30, doc='Total government revenue from taxes on products and imports')

#  EQ31            Total government receipts of indirect taxes on commodities
def EQ31(m):
    return m.TICT == sum(m.TIC[i] for i in m.commodities)
m.EQ31 = Constraint(rule=EQ31, doc='Total government receipts of indirect taxes on commodities')

#  EQ32            Total government revenue from import duties
def EQ32(m):
    return m.TIMT == sum(m.TIM[i] for i in m.commodities_ex_adm if val(m.IM, i) > 0)
m.EQ32 = Constraint(rule=EQ32, doc='Total government revenue from import duties')

#  EQ33            Total government revenue from export taxes
def EQ33(m):
    return m.TIXT == sum(m.TIX[i] for i in m.commodities_ex_adm if val(m.EXD, i) > 0)
m.EQ33 = Constraint(rule=EQ33, doc='Total government revenue from export taxes')

#  EQ34            Government transfer income
def EQ34(m):
    return m.YGTR == sum(m.TR[('GVT', agng)] for agng in m.agents_nongov)
m.EQ34 = Constraint(rule=EQ34, doc='Government transfer income')

#  EQ35(h)         Income taxes of type h households
def EQ35(m,h):
    return m.TDH[h] == m.PIXCON**m.eta*m.ttdh0[h] + m.ttdh1[h]*m.YH[h]
m.EQ35 = Constraint(m.households, rule=EQ35, doc='Income taxes of type h households')

#  EQ36(f)         Income taxes of type f businesses
def EQ36(m,f):
    return m.TDF[f] == m.PIXCON**m.eta*m.ttdf0[f] + m.ttdf1[f]*m.YFK[f]
m.EQ36 = Constraint(m.firms, rule=EQ36, doc='Income taxes of type f businesses')

#  EQ37(l,j)       Government revenue from payroll taxes on type l labor in industry j
def EQ37(m,l,j):
    if val(m.LD, l, j) > 0:
        return m.TIW[(l,j)] == m.ttiw[(l,j)]*m.W[l]*m.LD[(l,j)]
    return Constraint.Skip
m.EQ37 = Constraint(m.labour, m.industries, rule=EQ37, doc='Government revenue from payroll taxes on type l labor in industry j')

#  EQ38(k,j)       Government revenue from taxes on type k capital used by industry j
def EQ38(m,k,j):
    if val(m.KD, k, j) > 0:
        return m.TIK[(k,j)] == m.ttik[(k,j)]*m.R[(k,j)]*m.KD[(k,j)]
    return Constraint.Skip
m.EQ38 = Constraint(m.capital, m.industries, rule=EQ38, doc='Government revenue from taxes on type k capital used by industry j')

#  EQ39(j)         Government revenue from taxes on industry j production
def EQ39(m,j):
    return m.TIP[j] == m.ttip[j]*m.PP[j]*m.XST[j]
m.EQ39 = Constraint(m.industries, rule=EQ39, doc='Government revenue from taxes on industry j production')

# #  EQ40(i)         Government revenue from indirect taxes on product i
# def EQ40(m,i):
#     return m.TIC[i] == m.ttic[i]*(
#         (m.PL[i] + sum(m.PC[i2]*m.tmrg[(i2,i)] for i2 in m.commodities))*m.DD[i] +
#         ((1+m.ttim[i])*m.e*m.PWM[i] + sum(m.PC[i3]*m.tmrg[(i3,i)] for i3 in m.commodities))*m.IM[i]
#     )
# m.EQ40 = Constraint(m.commodities, rule=EQ40, doc='Government revenue from indirect taxes on product i')


#  EQ40(i)         Government revenue from indirect taxes on product i
def EQ40(m,i):
    if val(m.DD,i) > 0 and val(m.IM,i)>0:
        return m.TIC[i] == m.ttic[i]*(
            (m.PL[i] + sum(m.PC[i2]*m.tmrg[(i2,i)] for i2 in m.commodities))*m.DD[i] +
            ((1+m.ttim[i])*m.e*m.PWM[i] + sum(m.PC[i3]*m.tmrg[(i3,i)] for i3 in m.commodities))*m.IM[i]
        )
    elif val(m.DD,i) >0:
        return m.TIC[i] == m.ttic[i]*(
            (m.PL[i] + sum(m.PC[i2]*m.tmrg[(i2,i)] for i2 in m.commodities))*m.DD[i] +
            0
        )
    elif val(m.IM,i) > 0:
        return m.TIC[i] == m.ttic[i]*(
            0 +
            ((1+m.ttim[i])*m.e*m.PWM[i] + sum(m.PC[i3]*m.tmrg[(i3,i)] for i3 in m.commodities))*m.IM[i]
        )
m.EQ40 = Constraint(m.commodities, rule=EQ40, doc='Government revenue from indirect taxes on product i')


#  EQ41(i)         Government revenue from import duties on product i
def EQ41(m,i):
    if val(m.IM,i) > 0:
        return m.TIM[i] == m.ttim[i]*m.e*m.PWM[i]*m.IM[i]
    return Constraint.Skip
m.EQ41 = Constraint(m.commodities, rule=EQ41, doc='Government revenue from import duties on product i')

#  EQ42(i)         Government revenue from export taxes on product i
def EQ42(m,i):
    if val(m.EXD,i) > 0:
        return m.TIX[i] == m.ttix[i]*(m.PE[i] + sum(m.PC[i2]*m.tmrg_X[(i2,i)] for i2 in m.commodities_ex_adm))*m.EXD[i]
    return Constraint.Skip
m.EQ42 = Constraint(m.commodities_ex_adm, rule=EQ42, doc='Government revenue from export taxes on product i')


#  EQ43            Government savings
def EQ43(m):
    return m.SG == m.YG - sum(m.TR[(agng,'GVT')] for agng in m.agents_nongov) - m.G
m.EQ43 = Constraint(rule=EQ43, doc='Government savings')


In [None]:
## Rest of the world
#  EQ44            Rest-of-the-world income
def EQ44(m):
    return m.YROW == m.e*sum(m.PWM[i]*m.IM[i] for i in m.commodities if val(m.IM, i) > 0) + sum(
        m.lambda_RK[('ROW',k)]*sum(m.R[(k,j)]*m.KD[(k,j)] 
        for j in m.industries if val(m.KD, k, j) > 0) for k in m.capital) + sum(
            m.TR[('ROW', agd)] for agd in m.agents_domestic)
m.EQ44 = Constraint(rule=EQ44, doc='Rest-of-the-world income')

# EQ45            Rest-of-the-world savings
def EQ45(m):
    return m.SROW == m.YROW - sum(m.PE_FOB[i]*m.EXD[i] for i in commodities_ex_adm
    if val(m.EXD,i) > 0) - sum(
        m.TR[(agd, 'ROW')] for agd in m.agents_domestic
    )
m.EQ45 = Constraint(rule=EQ45, doc='Rest-of-the-world savings')

#  EQ46            Equivalence between current account balance and ROW savings
def EQ46(m):
    return m.SROW == -1*m.CAB
m.EQ46 = Constraint(rule=EQ46, doc='Equivalence between current account balance and ROW savings')

In [None]:
## Transfers
#  EQ47(agng,h)    Transfers from household h to agent agng
def EQ47(m, agng, h):
    return m.TR[(agng, h)] == m.lambda_TR[(agng, h)]*m.YDH[h]
m.EQ47 = Constraint(m.agents_nongov, m.households, rule=EQ47, doc='Transfers from household h to agent agng')

#  EQ48(h)         Transfers from household h to government
def EQ48(m, h):
    return m.TR[('GVT', h)] == m.PIXCON**m.eta*m.tr0[h] + m.tr1[h]*m.YH[h]
m.EQ48 = Constraint(m.households, rule=EQ48, doc='Transfers from household h to government')

#  EQ49(ag,f)      Transfers from type f businesses to agent ag
def EQ49(m, ag, f):
    return m.TR[(ag, f)] == m.lambda_TR[(ag, f)]*m.YDF[f]
m.EQ49 = Constraint(m.agents, m.firms, rule=EQ49, doc='Transfers from type f businesses to agent ag')

#  EQ50(agng)      Public transfers
def EQ50(m, agng):
    return m.TR[(agng, 'GVT')] == m.PIXCON**m.eta*val(m.TR, agng, 'GVT')
m.EQ50 = Constraint(m.agents_nongov, rule=EQ50, doc='Public transfers')

#  EQ51(agd)       Transfers from abroad
def EQ51(m, agd):
    return m.TR[(agd, 'ROW')] == m.PIXCON**m.eta*val(m.TR, agd,'ROW')
m.EQ51 = Constraint(m.agents_domestic, rule=EQ51, doc='Transfers from abroad')

### 5.3 Demand


In [None]:
#  EQ52(i,h)       Consumption of commodity i by type h households
def EQ52(m, i, h):
    return m.PC[i]*m.C[(i,h)] == m.PC[i]*m.CMIN[(i,h)] + m.gamma_LES[(i,h)]*(
        m.CTH[h] - sum(m.PC[i2]*m.CMIN[(i2,h)] for i2 in m.commodities)
    )
m.EQ52 = Constraint(m.commodities, m.households, rule=EQ52, doc='Consumption of commodity i by type h households')

#  EQ53            Gross fixed capital formation
def EQ53(m):
    return m.GFCF == m.IT - sum(m.PC[i]*m.VSTK[i] for i in m.commodities)
m.EQ53 = Constraint(rule=EQ53, doc='Gross fixed capital formation')

#  EQ54(i)         Final demand of commodity i for investment purposes (GFCF)
def EQ54(m, i):
    return m.PC[i]*m.INV[i] == m.gamma_INV[i]*m.GFCF
m.EQ54 = Constraint(m.commodities, rule=EQ54, doc='Gross fixed capital formation')

#  EQ55(i)         Public final consumption of commodity i
def EQ55(m, i):
    return m.PC[i]*m.CG[i] == m.gamma_GVT[i]*m.G
m.EQ55 = Constraint(m.commodities, rule=EQ55, doc='Public final consumption of commodity i')

#  EQ56(i)         Total intermediate demand for commodity i
def EQ56(m, i):
    return m.DIT[i] == sum(m.DI[(i,j)] for j in m.industries if val(m.DI,i,j) > 0)
m.EQ56 = Constraint(m.commodities, rule=EQ56, doc='Total intermediate demand for commodity i')

#  EQ57(i)         Demand for commodity i as a trade or transport margin
def EQ57(m, i):
    return m.MRGN[i] == sum(m.tmrg[(i,i2)]*m.DD[i2] for i2 in m.commodities if val(m.DD,i2) > 0) + sum(
        m.tmrg[(i,i3)]*m.IM[i3] for i3 in m.commodities if val(m.IM, i3) > 0) + sum (
        m.tmrg_X[(i,i4)]*m.EXD[i4] for i4 in m.commodities_ex_adm if val(m.EXD,i4) > 0)
m.EQ57 = Constraint(m.commodities, rule=EQ57, doc='Demand for commodity i as a trade or transport margin')


### 5.4 International trade

In [None]:
#  EQ58(j)         CET between different commodities produced by industry j
def EQ58(m, j):
    return m.XST[j] == m.B_XT[j]*sum( 
            m.beta_XT[(j,i)]*m.XS[(j,i)]**m.rho_XT[j] 
            for i in m.commodities if val(m.XS,j,i) > 0)**(1/m.rho_XT[j])
m.EQ58 = Constraint(m.industries, rule=EQ58, doc='CET between different commodities produced by industry j')

#  EQ59(j,i)       Industry j production of commodity i (CET)
def EQ59(m, j, i):
    if val(m.XS, j,i) > 0 and val(m.XS, j, i) != val(m.XST, j):
        return m.XS[(j,i)] == m.XST[j] / m.B_XT[j]**(1+m.sigma_XT[j])*(
            m.P[(j,i)]/(m.beta_XT[(j,i)]*m.PT[j])
        )**m.sigma_XT[j]
    return Constraint.Skip
m.EQ59 = Constraint(m.industries, m.commodities, rule=EQ59, doc='Industry j production of commodity i (CET)')


#  EQ60(j,i)       CET between exports and local commodity
def EQ60(m, j, i):
    if val(m.XS, j,i) > 0:
        if val(m.EX,j,i,default=0) > 0 and val(m.DS,j,i,default=0) > 0:
            return m.XS[(j,i)] == m.B_X[(j,i)]*(
                    (m.beta_X[(j,i)]*m.EX[(j,i)]**m.rho_X[(j,i)]) +
                    ((1-m.beta_X[(j,i)])*m.DS[(j,i)]**m.rho_X[(j,i)])
                )**(1/m.rho_X[(j,i)])
        if val(m.EX,j,i,default=0) > 0:
            return m.XS[(j,i)] == m.B_X[(j,i)]*(
                    (m.beta_X[(j,i)]*m.EX[(j,i)]**m.rho_X[(j,i)])
                )**(1/m.rho_X[(j,i)])
        if val(m.DS,j,i,default=0) > 0:
            return m.XS[(j,i)] == m.B_X[(j,i)]*(
                    ((1-m.beta_X[(j,i)])*m.DS[(j,i)]**m.rho_X[(j,i)])
                )**(1/m.rho_X[(j,i)])
    return Constraint.Skip
m.EQ60 = Constraint(m.industries, m.commodities, rule=EQ60, doc='CET between exports and local commodity')

#  EQ61(j,i)       Relative supply of exports and local commodity (CET)
def EQ61(m, j, i):
    if val(m.EX, j, i, default=0) > 0 and val(m.DS, j, i, default=0) > 0:
        return m.EX[(j,i)] == ( 
            ((1-m.beta_X[(j,i)])/m.beta_X[(j,i)]) * (m.PE[i]/m.PL[i]) )**m.sigma_X[(j,i)]*m.DS[(j,i)]
    return Constraint.Skip
m.EQ61 = Constraint(m.industries, m.commodities_ex_adm, rule=EQ61, doc='Relative supply of exports and local commodity (CET)')

#  EQ62(i)         World demand for exports of product i
def EQ62(m, i):
    if val(m.EXD, i) > 0:
        return m.EXD[i] == val(m.EXD, i)*(m.e*m.PWX[i]/m.PE_FOB[i])**m.sigma_XD[i]
    return Constraint.Skip
m.EQ62 = Constraint(m.commodities_ex_adm, rule=EQ62, doc='World demand for exports of product i')

# #  EQ63(i)         CES between imports and local production
# def EQ63(m, i):
#     return m.Q[i] == m.B_M[i]*(
#         (m.beta_M[i]*m.IM[i]**(-1*m.rho_M[i])) +
#         ((1-m.beta_M[i])*m.DD[i]**(-1*m.rho_M[i]))
#     )**(-1/m.rho_M[i])
# m.EQ63 = Constraint(m.commodities, rule=EQ63, doc='CES between imports and local production')

#  EQ63(i)         CES between imports and local production
def EQ63(m, i):
    if val(m.DD,i) > 0 and val(m.IM,i) > 0:
        return m.Q[i] == m.B_M[i]*(
            (m.beta_M[i]*m.IM[i]**(-1*m.rho_M[i])) +
            ((1-m.beta_M[i])*m.DD[i]**(-1*m.rho_M[i]))
        )**(-1/m.rho_M[i])
    elif val(m.DD,i) > 0:
            return m.Q[i] == m.B_M[i]*(
            0 +
            ((1-m.beta_M[i])*m.DD[i]**(-1*m.rho_M[i]))
        )**(-1/m.rho_M[i])
    elif val(m.IM,i) > 0:
        return m.Q[i] == m.B_M[i]*(
            (m.beta_M[i]*m.IM[i]**(-1*m.rho_M[i])) +
            0
        )**(-1/m.rho_M[i])
m.EQ63 = Constraint(m.commodities, rule=EQ63, doc='CES between imports and local production')

#  EQ64(i)         Demand for imports (CES)
def EQ64(m, i):
    if val(m.IM, i) > 0 and val(m.DD,i) > 0:
        return m.IM[i] == ( (m.beta_M[i]/(1-m.beta_M[i])) * (m.PD[i]/m.PM[i]) )**m.sigma_M[i]*m.DD[i]
    return Constraint.Skip
m.EQ64 = Constraint(m.commodities, rule=EQ64, doc='Demand for imports (CES)')

### 5.5 Prices

In [None]:
#  EQ65(j)         Industry j unit cost
def EQ65(m, j):
    return m.PP[j]*m.XST[j] == m.PVA[j]*m.VA[j] + m.PCI[j]*m.CI[j]
m.EQ65 = Constraint(m.industries, rule=EQ65, doc='Industry j unit cost')

#  EQ66(j)         Basic price of industry j's production of commodity i
def EQ66(m, j):
    return m.PT[j] == (1 + m.ttip[j])*m.PP[j]
m.EQ66 = Constraint(m.industries, rule=EQ66, doc='Basic price of industry j\'s production of commodity i')

#  EQ67(j)         Intermediate consumption price index of industry j
def EQ67(m, j):
    return m.PCI[j]*m.CI[j] == sum(m.PC[i]*m.DI[(i,j)] for i in m.commodities)
m.EQ67 = Constraint(m.industries, rule=EQ67, doc='Intermediate consumption price index of industry j')

#  EQ68(j)         Price of industry j value added
def EQ68(m, j):
    if val(m.LDC, j) > 0 and val(m.KDC, j) > 0:
        return m.PVA[j]*m.VA[j] == m.WC[j]*m.LDC[j] + m.RC[j]*m.KDC[j]
    elif val(m.LDC, j) > 0:
        return m.PVA[j]*m.VA[j] == m.WC[j]*m.LDC[j]
    elif val(m.KDC, j) > 0:
        return m.PVA[j]*m.VA[j] == m.RC[j]*m.KDC[j]
    return Constraint.Skip
m.EQ68 = Constraint(m.industries, rule=EQ68, doc='Price of industry j value added')

# * EQ69(j)         Wage rate of industry j composite labor
# Given the way equation 6 is written, equation 69 is redundant
def EQ69(m, j):
    return m.WC[j]*m.LDC[j] == sum(
        0 if val(m.LD, l, j) == 0 else
        m.WTI[(l,j)]*m.LD[(l,j)]
        for l in m.labour
    )
# m.EQ69 = Constraint(m.industries, rule=EQ69, doc='Wage rate of industry j composite labor')

#  EQ70(l,j)       Wage rate paid by industry j for type l labor including payroll taxes
def EQ70(m, l, j):
    return m.WTI[(l,j)] == m.W[l]*(1 + m.ttiw[(l,j)])
m.EQ70 = Constraint(m.labour, m.industries, rule=EQ70, doc='Wage rate paid by industry j for type l labor including payroll taxes')

# * EQ71(j)         Rental rate of industry j composite capital
# Given the way equation 8 is written, equation 71 is redundant
def EQ71(m, j):
    if val(m.kmob) > 0 and val(m.KDC,j) > 0:
        return m.RC[j]*m.KDC[j] == sum(
            m.RTI[(k,j)]*m.KD[(k,j)]
            for k in m.capital if val(m.KD,k,j) > 0
        )
    return Constraint.Skip
# m.EQ71 = Constraint(m.industries, rule=EQ71, doc='Rental rate of industry j composite capital')

#  EQ72(k,j)       Rental rate paid by industry j for type k capital including capital taxes
def EQ72(m, k, j):
    if val(m.KD, k, j) > 0:
        return m.RTI[(k,j)] == m.R[(k,j)]*(1+m.ttik[(k,j)])
    return Constraint.Skip
m.EQ72 = Constraint(m.capital, m.industries, rule=EQ72, doc='Rental rate paid by industry j for type k capital including capital taxes')

#  EQ73(k,j)       Rental rate of type k capital (if capital is mobile)
def EQ73(m, k, j):
    if val(m.kmob) > 0 and val(m.KD, k, j) > 0:
        return m.R[(k,j)] == m.RK[k]
    return Constraint.Skip
m.EQ73 = Constraint(m.capital, m.industries, rule=EQ73, doc='Rental rate of type k capital (if capital is mobile)')

#  EQ74(j,i)       Total producer price is equal to P if there is only one product
# Given the way equation 59 is written, equation 74 is redundant if a sector produces more than one commodity
def EQ74(m, j):
    return m.PT[j]*m.XST[j] == sum(m.P[(j,i)]*m.XS[(j,i)] for i in m.commodities)
# m.EQ74 = Constraint(m.industries, rule=EQ74, doc='Total producer price is equal to P if there is only one product')

def EQ74b(m, j, i):
    if val(m.XS, j, i) == val(m.XST, j):
        return m.P[(j,i)] == m.PT[j]
    return Constraint.Skip
m.EQ74 = Constraint(m.industries, m.commodities, rule=EQ74b, doc='Total producer price is equal to P if there is only one product')

#  EQ75(j,i)       Basic price of industry j's production of commodity i
def EQ75(m, j, i):
    if val(m.XS, j, i) > 0:
        if val(m.EX, j,i,default=0) > 0 and val(m.DS,j,i,default=0) > 0:
            return m.P[(j,i)]*m.XS[(j,i)] == m.PE[i]*m.EX[(j,i)] + m.PL[i]*m.DS[(j,i)]
        elif val(m.EX, j,i,default=0) > 0:
            return m.P[(j,i)]*m.XS[(j,i)] == m.PE[i]*m.EX[(j,i)]
        elif val(m.DS, j,i,default=0) > 0:
            return m.P[(j,i)]*m.XS[(j,i)] == m.PL[i]*m.DS[(j,i)]
        return Constraint.Skip
    return Constraint.Skip
m.EQ75 = Constraint(m.industries, m.commodities, rule=EQ75, doc='Basic price of industry j\'s production of commodity i')

#  EQ76(i)         Price received for exported commodity i (excluding export taxes)
def EQ76(m, i):
    if val(m.EXD, i) > 0:
        return m.PE_FOB[i] == (1 + m.ttix[i])*( m.PE[i] + sum(
            m.PC[i2]*m.tmrg_X[(i2,i)]
            for i2 in m.commodities_ex_adm
        ) )
    return Constraint.Skip
m.EQ76 = Constraint(m.commodities_ex_adm, rule=EQ76, doc='Price received for exported commodity i (excluding export taxes)')

#  EQ77(i)         Price of local product i sold on the domestic market (including all taxes and margins)
def EQ77(m, i):
    if val(m.DD, i) > 0:
        return m.PD[i] == (1 + m.ttic[i])*( m.PL[i] + sum(
            m.PC[i2]*m.tmrg[(i2,i)]
            for i2 in m.commodities
        ) )
    return Constraint.Skip
m.EQ77 = Constraint(m.commodities, rule=EQ77, doc='Price of local product i sold on the domestic market (including all taxes and margins)')

#  EQ78(i)         Price of imported product i (including all taxes and tariffs)
def EQ78(m, i):
    if val(m.IM, i) > 0:
        return m.PM[i] == (1 + m.ttic[i])*((1+m.ttim[i])*m.e*m.PWM[i] + sum(
            m.PC[i2]*m.tmrg[(i2,i)]
            for i2 in m.commodities
        ) )
    return Constraint.Skip
m.EQ78 = Constraint(m.commodities, rule=EQ78, doc='Price of imported product i (including all taxes and tariffs)')

# #  EQ79(i)         Purchaser price of composite comodity i
# def EQ79(m, i):
#     return m.PC[i]*m.Q[i] == m.PM[i]*m.IM[i] + m.PD[i]*m.DD[i]
# m.EQ79 = Constraint(m.commodities, rule=EQ79, doc='Purchaser price of composite commodity i')

#  EQ79(i)         Purchaser price of composite comodity i
def EQ79(m, i):
    if val(m.DD,i) > 0 and val(m.IM,i) > 0:
        return m.PC[i]*m.Q[i] == m.PM[i]*m.IM[i] + m.PD[i]*m.DD[i]
    elif val(m.DD,i) > 0:
        return m.PC[i]*m.Q[i] == 0 + m.PD[i]*m.DD[i]
    elif val(m.IM,i) > 0:
        return m.PC[i]*m.Q[i] == m.PM[i]*m.IM[i] + 0
m.EQ79 = Constraint(m.commodities, rule=EQ79, doc='Purchaser price of composite commodity i')

#  EQ80            GDP deflator (Fischer index)
def EQ80(m):
    return m.PIXGDP == (
        sum((m.PVA[j]*m.VA[j] + m.TIP[j]) / m.VA[j]*val(m.VA, j) for j in m.industries) /
        sum((val(m.PVA, j)*val(m.VA,j) + val(m.TIP, j)) / val(m.VA, j) * val(m.VA, j) for j in m.industries) *
        sum((m.PVA[j]*m.VA[j] + m.TIP[j]) / m.VA[j] * m.VA[j] for j in m.industries) /
        sum((val(m.PVA, j)*val(m.VA,j) + val(m.TIP, j)) / val(m.VA, j) * m.VA[j] for j in m.industries)
    )**0.5
m.EQ80 = Constraint(rule=EQ80, doc='GDP deflator (Fischer index)')

#  EQ81            Consumer price index (Laspeyres)
def EQ81(m):
    return m.PIXCON == sum(
        m.PC[i]*sum(val(m.C, i, h) for h in m.households)
        for i in m.commodities
    ) / sum(
        val(m.PC, i)*sum(val(m.C, i, h) for h in m.households)
        for i in m.commodities
    )
m.EQ81 = Constraint(rule=EQ81, doc='Consumer price index (Laspeyres)')

#  EQ82            Investment price index (derived from investment function)
def EQ82(m):
    return m.PIXINV == prod((m.PC[i]/val(m.PC, i))**m.gamma_INV[i] for i in 
     m.commodities if val(m.gamma_INV,i) > 0)
m.EQ82 = Constraint(rule=EQ82, doc='Investment price index (derived from investment function)')

#  EQ83            Public expenditures price index
def EQ83(m):
    return m.PIXGVT == prod((m.PC[i]/val(m.PC, i))**m.gamma_GVT[i] for i in m.commodities 
    if val(m.gamma_GVT,i) > 0)
m.EQ83 = Constraint(rule=EQ83, doc='Public expenditures price index')



### 5.6 Equilibrium

In [None]:
#  EQ84(i1)        Domestic absorbtion
def EQ84(m, i):
    return m.Q[i] == sum(m.C[(i,h)] for h in m.households) + m.CG[i] + m.INV[i] + m.VSTK[i] + m.DIT[i] + m.MRGN[i]
m.EQ84 = Constraint(m.commodities_ex_agr, rule=EQ84, doc='Domestic absorbtion')

#  EQ85(l)         Labor supply equals labor demand
def EQ85(m, l):
    return m.LS[l] == sum(m.LD[l,j] for j in m.industries if val(m.LD,l,j) > 0)
m.EQ85 = Constraint(m.labour, rule=EQ85, doc='Labor supply equals labor demand')

#  EQ86(k)         Capital supply equals capital demand
def EQ86(m, k):
    return m.KS[k] == sum(m.KD[k,j] for j in m.industries if val(m.KD, k, j) > 0)
m.EQ86 = Constraint(m.capital, rule=EQ86, doc='Capital supply equals capital demand')

#  EQ87            Total investment equals total savings
def EQ87(m):
    return m.IT == sum(m.SH[h] for h in m.households) + sum(m.SF[f] for f in m.firms) + m.SG + m.SROW
m.EQ87 = Constraint(rule=EQ87, doc='Total investment equals total savings')

#  EQ88            Supply of domestic production equals local demand
def EQ88(m, i):
    if val(m.DD, i) > 0:
        return m.DD[i] == sum(m.DS[(j,i)] for j in m.industries if val(m.DS,j,i) > 0)
    return Constraint.Skip
m.EQ88 = Constraint(m.commodities, rule=EQ88, doc='Supply of domestic production equals local demand')

#  EQ89(i)         Supply of exports equals international world demand
def EQ89(m, i):
    if val(m.EXD, i) > 0:
        return m.EXD[i] == sum(m.EX[(j,i)] for j in m.industries if val(m.EX,j,i) > 0)
    return Constraint.Skip
m.EQ89 = Constraint(m.commodities_ex_adm, rule=EQ89, doc='Supply of exports equals international world demand')


### 5.7 Gross Domestic Product

In [None]:
#  EQ90            GDP at basic prices
def EQ90(m):
    return m.GDP_BP == sum(m.PVA[j]*m.VA[j] for j in m.industries) + m.TIPT 
m.EQ90 = Constraint(rule=EQ90, doc='GDP at basic prices')

#  EQ91            GDP at market prices
def EQ91(m):
    return m.GDP_MP == m.GDP_BP + m.TPRCTS
m.EQ91 = Constraint(rule=EQ91, doc='GDP at market prices')

#  EQ92            GDP at market prices (income-based)
def EQ92(m):
    return m.GDP_IB == sum(
        m.W[l]*m.LD[l,j]
        for l in m.labour for j in m.industries if val(m.LD,l,j) > 0 
    ) + sum(
        m.R[(k,j)]*m.KD[(k,j)]
        for k in m.capital for j in m.industries if val(m.KD,k,j) > 0
    ) + m.TPRODN + m.TPRCTS
m.EQ92 = Constraint(rule=EQ92, doc='GDP at market prices (income-based)')

#  EQ93            GDP at purchasers' prices from the perspective of final demand
def EQ93(m):
    return m.GDP_FD == sum(
        m.PC[i]*(sum(m.C[(i,h)] for h in m.households) + m.CG[i] + m.INV[i] + m.VSTK[i])
        for i in m.commodities
    ) + sum(
        m.PE_FOB[i]*m.EXD[i]
        for i in m.commodities_ex_adm if val(m.EXD,i) > 0
    ) - sum (
        m.PWM[i]*m.e*m.IM[i]
        for i in m.commodities_ex_adm if val(m.IM, i) > 0
    )
m.EQ93 = Constraint(rule=EQ93, doc='GDP at purchasers\' prices from the perspective of final demand')



### 5.8 Real variables

In [None]:
#  EQ94(h)         Real consumption budget of type h households
def EQ94(m, h):
    return m.CTH_REAL[h] == m.CTH[h] / m.PIXCON
m.EQ94 = Constraint(m.households, rule=EQ94, doc='Real consumption budget of type h households')

#  EQ95            Real current government expenditures on goods and services
def EQ95(m):
    return m.G_REAL == m.G / m.PIXGVT
m.EQ95 = Constraint(rule=EQ95, doc='Real current government expenditures on goods and services')

#  EQ96            Real GDP at basic prices
def EQ96(m):
    return m.GDP_BP_REAL == m.GDP_BP / m.PIXGDP
m.EQ96 = Constraint(rule=EQ96, doc='Real GDP at basic prices')

#  EQ97            Real GDP at market prices
def EQ97(m):
    return m.GDP_MP_REAL == m.GDP_MP / m.PIXCON
m.EQ97 = Constraint(rule=EQ97, doc='Real GDP at market prices')

#  EQ98            Real gross fixed capital formation
def EQ98(m):
    return m.GFCF_REAL == m.GFCF / m.PIXINV
m.EQ98 = Constraint(rule=EQ98, doc='Real gross fixed capital formation')


### 5.9 Walras

In [None]:
#  WALRAS          Walras law verification
def WALRAS(m):
    return m.LEON == m.Q['AGR'] - sum(m.C[('AGR', h)] for h in households
    ) - m.CG['AGR'] - m.INV['AGR'] - m.VSTK['AGR'] - m.DIT['AGR'] - m.MRGN['AGR']
m.WALRAS = Constraint(rule=WALRAS, doc='Walras law verification')

## 6 Create instances

### 6.1 Helper functions

In [None]:
def count_nonfixed(i, active=True):
    
    num_non_fixed = 0
    num_fixed = 0
    num_vars = 0
    for block in i.block_data_objects(active=active):
        var_set = ComponentSet()
        for c in block.component_data_objects(
                ctype=Constraint, active=True, descend_into=True):
            for v in identify_variables(c.body):
                var_set.add(v)

        for v in var_set:
            if v.is_fixed():
                num_fixed += 1
            else: 
                num_non_fixed += 1
        num_vars = len(var_set)
 
    return num_non_fixed, num_fixed, num_vars

def count_constraints(i, active=True):
    num_constraints = 0
    for block in i.block_data_objects(active=active):
        for data in block.component_map(pyo.Constraint, active=active).values():
            for key in data._data.keys():
                num_constraints += 1
    return num_constraints

def checkSquareness(inst):
    num_non_fixed, num_fixed, num_vars = count_nonfixed(inst)
    if (num_non_fixed == inst.nconstraints()):
        print('Instance IS SQUARE.')
    else:
        print('Instance is NOT SQUARE!')

def checkFeasibility(inst, limit=1e-9):
    feasible = True
    num_checked = 0
    for block in inst.block_data_objects(active=True):
        for constraint in block.component_map(pyo.Constraint, active=True).values():
            # print(constraint, len(list(constraint._data.keys())), constraint.doc)
            for index in constraint.index_set():
                if index in list(constraint._data.keys()):
                    num_checked = num_checked + 1
                    if abs(constraintValue(constraint[index])) > limit:
                        feasible = False
                        print('Infeasibility:', constraint, index, constraintValue(constraint[index]))
    print(f'Checked {num_checked} constraints.')
    
    if feasible:
        print('All equations balanced. Instance is feasible')

def has_value(obj, key):
    if obj[key].value != 0:
         if obj[key].value != 0.0:
             if obj[key].value is not None:
                 return True
    return False

In [None]:
def applyFixes(inst, m):
    ## fix exogenous variables
    if val(m.kmob) == 0:
        inst.KD.fix()
    if val(m.kmob) == 1:
        inst.KS.fix()

    ## fix exchange rate and CAB
    inst.e.fix()
    inst.CAB.fix()

    ## fix minimum consumption and government budget
    inst.CMIN.fix()
    inst.G.fix()

    ## fix labour supply
    inst.LS.fix()

    ## fix prices on the world market
    inst.PWM.fix()
    inst.PWX.fix()

    ## fix stock variance (not sure why)
    inst.VSTK.fix()

    ## fix distribution shares
    inst.sh0.fix()
    inst.sh1.fix()
    inst.tr0.fix()
    inst.tr1.fix()
    inst.ttdf0.fix()
    inst.ttdf1.fix()
    inst.ttdh0.fix()
    inst.ttdh1.fix()

    ## fix tax rates
    inst.ttic.fix()
    inst.ttik.fix()
    inst.ttim.fix()
    inst.ttip.fix()
    inst.ttiw.fix()
    inst.ttix.fix()

    ## fix zero volumes for some variables
    # for block in inst.block_data_objects(active=True):
    #     for data in block.component_map(pyo.Var, active=True).values():
    #         if (data.getname() in some_volume_vars):
    #             for key in data.keys():
    #                 if not has_value(data, key):
    #                     data[key].fix()
    # print('applied fixes')

### 6.2 Define optimization objective

In [None]:
def obj_expression(m):
    return sum(m.C[(i, h)] for i,h in m.commodities_ex_adm*m.households)

m.OBJ = pyo.Objective(rule=obj_expression, sense=pyo.maximize)

### 6.3 Create base instance

In [None]:
inst_base = m.create_instance()
applyFixes(inst_base, m)

In [None]:
for block in inst_base.block_data_objects(active=True):
    print('variables:          ', idaes_utils.model_statistics.number_variables_in_activated_equalities(block))
    print('  fixed:            ', idaes_utils.model_statistics.number_fixed_variables_in_activated_equalities(block))
    print('  unfixed:          ', idaes_utils.model_statistics.number_unfixed_variables_in_activated_equalities(block))
    print('constraints:        ', idaes_utils.model_statistics.number_activated_equalities(block))
    print('degrees of freedom: ', idaes_utils.model_statistics.degrees_of_freedom(block))
    checkSquareness(block)
    checkFeasibility(block)
    idaes_utils.model_statistics.report_statistics(block)


### 6.4 Create scenarios

In [None]:
## simulations
inst_scenario_1 = m.create_instance()
inst_scenario_2 = m.create_instance()
inst_scenario_3 = m.create_instance()

applyFixes(inst_scenario_1, m)
applyFixes(inst_scenario_2, m)
applyFixes(inst_scenario_3, m)

## Scenario 1
# 25% increase of international import price of AGR
inst_scenario_1.PWM['AGR'] = 1.25*inst_scenario_1.PWM['AGR'].value

## Scenario 2
# 25% decrease of the indirect tax rates on all commodities
inst_scenario_2.ttix['AGR'] = 0.75*inst_scenario_2.ttix['AGR'].value
inst_scenario_2.ttix['FOOD'] = 0.75*inst_scenario_2.ttix['FOOD'].value
inst_scenario_2.ttix['OTHIND'] = 0.75*inst_scenario_2.ttix['OTHIND'].value
inst_scenario_2.ttix['SER'] = 0.75*inst_scenario_2.ttix['SER'].value

## Scenario 3
#  20% increase of public expenditures
inst_scenario_3.G = 1.2*inst_scenario_3.G.value



## 7 Solve instances

In [None]:
opt = pyo.SolverFactory('ipopt')
opt.options['max_cpu_time'] = 120 #seconds
opt.options['warm_start_init_point'] = 'yes'
opt.options['halt_on_ampl_error'] = 'yes'

In [None]:
res_base = opt.solve(inst_base, tee=True)
if (res_base.Solver.Status != 'ok'):
    print('Status NOT ok!')
if (res_base.Solver[0]['Termination condition'] != 'optimal'):
    print('Optimal solution NOT found!')
else:
    print('Optimal solution found!')
res_base

In [None]:
res_scenario_1 = opt.solve(inst_scenario_1, tee=False)
if (res_scenario_1.Solver.Status != 'ok'):
    print('Status NOT ok!')
if (res_scenario_1.Solver[0]['Termination condition'] != 'optimal'):
    print('Optimal solution NOT found!')
else:
    print('Optimal solution found!')
res_scenario_1

In [None]:
res_scenario_2 = opt.solve(inst_scenario_2, tee=False)
if (res_scenario_2.Solver.Status != 'ok'):
    print('Status NOT ok!')
if (res_scenario_2.Solver[0]['Termination condition'] != 'optimal'):
    print('Optimal solution NOT found!')
else:
    print('Optimal solution found!')
res_scenario_2

In [None]:
res_scenario_3 = opt.solve(inst_scenario_3, tee=False)
if (res_scenario_3.Solver.Status != 'ok'):
    print('Status NOT ok!')
if (res_scenario_3.Solver[0]['Termination condition'] != 'optimal'):
    print('Optimal solution NOT found!')
else:
    print('Optimal solution found!')
res_scenario_3

## 8 Compare solutions

### 8.1 Helper functions

In [None]:
def compare_scenarios(instances_dict):
    results_obj = {}    
    for instance_key in instances_dict.keys():
        for block in instances_dict[instance_key].block_data_objects(active=True):
            var_set = ComponentSet()
            for v in block.component_map(pyo.Var, active=True).values():
                if not v._name in results_obj.keys():
                    results_obj[v._name] = {}
                for key in v.keys():
                    if not key in results_obj[v._name].keys():
                        results_obj[v._name][key] = {}
                    results_obj[v._name][key][instance_key] = v[key].value

    return results_obj

def results_dict_to_pandas_dict(results_obj):
    pandas_dict = {}
    for key in results_obj.keys():
        df = pd.DataFrame.from_dict(results_obj[key], orient='index')
        columns = list(df.columns)
        for column in columns:
            if column == columns[0]:
                continue
            df[f'{column} %diff'] = df[column]/df[columns[0]] - 1
            df[f'{column} %diff'] = np.where((df[column] == 0) & (df[columns[0]] == 0), 
                                             0.0, df[f'{column} %diff'])
        pandas_dict[key] = df
    return pandas_dict

def write_pandas_dict_to_file(pandas_dict, wb_path):
    # book = load_workbook(wb_path)
    writer = pd.ExcelWriter(wb_path)
    workbook = writer.book
    percent_fmt = workbook.add_format({'num_format': '0.00%'})
    wrapped_text = workbook.add_format({
        'text_wrap': True
    })
    
    # writer.book = book
    for key in pandas_dict.keys():
        df = pandas_dict[key]
        first_row = 2
        last_row = 1 + len(df)
        num_index_cols = 1
        if type(df.index) == pd.core.indexes.multi.MultiIndex:
            num_index_cols = len(df.index[0])
        base_col = num_index_cols + 1
        num_scenarios = int((len(list(df.columns)) - 1) / 2)
        
        df.to_excel(writer, sheet_name=key, index=True)
        
        workbook = writer.book
        worksheet = writer.sheets[key]
        
        # Apply a conditional format to the cell range.
        worksheet.conditional_format(
            f'{chr(64 + base_col + num_scenarios + 1)}{first_row}:{chr(64 + base_col + num_scenarios*2)}{last_row}', 
            {'type': '3_color_scale'})
        worksheet.set_column(f'{chr(64 + base_col + num_scenarios + 1)}:{chr(64 + base_col + num_scenarios*2)}', None, percent_fmt)
        worksheet.set_row(0, 30, wrapped_text)
        
    writer.save()
    print('done.')

### 8.2 Store comparisons

In [None]:
results_obj = compare_scenarios({
    'base': inst_base, 
    'scenario 1': inst_scenario_1,
    'scenario 2': inst_scenario_2,
    'scenario 3': inst_scenario_3
    })
    


In [None]:
pandas_dict = results_dict_to_pandas_dict(results_obj)
    

In [None]:
write_pandas_dict_to_file(pandas_dict, 'output/results_pyomo.xlsx')