<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Data" data-toc-modified-id="Data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span><ul class="toc-item"><li><span><a href="#Settings" data-toc-modified-id="Settings-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Settings</a></span></li><li><span><a href="#Variables" data-toc-modified-id="Variables-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Variables</a></span></li><li><span><a href="#Parameters" data-toc-modified-id="Parameters-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Parameters</a></span></li><li><span><a href="#Equations" data-toc-modified-id="Equations-2.4"><span class="toc-item-num">2.4&nbsp;&nbsp;</span>Equations</a></span></li><li><span><a href="#Compile-Model" data-toc-modified-id="Compile-Model-2.5"><span class="toc-item-num">2.5&nbsp;&nbsp;</span>Compile Model</a></span></li></ul></li><li><span><a href="#Calibration" data-toc-modified-id="Calibration-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Calibration</a></span><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Check-calibration" data-toc-modified-id="Check-calibration-3.0.1"><span class="toc-item-num">3.0.1&nbsp;&nbsp;</span>Check calibration</a></span></li></ul></li></ul></li><li><span><a href="#Simulations" data-toc-modified-id="Simulations-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Simulations</a></span></li></ul></div>

In [1]:
import sympy as sp
import numpy as np
import numba as nb
from scipy import optimize
import sys
from functools import partial

sys.path.append("..")
from cge_modeling import CGEModel, Variable, Parameter
from cge_modeling.sympy_tools import (
    info_to_symbols,
    enumerate_indexbase,
    sub_all_eqs,
    dict_info_to_symbols,
    symbol,
    symbols,
    remove_string_keys,
)
from cge_modeling.production_functions import leontief, cobb_douglass
from cge_modeling.base.cge import (
    compile_cge_to_numba,
    expand_compact_system,
    recursive_solve_symbolic,
    numba_linearize_cge_func,
)
from cge_modeling.numba_tools import euler_approx, numba_lambdify
from cge_modeling.output_tools import display_info_as_table, latex_print_equations

import pandas as pd

# Data

In [3]:
df = pd.read_csv("data/lesson_5_sam.csv", index_col=[0, 1], header=[0, 1]).map(float).fillna(0)
assert np.allclose(df.sum(axis=0), df.sum(axis=1))

In [4]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Factor,Factor,Institution,Institution,Institution,Production,Production,Production,Activities,Activities,Activities,Other,Other
Unnamed: 0_level_1,Unnamed: 1_level_1,Labor,Capital,Household,Firms,Govt,Agriculture,Industry,Services,Agriculture,Industry,Services,Capital Accumulation,Rest of World
Factor,Labor,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1100.0,1700.0,4150.0,0.0,0.0
Factor,Capital,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,250.0,1800.0,550.0,0.0,0.0
Institution,Household,6950.0,0.0,0.0,2600.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Institution,Firms,0.0,2600.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Institution,Govt,0.0,0.0,1400.0,0.0,0.0,150.0,1320.0,280.0,0.0,0.0,0.0,0.0,0.0
Production,Agriculture,0.0,0.0,1700.0,0.0,400.0,0.0,0.0,0.0,1100.0,900.0,1100.0,200.0,250.0
Production,Industry,0.0,0.0,1200.0,0.0,900.0,0.0,0.0,0.0,2200.0,3150.0,3300.0,800.0,0.0
Production,Services,0.0,0.0,3400.0,0.0,2750.0,0.0,0.0,0.0,850.0,2250.0,1100.0,100.0,30.0
Activities,Agriculture,0.0,0.0,0.0,0.0,0.0,5500.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Activities,Industry,0.0,0.0,0.0,0.0,0.0,0.0,9800.0,0.0,0.0,0.0,0.0,0.0,0.0


# Model

## Settings

In [5]:
default_assumptions = {"real": True}
sectors = ["Ag", "Ind", "Serv"]
nm1_sectors = len(sectors) - 1

i, j = [sp.Idx(name) for name in list("ij")]
index_dict = {i: sectors, j: sectors}

symbol = partial(symbol, assumptions=default_assumptions)
symbols = partial(symbols, assumptions=default_assumptions)

## Variables

In [6]:
Y_base, C_base = [sp.IndexedBase(name) for name in ["Y", "C"]]
i = sp.Idx("i")
Y, C, a, b = sp.symbols("Y C a b")
var_dict = {var.name: var for var in [Y_base, C_base, a, b, i]}
known_symbols = [Y, C, a, b, Y_base[i], C_base[i], i]


def sympify_plus(expr, var_dict, known_symbols):
    if "=" in expr:
        if expr.count("=") > 1:
            raise ValueError("Cannot parse expression with multiple equality relationships")
        lhs, rhs = [x.strip() for x in expr.split("=")]
        eq = sp.Eq(sp.sympify(lhs, locals=var_dict), sp.sympify(rhs, locals=var_dict))
    else:
        eq = expr.sympify(expr, locals=var_dict)

    unknown_symbols = [x for x in eq.free_symbols if x not in known_symbols]
    atom_list = ", ".join([x.name for x in unknown_symbols])
    if len(unknown_symbols) > 0:
        raise ValueError(
            "Found unknown variables in parsed expression. Ensure the following have been added to your "
            f"model: {atom_list}"
        )

    return eq


eq = sympify_plus("Y[i] = a + b * C[i]", var_dict, known_symbols)

In [7]:
# 24 equations (Really 27, because 3 values are fixed)
variable_info = [
    # Firm variables
    Variable(
        name="Y",
        dims=("i",),
        latex_name="Y",
        description="Total output of final goods in sector i",
    ),
    Variable(
        name="VA",
        dims=("i",),
        latex_name="VA",
        description="Labor-Capital input mix used for final goods production in sector i",
    ),
    Variable(
        name="IC",
        dims=("i",),
        latex_name="IC",
        description="Final goods input mix usef for final goods production in sector i",
    ),
    Variable(
        name="Kd",
        dims=("i",),
        latex_name="K^d",
        description="Quantity of capital demanded by sector i",
    ),
    Variable(
        name="Ld",
        dims=("i",),
        latex_name="L^d",
        description="Quantity of labor demanded by sector i",
    ),
    Variable(
        name="CIJ",
        dims=("i", "j"),
        latex_name="Y^d",
        description="Quantity of sector j final good demanded by sector i",
    ),
    Variable(
        name="Id",
        dims=("i",),
        latex_name="I^d",
        description="Investment capital demanded by sector i",
    ),
    # Prices
    Variable(
        name="P",
        dims=("i",),
        latex_name="P_i",
        description="Market price of final goods produced by sector i",
    ),
    Variable(
        name="P_Y",
        dims=("i",),
        latex_name="P_Y",
        description="Producer price of final goods produced by sector i",
    ),
    Variable(
        name="P_w",
        dims=("i",),
        latex_name="P_w",
        description="World price of final goods produced by sector i",
    ),
    Variable(
        name="P_VA",
        dims=("i",),
        latex_name="P_{VA}",
        description=r"Price of the labor-capital input bundle in sector i",
    ),
    Variable(
        name="P_IC",
        dims=("i",),
        latex_name="P_{IC}",
        description="Price of final goods input bundle in sector i",
    ),
    Variable(name="w", latex_name="w", description="Wage level"),
    Variable(name="r", latex_name="r", description="Rental rate for capital"),
    # Household variables
    Variable(
        name="income",
        latex_name="\Omega",
        description="Gross (pre-tax) household income",
    ),
    Variable(
        name="net_income",
        latex_name=r"\hat{\Omega}",
        description="Net (after-tax) household income",
    ),
    Variable(name="hh_savings", latex_name="S", description="Household savings"),
    Variable(
        name="C",
        dims=("i",),
        latex_name="C",
        description=r"Household consumption of final goods produced by sector i",
    ),
    Variable(
        name="Is",
        latex_name="I^s",
        description="Investment capital supplied by households",
    ),
    Variable(name="U", latex_name="U", description="Household utility"),
    # Government variables
    Variable(name="G", latex_name="G", description="Total government spending"),
    Variable(
        name="C_G",
        dims=("i",),
        latex_name="C_{G}",
        description="Government consumption of final goods produced by sector i",
    ),
    Variable(
        name="G_savings",
        latex_name="S_G",
        description="Investment capital supplied by government",
    ),
    # International concepts
    Variable(
        name="Ed",
        dims=("i",),
        latex_name="E_d",
        description="Excess domestic demand for final goods produced by sector i",
    ),
    Variable(name="e", latex_name="e", description="Exchange rate"),
    Variable(name="TB", latex_name="TB", description="Domestic trade balance"),
    # Misc
    Variable(
        name="walras_resid",
        latex_name=r"\varepsilon",
        description="Errors and residuals in macroeconomic equlibrium",
    ),
]

parameter_info = [
    # Production parameters
    Parameter(
        name="alpha",
        dims=("i",),
        latex_name=r"\alpha_i",
        description="Share of capital in sector i production",
    ),
    Parameter(
        name="alpha_k",
        dims=("i",),
        latex_name=r"\alpha_I",
        description="Share of capital investment demanded by sector i",
    ),
    Parameter(
        name="A",
        dims=("i",),
        latex_name="A",
        description="Total factor productivity in sector i",
    ),
    Parameter(
        name="psi_VA",
        dims=("i",),
        latex_name=r"\psi_{VA}",
        description="Proportion of labor-capital bundle used in production by sector i",
    ),
    Parameter(
        name="psi_IC",
        dims=("i",),
        latex_name=r"\psi_IC",
        description="Proportion of final goods bundle used in production by sector i",
    ),
    Parameter(
        name="psi_Yij",
        dims=("i", "j"),
        latex_name=r"\psi_Y",
        description="Proportion of final good produced by sector j in final goods bundle used by sector i",
    ),
    # Tax parameters
    Parameter(
        name="tau",
        dims=("i",),
        latex_name=r"\tau_i",
        description="Sales tax on final goods produced by sector i",
    ),
    Parameter(
        name="tau_income",
        latex_name=r"\tau_{\Omega}",
        description="Flat income tax on households",
    ),
    Parameter(
        name="tau_m",
        dims=("i",),
        latex_name=r"\tau_m",
        description="Import tax on final goods produced by foreign sector i",
    ),
    # Household parameters
    Parameter(
        name="mps",
        latex_name="\phi",
        description="Household marginal propensity to save",
    ),
    Parameter(
        name="gamma",
        dims=("i",),
        latex_name="\gamma",
        description="Share of good i in household consumption basket",
    ),
    # Government parameters
    Parameter(
        name="alpha_G",
        dims=("i",),
        latex_name=r"\alpha_{G}",
        description="Share of good i in government consumption basket",
    ),
    # Trade concepts
    Parameter(
        name="f",
        dims=("i",),
        latex_name=r"f_i",
        description="Scale of domestic sector i demand on world price",
    ),
    Parameter(
        name="sigma",
        dims=("i",),
        latex_name=r"\sigma_i",
        description="Elasticity of domestic demand in sector i on world price",
    ),
    # Fixed values
    Parameter(name="Ks", latex_name="K^s", description="Exogenous supply of capital"),
    Parameter(name="Ls", latex_name="L^s", description="Exogenous supply of labor"),
    Parameter(name="e_bar", latex_name=r"\bar{e}", description="Exogenous exchange rate"),
    Parameter(
        name="G_savings_bar",
        latex_name=r"\bar{G}_S",
        description="Exogenous level of government savings",
    ),
    Parameter(
        name="P_Ag_bar",
        latex_name=r"\bar{P}_{Ag}",
        description="Numeraire price of agricultural goods",
    ),
]

In [9]:
coords = {"i": sectors, "j": sectors}

mod = CGEModel(coords=coords, variables=variable_info, parameters=parameter_info)

## Parameters

In [11]:
display_info_as_table(parameter_info)

<IPython.core.display.Latex object>

In [12]:
variables, var_updates = dict_info_to_symbols(variable_info, default_assumptions)
parameters, param_updates = dict_info_to_symbols(parameter_info, default_assumptions)
globals().update(var_updates)
globals().update(param_updates)

AttributeError: 'Variable' object has no attribute 'get'

## Equations

In [None]:
# temp idx, needed for the double-index
k = sp.Idx("k")

equation_info = [
    # Firm equations (9)
    dict(name="Final goods production", equation=Y * P_Y - P_VA * VA - P_IC * IC),
    dict(name="Firm self-demand for labor-capital bundle", equation=VA - psi_VA * Y),
    dict(name="Firm self-demand for intermediate goods bundle", equation=IC - psi_IC * Y),
    dict(
        name="Intermediate goods bundle production",
        equation=P_IC * IC - sp.Sum(P * CIJ, (j, 0, nm1_sectors)),
    ),
    dict(
        name="Labor-capital bundle production",
        equation=VA - A * Kd**alpha * Ld ** (1 - alpha),
    ),
    dict(name="Firm demand for capital", equation=alpha * P_VA * VA - r * Kd),
    dict(name="Firm demand for labor", equation=(1 - alpha) * P_VA * VA - w * Ld),
    dict(name="Firm demand for final goods of other firms", equation=CIJ - psi_Yij * IC),
    dict(name="Firm demand for investment capital", equation=P * Id - alpha_k * Is),

    # Government equations
    dict(
        name="Government budget constraint",
        equation=(
            tau_income * income
            + sp.Sum(tau * P_Y * Y, (i, 0, nm1_sectors))
            + sp.Sum(sp.Max(0, Ed) * tau_m * e * P_w, (i, 0, nm1_sectors))
            - G
            - G_savings
        ),
    ),
    dict(name="Government consumption", equation=P * C_G - alpha_G * G),
    dict(name="After-tax consumer price", equation=P  P_Y * (1 + tau)),
    dict(name="Exogenous government savings", equation=G_savings - G_savings_bar),

    # Household block
    dict(name="Gross household income", equation=income - w * Ls - r * Ks),
    dict(name="Net household income", equation=net_income - (1 - tau_income) * income),
    dict(
        name="Household utility",
        equation=U - sp.Product(C**gamma, (i, 0, nm1_sectors)),
    ),
    dict(name="Household savings", equation=hh_savings - mps * net_income),
    dict(name="Household consumption", equation=gamma * (1 - mps) * net_income - P * C),
    
    # International concepts
    dict(name="Law of one price", equation=P - e * P_w * (1 + tau_m)),
    dict(name="World price", equation=Ed - f * P_w**sigma),
    dict(name="Trade balance", equation=TB + e * sp.Sum(P_w * Ed, (i, 0, nm1_sectors))),
    dict(name="Exogenous exchange rate", equation=e - e_bar),
    
    # Market clearing
    dict(
        name="Goods market equlibrium",
        equation=hh_savings - Is + G_savings - TB - walras_resid,
    ),

    dict(name="Labor market clearing", equation=Ls - sp.Sum(Ld, (i, 0, nm1_sectors))),
    dict(name="Capital market clearing", equation=Ks - sp.Sum(Kd, (i, 0, nm1_sectors))),
    dict(
        name="Final goods market clearing",
        # Subbing j to i is a bit weird here, but the indices on the rest of the variables need to match
        equation=C + sp.Sum(CIJ.subs({i: k}), (k, 0, nm1_sectors)).subs({j: i}) + Id + C_G - Ed - Y,
    ),
    dict(
        name="Numeraire price level",
        equation=P.subs({i: sectors.index("Ag")}) - P_Ag_bar,
    ),
]

In [None]:
latex_print_equations(equation_info, variables, variable_info, parameters, parameter_info)

## Compile Model

In [None]:
equations = [d.get("equation").doit() for d in equation_info]

full_system, named_variables, named_params = expand_compact_system(
    compact_equations=equations,
    compact_variables=variables,
    compact_params=parameters,
    index_dict=index_dict,
)

n_params = len(named_params)

In [None]:
loss_funcs, root_funcs, ordered_inputs = compile_cge_to_numba(
    compact_equations=equations,
    compact_variables=variables,
    compact_params=parameters,
    index_dict=index_dict,
)
f_loss, f_grad, f_hess = loss_funcs
f_system, f_jac = root_funcs

In [None]:
f_dX = numba_linearize_cge_func(equations, variables, parameters, index_dict)

# Calibration

In [None]:
short_to_long = {"Ag": "Agriculture", "Ind": "Industry", "Serv": "Services"}

# Normalize prices
initial_values = {w: 1, r: 1, e: 1, walras_resid: 0}
initial_values.update(symbols("P_VA", 1, sectors))
initial_values.update(symbols("P_IC", 1, sectors))
initial_values.update(symbols("P", 1, sectors))

# Values provided by the exercise (from where?)
initial_values["import_tax_Ag"] = 0
initial_values["import_tax_Ind"] = 300
initial_values["import_tax_Serv"] = 0

initial_values[symbol("sigma", "Ag")] = -3
initial_values[symbol("sigma", "Ind")] = 3
initial_values[symbol("sigma", "Serv")] = -3

# Enter data from SAM
initial_values["income_tax"] = df.loc[("Institution", "Govt"), ("Institution", "Household")]
initial_values[hh_savings] = df.loc[("Other", "Capital Accumulation"), ("Institution", "Household")]
initial_values[G_savings] = df.loc[("Other", "Capital Accumulation"), ("Institution", "Govt")]

# Aggregates
initial_values[Ls] = df.loc[("Institution", "Household"), ("Factor", "Labor")]
initial_values[Ks] = df.loc[("Institution", "Firms"), ("Factor", "Capital")]

initial_values[income] = (
    initial_values[w] * initial_values[Ls] + initial_values[r] * initial_values[Ks]
)
initial_values[tau_income] = initial_values["income_tax"] / initial_values[income]
initial_values[net_income] = (1 - initial_values[tau_income]) * initial_values[income]

for sector in sectors:
    long_sector = short_to_long[sector]
    initial_values[symbol("Ld", sector)] = (
        df.loc[("Factor", "Labor"), ("Activities", long_sector)] / initial_values[w]
    )
    initial_values[symbol("Kd", sector)] = (
        df.loc[("Factor", "Capital"), ("Activities", long_sector)] / initial_values[r]
    )
    initial_values[symbol("C_G", sector)] = df.loc[
        ("Production", long_sector), ("Institution", "Govt")
    ]

    import_tax = initial_values[f"import_tax_{sector}"]
    initial_values[f"sales_tax_{sector}"] = (
        df.loc[("Institution", "Govt"), ("Production", long_sector)] - import_tax
    )
    initial_values[symbol("C", sector)] = df.loc[
        ("Production", long_sector), ("Institution", "Household")
    ]

    # Excess demand -- What is this in terms of imports/exports?
    world_supply = df.loc[("Other", "Rest of World"), ("Production", long_sector)]
    world_demand = df.loc[("Production", long_sector), ("Other", "Rest of World")]
    initial_values[symbol("Ed", sector)] = (world_supply - world_demand) + import_tax

    initial_values[symbol("Id", sector)] = df.loc[
        ("Production", long_sector), ("Other", "Capital Accumulation")
    ]

    # This computes Y, but what does this ultimately work out to? There is a lot of cancellations happening.
    total_i = (
        df.sum(axis=0).loc[("Production", long_sector)]
        - df.loc[("Production", long_sector), ("Other", "Rest of World")]
    )
    initial_values[symbol("Y", sector)] = total_i - initial_values[symbol("Ed", sector)]

    # Work out the tax rate from tax reciepts
    P_sector = symbol("P_Y", sector)
    tax_rate = symbol("tau", sector)
    tariff_rate = symbol("tau_m", sector)
    T_sec = initial_values[f"sales_tax_{sector}"]
    Y_sec = initial_values[symbol("Y", sector)]
    P_sec = initial_values[symbol("P", sector)]
    Ed_sec = initial_values[symbol("Ed", sector)]

    initial_values[tax_rate] = T_sec / (Y_sec - T_sec)
    initial_values[P_sector] = P_sec / (1 + initial_values[tax_rate])
    initial_values[tariff_rate] = import_tax / (Ed_sec * P_sec - import_tax) if Ed_sec > 0 else 0
    initial_values[symbol("P_w", sector)] = P_sec / (1 + initial_values[tariff_rate])

# Adjust values in SAM to account for non-normalized prices
for sector in sectors:
    long_sector = short_to_long[sector]
    for sector_j in sectors:
        long_sector_j = short_to_long[sector_j]
        SAM_idx = ("Production", long_sector_j), ("Activities", long_sector)
        initial_values[symbol("CIJ", sector, sector_j)] = df.loc[SAM_idx]


initial_values = remove_string_keys(initial_values)

In [None]:
def state_dict_to_input_arrays(state_dict, named_variables, named_params):
    x = np.array([state_dict[k] for k in named_variables], dtype=float)
    theta = np.array([state_dict[x] for x in named_params], dtype=float)

    return x, theta

In [None]:
state_0 = recursive_solve_symbolic(full_system, initial_values)
x0, theta0 = state_dict_to_input_arrays(state_0, named_variables, named_params)

### Check calibration

In [None]:
f_loss(x0, theta0)

# Simulations

In [None]:
scenario_df = pd.DataFrame(x0, index=[x.name for x in named_variables], columns=["Initial"])

In [None]:
tariff_cut_theta = state_0.copy()
tariff_cut_theta.update(
    {k: state_0[k] * 0.5 for k in [symbol("tau_m", sector) for sector in sectors]}
)
tariff_x, tariff_theta = state_dict_to_input_arrays(tariff_cut_theta, named_variables, named_params)

In [None]:
tariff_cut_scenario = euler_approx(f_dX, x0, theta0, tariff_theta, 10_000)
print(
    f"Linear Loss: {f_loss(tariff_cut_scenario[:-n_params], tariff_cut_scenario[-n_params:]):0.5}"
)

In [None]:
tariff_cut_res = optimize.minimize(
    f_loss,
    tariff_cut_scenario[:-n_params],
    args=tariff_theta,
    jac=f_grad,
    hess=f_hess,
    method="trust-exact",
    tol=1e-4,
)

assert tariff_cut_res.success, (tariff_cut_res.message, tariff_cut_res.fun)
assert tariff_cut_res.x[-1] < 1e-4, tariff_cut_res.x[-1]
print(f"Final Loss: {f_loss(tariff_cut_res.x, tariff_theta):0.5}")
scenario_df["tariff_cut"] = tariff_cut_res.x

In [None]:
scenario_df[scenario_df.diff(axis=1)["tariff_cut"] > 100]

In [None]:
scenario_df[scenario_df.pct_change(axis=1)["tariff_cut"].abs() > 0.01].plot.bar()