In [4]:
import sys
from gams.core import gdx
import xarray as xr

gams_dir = "/Library/Frameworks/GAMS.framework/Resources"
path = "/Users/jessegrabowski/Downloads/GDX11aPower17/GSDFDAT.GDX_API"

In [5]:
#: String representations of API constants for G(a)MS D(ata) T(ypes)
type_str = {
    gdx.GMS_DT_SET: "set",
    gdx.GMS_DT_PAR: "parameter",
    gdx.GMS_DT_VAR: "variable",
    gdx.GMS_DT_EQU: "equation",
    gdx.GMS_DT_ALIAS: "alias",
}


#: String representations of API constants for G(a)MS VAR(iable) TYPE(s)
vartype_str = {
    gdx.GMS_VARTYPE_UNKNOWN: "unknown",
    gdx.GMS_VARTYPE_BINARY: "binary",
    gdx.GMS_VARTYPE_INTEGER: "integer",
    gdx.GMS_VARTYPE_POSITIVE: "positive",
    gdx.GMS_VARTYPE_NEGATIVE: "negative",
    gdx.GMS_VARTYPE_FREE: "free",
    gdx.GMS_VARTYPE_SOS1: "sos1",
    gdx.GMS_VARTYPE_SOS2: "sos2",
    gdx.GMS_VARTYPE_SEMICONT: "semicont",
    gdx.GMS_VARTYPE_SEMIINT: "semiint",
    gdx.GMS_VARTYPE_MAX: "max",
}


class GDX_API(object):
    """Wrapper around the `GDX_API API`_."""

    #: Methods that conform to the semantics of :func:`call`.
    __valid = [
        "CreateD",
        "DataReadStr",
        "DataReadStrStart",
        "ErrorCount",
        "ErrorStr",
        "FileVersion",
        "GetElemText",
        "GetLastError",
        "OpenRead",
        "SymbolGetDomain",
        "SymbolGetDomainX",
        "SymbolInfo",
        "SymbolInfoX",
        "SystemInfo",
    ]

    def __init__(self, gams_dir):
        """Constructor."""
        self._handle = gdx.new_gdxHandle_tp()
        self.error_count = 0
        self.call("CreateD", gams_dir, gdx.GMS_SSSIZE)

    def call(self, method, *args):
        """Invoke the GDX_API API method named gdx\ *Method*.

        Optional positional arguments *args* are passed to the API method.
        Returns the result of the method call, with the return code stripped.
        Refer to the GDX_API API documentation for the type and number of arguments
        and return values for any method.

        If the call fails, raise an appropriate exception.

        """
        if method not in self.__valid:
            raise NotImplementedError(f"GDX_API.call() cannot invoke gdx.gdx{method}")
        ret = getattr(gdx, f"gdx{method}")(self._handle, *args)
        if isinstance(ret, int):
            return ret
        if ret[0]:
            # unwrap a 1-element array
            if len(ret) == 2:
                return ret[1]
            else:
                return ret[1:]
        else:
            if method == "OpenRead":
                error_str = self.call("ErrorStr", ret[1])
                if error_str == "No such file or directory":
                    raise FileNotFoundError(f"[gdx{method}] {error_str}: '{args[0]}'")
            else:
                error_count = self.call("ErrorCount")
                if error_count > self.error_count:
                    self.error_count = error_count
                    error_num = self.call("GetLastError")
                    error_str = self.call("ErrorStr", error_num)
                    raise Exception(f"[gdx{method}] {error_str}")
                else:
                    raise RuntimeError(f"[gdx{method}] returned {ret} for arguments {args}")

    def __getattr__(self, name):
        """Name mangling for method invocation without call()."""
        mangle = name.title().replace("_", "")
        if mangle in self.__valid:

            def wrapper(*args):
                return self.call(mangle, *args)

            return wrapper
        else:
            raise AttributeError(name)

In [8]:
import logging

_log = logging.Logger("gdx_file_", level=logging.DEBUG)


class File(xr.Dataset):
    """Load the file at *filename* into memory.

    If *lazy* is ``True`` (default), then the data for GDX_API Parameters is not
    loaded until each individual parameter is first accessed; otherwise all
    parameters except those listed in *skip* (default: empty) are loaded
    immediately.

    If *implicit* is ``True`` (default) then, for each dimension of any GDX_API
    Parameter declared over '*' (the universal set), an implicit set is
    constructed, containing only the labels appearing in the respective
    dimension of that parameter.

    .. note::

       For instance, the GAMS Parameter ``foo(*,*,*)`` is loaded as
       ``foo(_foo_0,_foo_1,_foo_2)``, where ``_foo_0`` is an implicit set that
       contains only labels appearing along the first dimension of ``foo``,
       etc. This workaround is essential for GDX_API files where ``*`` is large;
       otherwise, loading ``foo`` as declared raises :py:class:`MemoryError`.

    """

    __slots__ = ("_api", "_alias", "_index", "_implicit", "_index", "_state")

    def __init__(self, filename="", gams_dir=None, lazy=True, implicit=True, skip=set()):
        super(File, self).__init__()  # Invoke Dataset constructor
        self._index = []
        self._state = {}
        self._alias = {}
        self._implicit = implicit

        self._api = GDX_API(gams_dir)
        self._api.open_read(filename)

        # Basic information about the GDX_API file
        v, p = self._api.file_version()
        sc, ec = self._api.system_info()
        self.attrs["version"] = v.strip()
        self.attrs["producer"] = p.strip()
        self.attrs["symbol_count"] = sc
        self.attrs["element_count"] = ec

        # Initialize private variables
        self._index = [None for _ in range(sc + 1)]
        self._state = {}
        self._alias = {}
        self._implicit = implicit

        # Read symbols
        for s_num in range(sc + 1):
            name, type_code = self._load_symbol(s_num)
            print(name, type_code)
            if type_code == gdx.GMS_DT_SET and name not in skip:
                self._load_symbol_data(name)

    #         if not lazy:
    #             for name in filter(None, self._index):
    #                 if name not in skip:
    #                     self._load_symbol_data(name)

    def _load_symbol(self, index):
        """Load the *index*-th Symbol in the GDX_API file."""
        # Load basic information
        name, dim, type_code = self._api.symbol_info(index)
        n_records, vartype, desc = self._api.symbol_info_x(index)

        self._index[index] = name  # Record the name

        attrs = {
            "index": index,
            "name": name,
            "dim": dim,
            "type_code": type_code,
            "records": n_records,
            "vartype": vartype,
            "description": desc,
        }

        # Assemble a string description of the Symbol's type
        type_str_ = type_str[type_code]
        if type_code == gdx.GMS_DT_PAR and dim == 0:
            type_str_ = "scalar"
        try:
            vartype_str_ = vartype_str[vartype]
        except KeyError:  # pragma: no cover
            # Some other vartype is returned that's not described by the GDX_API
            # API docs
            vartype_str_ = ""
        attrs["type_str"] = f"{vartype_str_} {type_str_}"

        _log.debug(
            str("Loading #{index} {name}: {dim}-D, {records} records, " '"{description}"').format(
                **attrs
            )
        )

        # Equations and Aliases require limited processing
        if type_code == gdx.GMS_DT_EQU:
            info("Loading of GMS_DT_EQU not implemented: {} {} not loaded.".format(index, name))
            self._state[name] = None
            return name, type_code
        elif type_code == gdx.GMS_DT_ALIAS:
            parent = desc.replace("Aliased with ", "")
            self._alias[name] = parent
            assert self[parent].attrs["_gdx_type_code"] == gdx.GMS_DT_SET
            # Duplicate the variable
            self._variables[name] = self._variables[parent]
            self._state[name] = True
            super(File, self).set_coords(name, inplace=True)
            return name, type_code

        # The Symbol is either a Set, Parameter or Variable
        try:  # Read the domain, as a list of names
            domain = self._api.symbol_get_domain_x(index)
            _log.debug("domain: {}".format(domain))
        except Exception:  # gdxSymbolGetDomainX fails for the universal set
            assert name == "*"
            domain = []

        # Cache the attributes
        attrs["domain"] = domain
        self._state[name] = {"attrs": attrs}

        return name, type_code

    def _load_symbol_data(self, name):
        """Load the Symbol *name*."""
        if self._state[name] in (True, None):  # Skip Symbols already loaded
            return

        # Unpack attributes
        attrs = self._state[name]["attrs"]
        index, dim, domain, records = [attrs[k] for k in ("index", "dim", "domain", "records")]

        # Read the data
        self._cache_data(name, index, dim, records)

        # If the GAMS method 'sameas' is invoked in a program, the resulting
        # GDX_API file contains an empty Set named 'SameAs' with domain (*,*). Do
        # not read this
        if name == "SameAs" and domain == ["*", "*"]:
            self._state[name] = None
            self._index[index] = None
            return

        domain = self._infer_domain(name, domain, self._state[name]["elements"])

        # Create an xr.DataArray with the Symbol's data
        self._add_symbol(name, dim, domain, attrs)

    def _cache_data(self, name, index, dim, records):
        """Read data for the Symbol *name* from the GDX_API file."""
        # Initiate the data read. The API method returns a number of records,
        # which should match that given by gdxSymbolInfoX in _load_symbol()
        records2 = self._api.data_read_str_start(index)
        assert records == records2, (
            "{}: gdxSymbolInfoX ({}) and gdxDataReadStrStart ({}) disagree on" " number of records."
        ).format(name, records, records2)

        # Indices of data records, one list per dimension
        elements = [list() for _ in range(dim)]
        # Data points. Keys are index tuples, values are data. For a 1-D Set,
        # the data is the GDX_API 'string number' of the text associated with the
        # element
        data = {}
        try:
            while True:  # Loop over all records
                labels, value, _ = self._api.data_read_str()  # Next record
                # Update elements with the indices
                for j, label in enumerate(labels):
                    if label not in elements[j]:
                        elements[j].append(label)
                # Convert a 1-D index from a tuple to a bare string
                key = labels[0] if dim == 1 else tuple(labels)
                # The value is a sequence, containing the level, marginal,
                # lower & upper bounds, etc. Store only the value (first
                # element).
                data[key] = value[gdx.GMS_VAL_LEVEL]
        except Exception:
            if len(data) == records:
                pass  # All data has been read
            else:  # pragma: no cover
                raise  # Some other read error

        # Cache the read data
        self._state[name].update(
            {
                "data": data,
                "elements": elements,
            }
        )

    def _infer_domain(self, name, domain, elements):
        """Infer the domain of the Symbol *name*.

        Lazy GAMS modellers may create variables like myvar(*,*,*,*). If the
        size of the universal set * is large, then attempting to instantiate a
        xr.DataArray with this many elements may cause a MemoryError. For every
        dimension of *name* defined on the domain '*' this method tries to find
        a Set from the file which contains all the labels appearing in *name*'s
        data.

        """
        if "*" not in domain:
            return domain
        _log.debug("guessing a better domain for {}: {}".format(name, domain))

        # Domain as a list of references to Variables in the File/xr.Dataset
        domain_ = [self[d] for d in domain]

        for i, d in enumerate(domain_):  # Iterate over dimensions
            e = set(elements[i])
            if d.name != "*" or len(e) == 0:  # pragma: no cover
                assert set(d.values).issuperset(e)
                continue  # The stated domain matches the data; or no data
            # '*' is given
            if self._state[name]["attrs"]["type_code"] == gdx.GMS_DT_PAR and self._implicit:
                d = "_{}_{}".format(name, i)
                _log.debug(
                    (
                        "Constructing implicit set {} for dimension {} of {}\n"
                        " {} instead of {} elements"
                    ).format(d, name, i, len(e), len(self["*"]))
                )
                self.coords[d] = elements[i]
                d = self[d]
            else:
                # try to find a smaller domain for this dimension
                # Iterate over every Set/Coordinate
                for s in self.coords.values():
                    if s.ndim == 1 and set(s.values).issuperset(e) and len(s) < len(d):
                        d = s  # Found a smaller Set; use this instead
            domain_[i] = d

        # Convert the references to names
        inferred = [d.name for d in domain_]

        if domain != inferred:
            # Store the result
            self._state[name]["attrs"]["domain_inferred"] = inferred
            _log.debug("…inferred {}.".format(inferred))
        else:
            _log.debug("…failed.")

        return inferred

    def _root_dim(self, dim):
        """Return the ultimate ancestor of the 1-D Set *dim*."""
        parent = self[dim].dims[0]
        return dim if parent == dim else self._root_dim(parent)

    def _empty(self, *dims, **kwargs):
        """Return an empty numpy.ndarray for a GAMS Set or Parameter."""
        size = []
        dtypes = []
        for d in dims:
            size.append(len(self[d]))
            dtypes.append(self[d].dtype)
        dtype = kwargs.pop("dtype", numpy.result_type(*dtypes))
        fv = kwargs.pop("fill_value")
        return numpy.full(size, fill_value=fv, dtype=dtype)

    def _add_symbol(self, name, dim, domain, attrs):
        """Add a xray.DataArray with the data from Symbol *name*."""
        # Transform the attrs for storage, unpack data
        gdx_attrs = {"_gdx_{}".format(k): v for k, v in attrs.items()}
        data = self._state[name]["data"]
        elements = self._state[name]["elements"]

        # Erase the cache; this also prevents __getitem__ from triggering lazy-
        # loading, which is still in progress
        self._state[name] = True

        kwargs = {}  # Arguments to xr.Dataset.__setitem__()
        if dim == 0:
            # 0-D Variable or scalar Parameter
            super(File, self).__setitem__(name, ([], data.popitem()[1], gdx_attrs))
            return
        elif attrs["type_code"] == gdx.GMS_DT_SET:  # GAMS Set
            if dim == 1:
                # One-dimensional Set
                self.coords[name] = elements[0]
                self.coords[name].attrs = gdx_attrs
            else:
                # Multi-dimensional Sets are mappings indexed by other Sets;
                # elements are either 'on'/True or 'off'/False
                kwargs["dtype"] = bool
                kwargs["fill_value"] = False

                # Don't define over the actual domain dimensions, but over the
                # parent Set/xr.Coordinates for each dimension
                dims = [self._root_dim(d) for d in domain]

                # Update coords
                self.coords.__setitem__(name, (dims, self._empty(*domain, **kwargs), gdx_attrs))

                # Store the elements
                for k in data.keys():
                    self[name].loc[k] = k if dim == 1 else True
        else:  # 1+-dimensional GAMS Parameters
            kwargs["dtype"] = float
            kwargs["fill_value"] = numpy.nan

            dims = [self._root_dim(d) for d in domain]  # Same as above

            # Create an empty xr.DataArray; this ensures that the data
            # read in below has the proper form and indices
            super(File, self).__setitem__(name, (dims, self._empty(*domain, **kwargs), gdx_attrs))

            # Fill in extra keys
            longest = numpy.argmax(self[name].values.shape)
            iters = []
            for i, d in enumerate(dims):
                if i == longest:
                    iters.append(self[d].to_index())
                else:
                    iters.append(cycle(self[d].to_index()))
            data.update({k: numpy.nan for k in set(zip(*iters)) - set(data.keys())})

            # Use pandas and xarray IO methods to convert data, a dict, to a
            # xr.DataArray of the correct shape, then extract its values
            tmp = pandas.Series(data)
            tmp.index.names = dims
            tmp = xr.DataArray.from_series(tmp).reindex_like(self[name])
            self[name].values = tmp.values

    def dealias(self, name):
        """Identify the GDX_API Symbol that *name* refers to, and return the
        corresponding :py:class:`xarray.DataArray`."""
        return self[self._alias[name]] if name in self._alias else self[name]

    def extract(self, name):
        """Extract the GAMS Symbol *name* from the dataset.

        The Sets and Parameters in the :class:`File` can be accessed directly,
        as e.g. `f['name']`; but for more complex xarray operations, such as
        concatenation and merging, this carries along sub-Sets and other
        Coordinates which confound xarray.

        :func:`extract()` returns a self-contained :py:class:`xarray.DataArray`
        with the declared dimensions of the Symbol (and *only* those
        dimensions), which does not make reference to the :class:`File`.
        """
        # Copy the Symbol, triggering lazy-loading if needed
        result = self[name].copy()

        # Declared dimensions of the Symbol, and their parents
        try:
            domain = result.attrs["_gdx_domain_inferred"]
        except KeyError:  # No domain was inferred for this Symbol
            domain = result.attrs["_gdx_domain"]
        dims = {c: self._root_dim(c) for c in domain}
        keep = set(dims.keys()) | set(dims.values())

        # Extraneous dimensions
        drop_coords = set(result.coords) - keep

        # Reduce the data
        for c, p in dims.items():
            if c == "*":  # Dimension is '*', drop empty labels
                result = result.dropna(dim="*", how="all")
            elif c == p:  # Dimension already indexed by the correct coord
                continue
            else:
                # Dimension is indexed by 'p', but declared 'c'. First drop
                # the elements which do not appear in the sub-Set c;, then
                # rename 'p' to 'c'
                drop = set(self[p].values) - set(self[c].values) - set("")
                result = result.drop(drop, dim=p).swap_dims({p: c})
                # Add the old coord to the set of coords to drop
                drop_coords.add(p)
        # Do this last, in case two dimensions have the same parent (p)
        return result.drop(drop_coords)

    def info(self, name):
        """Informal string representation of the Symbol with *name*."""
        if isinstance(self._state[name], dict):
            attrs = self._state[name]["attrs"]
            return "{} {}({}), {} records: {}".format(
                attrs["type_str"],
                name,
                ",".join(attrs["domain"]),
                attrs["records"],
                attrs["description"],
            )
        else:
            return repr(self[name])

    def _loaded_and_cached(self, type_code):
        """Return a list of loaded and not-loaded Symbols of *type_code*."""
        names = set()
        for name, state in self._state.items():
            if state is True:
                tc = self._variables[name].attrs["_gdx_type_code"]
            elif isinstance(state, dict):
                tc = state["attrs"]["type_code"]
            else:  # pragma: no cover
                continue
            if tc == type_code:
                names.add(name)
        return names

    def set(self, name, as_dict=False):
        """Return the elements of GAMS Set *name*.

        Because :py:mod:`xarray` stores non-null labels for each element of a
        coord, a GAMS sub-Set will contain some ``''`` elements, corresponding
        to elements of the parent Set which do not appear in *name*.
        :func:`set()` returns the elements without these placeholders.

        """
        assert (
            self[name].attrs["_gdx_type_code"] == gdx.GMS_DT_SET
        ), "Variable {} is not a GAMS Set".format(name)
        if len(self[name].dims) > 1:
            return self[name]
        elif as_dict:
            from collections import OrderedDict

            result = OrderedDict()
            parent = self[name].attrs["_gdx_domain"][0]
            for label in self[parent].values:
                result[label] = label in self[name].values
            return result
        else:
            return list(self[name].values)

    def sets(self):
        """Return a list of all GDX_API Sets."""
        return self._loaded_and_cached(gdx.GMS_DT_SET)

    def parameters(self):
        """Return a list of all GDX_API Parameters."""
        return self._loaded_and_cached(gdx.GMS_DT_PAR)

    def get_symbol_by_index(self, index):
        """Retrieve the GAMS Symbol from the *index*-th position of the
        :class:`File`."""
        return self[self._index[index]]

    def __getitem__(self, key):
        """Set element access."""
        try:
            return super(File, self).__getitem__(key)
        except KeyError as e:
            if isinstance(self._state[key], dict):
                _log.debug("Lazy-loading {}".format(key))
                self._load_symbol_data(key)
                return super(File, self).__getitem__(key)
            else:
                raise raise_from(KeyError(key), e)

In [9]:
f = File(path, gams_dir)

* 0
metaData 0
C 0
MAPC 0


KeyError: 'REG'

In [17]:
f._attrs

{'version': 'GDX_API Library      39.3.0 55b56f9b Jul 7, 2022           WEI x86 64bit/MS Window',
 'producer': 'GAMS Base Module 39.3.0 55b56f9b Jul 7, 2022           WEI x86 64bit/MS Window',
 'symbol_count': 46,
 'element_count': 358}

In [30]:
[gdx.gdxSymbolDim(h, i) for i in range(1, 40)]

[1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 3,
 3,
 3,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 2,
 4,
 1,
 1,
 1]

In [10]:
# import gdxpy as gp
# import os

# Create a GDX_API object
# gdxpath = os.path.join(gp.get_gams_root(), 'testlib_ml', 'trnsport.gdx')
# tdata = gp.GdxFile(gdxpath)

In [4]:
# h = gdx.new_gdxHandle_tp()
# ret, errno = gdx.gdxOpenRead(h, path)

In [3]:
??gdx.gdxOpenRead

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import squarify

import sys
import json
from copy import deepcopy

sys.path.append("..")
sys.path.append("../..")

from cge_modeling import Variable, Parameter, Equation, CGEModel
from cge_modeling.pytensorf.compile import (
    compile_cge_model_to_pytensor,
    euler_approximation_from_CGEModel,
    pytensor_objects_from_CGEModel,
)

In [2]:
from itertools import chain, compress
from cge_modeling.base.utilities import variable_dict_to_flat_array
from cge_modeling.production_functions import CES, dixit_stiglitz, leontief

In [3]:
sectors = ["Ag", "Ind", "Serv"]
eprod = ["coal", "nuc"]
n_sectors = len(sectors)
n_eprod = len(eprod)
coords = {"i": sectors, "j": sectors, "k": eprod, "g": eprod, "egrid": 1}
BACKEND = "pytensor"

In [20]:
variable_info = [
    # Firm variables (7)
    Variable(name="Y", dims="i", description="Final output in the <dim:i> sector"),
    Variable(name="Y_E", dims="k", description="Final output in the <dim:k> energy sector"),
    Variable(name="VA", dims="i", description="Value-added component of <dim:i> sector production"),
    Variable(name="VC", dims="i", description="Value-chain component of <dim:i> sector production"),
    Variable(name="VC_E", dims="k", description="Value-chain component of <dim:k> energy producer"),
    Variable(
        name="X",
        dims=("i", "j"),
        description="Demand for <dim:i> sector goods by the <dim:j> sector as value-chain inputs",
    ),
    Variable(
        name="X_E",
        dims=("i", "k"),
        description="Demand for <dim:k> sector goods by the <dim:g> sector as value-chain inputs",
    ),
    # Captial Labour bundle
    Variable(
        name="L_S_d",
        dims="i",
        extend_subscript=True,
        description="Labor demand in the <dim:i> sector",
    ),
    Variable(
        name="L_E_d",
        dims="k",
        extend_subscript=True,
        description="Labor demand in the <dim:k> e-sector",
    ),
    Variable(
        name="K_S_d",
        dims="i",
        extend_subscript=True,
        description="Capital demand in the <dim:i> sector",
    ),
    Variable(
        name="K_E_d",
        dims="k",
        extend_subscript=True,
        description="Capital demand in the <dim:k> e-sector",
    ),
    Variable(
        name="KE_S_d",
        dims="i",
        extend_subscript=True,
        description="Labor Capital demand in the <dim:i> sector",
    ),
    Variable(
        name="KL_E_d",
        dims="k",
        extend_subscript=True,
        description="Capital energy bundle Capital demand in the <dim:k> e-sector",
    ),
    # Raw energy producer
    Variable(
        name="W_s",
        dims="k",
        extend_subscript=True,
        description="Raw energy demand of <dim:k> sector",
    ),
    Variable(
        name="R_d",
        dims="k",
        extend_subscript=True,
        description="Ressource demand of <dim:k> sector",
    ),
    Variable(
        name="E_d",
        dims="i",
        extend_subscript=True,
        description="Energy demand in the <dim:i> sector",
    ),
    Variable(
        name="E_s",
        description="Energy supply from energy bundler",
    ),
    # Investment
    Variable(
        name="I_d",
        dims="i",
        extend_subscript=True,
        description="Investment capital demanded by the <dim:i> sector",
    ),
    # Prices (7)
    Variable(name="P", dims="i", description="Final good price in the <dim:i> sector, after taxes"),
    Variable(
        name="P_Y",
        dims="i",
        extend_subscript=True,
        description="Final good price in the <dim:i> sector, before taxes",
    ),
    Variable(
        name="P_Y_E",
        dims="k",
        extend_subscript=True,
        description="Final good price in the <dim:k> energy sector, before taxes",
    ),
    Variable(
        name="P_Y_E_star",
        dims="k",
        extend_subscript=True,
        description="Electricity price level for purchasing <dim:k> energy",
    ),
    Variable(
        name="P_VA",
        dims="i",
        extend_subscript=True,
        description="Price of the value-add component in the <dim:i> sector",
    ),
    Variable(
        name="P_VC",
        dims="i",
        extend_subscript=True,
        description="Price of the value-chain component in the <dim:i> sector",
    ),
    Variable(
        name="P_VC_E",
        dims="k",
        extend_subscript=True,
        description="Price of the value-chain component in the <dim:k> sector",
    ),
    Variable(
        name="P_KE_S",
        dims="i",
        extend_subscript=True,
        description="Capital labour bundle price in the <dim:i> sector",
    ),
    Variable(
        name="P_KL_E",
        dims="k",
        extend_subscript=True,
        description="Capital labour bundle price in the <dim:k> sector",
    ),
    Variable(
        name="P_E",
        extend_subscript=True,
        description="Electricity price",
    ),
    Variable(
        name="P_E_star",
        dims="i",
        extend_subscript=True,
        description="Electricity after tax price of <dim:i> sector",
    ),
    Variable(
        name="P_W",
        dims="k",
        extend_subscript=True,
        description="Raw energy price",
    ),
    Variable(
        name="P_R",
        dims="k",
        extend_subscript=True,
        description="Ressource energy price",
    ),
    Variable(name="r", description="Rental rate of capital"),
    Variable(name="w", description="Wage level"),
    Variable(
        name="r_star_S",
        latex_name="r^\\star_S",
        dims=("i",),
        description="<dim:i> after-tax rental rate",
        positive=True,
    ),
    Variable(
        name="r_star_E",
        latex_name="r^\\star_E",
        dims=("k",),
        description="<dim:k> after-tax rental rate",
        positive=True,
    ),
    Variable(
        name="w_star_S",
        dims=("i",),
        latex_name="w^\\star_S",
        description="<dim: i> after-tax wage level",
        positive=True,
    ),
    Variable(
        name="w_star_E",
        dims=("k",),
        latex_name="w^\\star_E",
        description="<dim: k> after-tax wage level",
        positive=True,
    ),
    # Household Variables
    Variable(name="U", description="Household utility"),
    Variable(name="C", dims="i", description="Household consumption of <dim:i> goods"),
    Variable(name="C_total", latex_name="\bar{C}", description="Household consumption bundle"),
    Variable(name="F", description="Household leisure time"),
    Variable(
        name="I_s", extend_subscript=True, description="Investment capital supplied by households"
    ),
    Variable(name="S", description="Household savings"),
    Variable(name="income", latex_name="Omega", description="Household income, before taxes"),
    Variable(
        name="net_income", latex_name="\hat{\Omega}", description="Household income, after taxes"
    ),
    # Government variables
    Variable(name="G", description="Government budget"),
    Variable(
        name="C_G",
        dims="i",
        extend_subscript=True,
        description="Government consumption of <dim:i> goods",
    ),
    Variable(
        name="S_G",
        extend_subscript=True,
        positive=None,
        description="Investment capital supplied by government",
    ),
    # Misc
    Variable(name="resid", latex_name=r"varepsilon", description="Walrasian residual"),
    Variable(name="L_s", description="Household labor supply"),
    Variable(name="CPI", latex_name="P_C", description="Price of the household consumption basket"),
]

param_info = [
    # Production Parameters
    Parameter(
        "alpha_VA",
        dims="i",
        description="Share of capital in production of the <dim:i> sector value-add bundle",
    ),
    Parameter(
        "alpha_KE_S",
        dims="i",
        description="Share of capital in production of the <dim:i> producer capital labour bundle",
    ),
    Parameter(
        "alpha_KL_E",
        dims="k",
        description="Share of capital in production of the <dim:k> producer capital labour bundle",
    ),
    Parameter(
        "alpha_W",
        dims="k",
        description="Share of capital in production of the <dim:k> producer capital labour bundle",
    ),
    Parameter(
        "alpha_k",
        dims="i",
        extend_subscript=True,
        description="Share of capital investment demanded by the <dim:i> sector",
    ),
    Parameter(
        "alpha_Es",
        dims="k",
        extend_subscript=True,
        description="Share of raw energy type <dim:k> in total energy mix",
    ),
    # Technology parameters
    Parameter("A_VA", dims="i", description="Total factor productivity of the <dim:i> sector"),
    Parameter("A_W", dims="k", description="Total factor productivity of the <dim:k> sector"),
    Parameter(
        "A_KE_S", dims="i", description="Captial Labour factor productivity of <dim:i> producer"
    ),
    Parameter(
        "A_KL_E", dims="k", description="Captial Labour factor productivity of <dim:k> producer"
    ),
    Parameter("A_Es", description="Total factor productivity of Energy production"),
    # Leontief shares
    Parameter(
        "psi_VA",
        extend_subscript=True,
        dims="i",
        description="Share of value-add bundle in <dim:i> sector final good production",
    ),
    Parameter(
        "psi_VC",
        extend_subscript=True,
        dims="i",
        description="Share of value chain bundle in <dim:i> sector final good production",
    ),
    Parameter(
        "psi_X",
        extend_subscript=True,
        dims=("i", "j"),
        description="Share of <dim:j> sector final goods in the <dim:i> value chain bundle",
    ),
    Parameter(
        "psi_VC_E",
        extend_subscript=True,
        dims="k",
        description="Share of value chain bundle in <dim:k> sector final good production",
    ),
    Parameter(
        "psi_X_E",
        extend_subscript=True,
        dims=("i", "k"),
        description="Share of <dim:g> sector final goods in the <dim:k> value chain bundle",
    ),
    Parameter(
        "psi_W",
        extend_subscript=True,
        dims="k",
        description="Share of energy in <dim:k> energy sector final energy production",
    ),
    # CES elasticities
    Parameter(
        name="epsilon_VA",
        extend_subscript=True,
        dims="i",
        description="Elasticity of subsitution between input factors in <dim:i> sector VA bundle",
    ),
    Parameter(
        name="epsilon_W",
        extend_subscript=True,
        dims="k",
        description="Elasticity of subsitution between input factors in <dim:k> sector W",
    ),
    Parameter(
        name="epsilon_KE_S",
        extend_subscript=True,
        dims="i",
        description="Elasticity of subsitution between input factors in <dim:i> producer KL bundle",
    ),
    Parameter(
        name="epsilon_KL_E",
        extend_subscript=True,
        dims="k",
        description="Elasticity of subsitution between input factors in <dim:k> producer KL bundle",
    ),
    Parameter(
        name="epsilon_Es",
        extend_subscript=True,
        description="Elasticity of subsitution between raw energy sources",
    ),
    # Tax parameters
    Parameter(name="tau_income", latex_name="\\tau_Y", description="Income tax rate"),
    Parameter(
        name="tau_sales_S",
        dims=("i",),
        extend_subscript=True,
        latex_name="\\tau_P_S",
        description="Sales tax rate in sector <dim:i>",
    ),
    Parameter(
        name="tau_energy_in",
        dims=("k",),
        extend_subscript=True,
        latex_name="\\tau_P_E",
        description="Tax on energy inputs <dim:k> to bundling",
    ),
    Parameter(
        name="tau_energy_out",
        dims=("i",),
        extend_subscript=True,
        latex_name="\\tau_P_E",
        description="Sales tax on energy output for sector <dim:i>",
    ),
    Parameter(
        name="tau_capital_S",
        dims=("i",),
        extend_subscript=True,
        latex_name="\\tau_r_S",
        description="Capital income tax in sector <dim:i>",
    ),
    Parameter(
        name="tau_capital_E",
        dims=("k",),
        extend_subscript=True,
        latex_name="\\tau_r_E",
        description="Capital income tax in sector <dim:k>",
    ),
    Parameter(
        name="tau_wage_S",
        dims=("i",),
        extend_subscript=True,
        latex_name="\\tau_w_S",
        description="Payroll tax in sector <dim:i>",
    ),
    Parameter(
        name="tau_wage_E",
        dims=("k",),
        extend_subscript=True,
        latex_name="\\tau_w_E",
        description="Payroll tax in sector <dim:k>",
    ),
    # Household parameters
    Parameter(
        "alpha_C",
        dims="i",
        description="Household elasticity of consumption utility for <dim:i> sector goods",
    ),
    Parameter(
        name="sigma_C",
        description="Arrow-Pratt risk averson",
    ),
    Parameter(name="A_C", description="Household shopping technology"),
    Parameter(name="T", description="Time endowment"),
    Parameter(name="sigma_L", description="Inverse Frisch elasticity between work and leisure"),
    Parameter(name="Theta", description="Household labor dispreference parameter"),
    Parameter(name="phi_C", description="Elasticity of substitution between consumption goods"),
    Parameter("mpc", latex_name="phi", description="Household marginal propensity to consume"),
    # Government parameters
    Parameter(
        "alpha_G",
        dims="i",
        description="Share of <dim:i> sector final goods in governmnet consumption",
    ),
    # Exogenous values
    Parameter("R_s", dims="k", description="Exogenous raw energy endowment"),
    Parameter("K_s", description="Exogenous capital supply"),
    Parameter("P_num", latex_name=r"\\bar{P}_{num}", description="Numeraire price"),
    Parameter(
        "S_G_bar",
        latex_name=r"\\bar{S}_G",
        description="Exogenous level of governmnet savings",
        positive=None,
    ),
]

va_eqs = CES(
    factors=["KE_S_d", "L_S_d"],
    factor_prices=["P_KE_S", "w_star_S"],
    output="VA",
    output_price="P_VA",
    TFP="A_VA",
    factor_shares="alpha_VA",
    epsilon="epsilon_VA",
    backend=BACKEND,
)

ke_s_eqs = CES(
    factors=["K_S_d", "E_d"],
    factor_prices=["r_star_S", "P_E_star"],
    output="KE_S_d",
    output_price="P_KE_S",
    TFP="A_KE_S",
    factor_shares="alpha_KE_S",
    epsilon="epsilon_KE_S",
    backend=BACKEND,
)
kl_e_eqs = CES(
    factors=["K_E_d", "L_E_d"],
    factor_prices=["r_star_E", "w_star_E"],
    output="KL_E_d",
    output_price="P_KL_E",
    TFP="A_KL_E",
    factor_shares="alpha_KL_E",
    epsilon="epsilon_KL_E",
    backend=BACKEND,
)
w_eqs = CES(
    factors=["KL_E_d", "R_d"],
    factor_prices=["P_KL_E", "P_R"],
    output="W_s",
    output_price="P_W",
    TFP="A_W",
    factor_shares="alpha_W",
    epsilon="epsilon_W",
    backend=BACKEND,
)

energy_bundler_eqs = dixit_stiglitz(
    factors="Y_E",
    factor_prices="P_Y_E_star",
    output="E_s",
    output_price="P_E",
    TFP="A_Es",
    factor_shares="alpha_Es",
    epsilon="epsilon_Es",
    dims="k",
    coords=coords,
    backend=BACKEND,
)

final_goods_eqs = leontief(
    factors=["VC", "VA"],
    factor_prices=["P_VC", "P_VA"],
    factor_shares=["psi_VC", "psi_VA"],
    output="Y",
    output_price="P_Y",
    dims="i",
    coords=coords,
    backend=BACKEND,
)

value_chain_eqs = leontief(
    factors="X",
    factor_prices="P_Y",
    factor_shares="psi_X",
    output="VC",
    output_price="P_VC",
    dims=["i", "j"],
    coords=coords,
    backend=BACKEND,
)

energy_goods_eqs = leontief(
    factors=["VC_E", "W_s"],
    factor_prices=["P_VC_E", "P_W"],
    factor_shares=["psi_VC_E", "psi_W"],
    output="Y_E",
    output_price="P_Y_E",
    dims="k",
    coords=coords,
    backend=BACKEND,
)

energy_value_chain_eqs = leontief(
    factors="X_E",
    factor_prices="P_Y",
    factor_shares="psi_X_E",
    output="VC_E",
    output_price="P_VC_E",
    dims=["i", "k"],
    coords=coords,
    backend=BACKEND,
)

equations = [
    # Sector Final Goods
    Equation("Final good production of sector <dim:i>", final_goods_eqs[0]),
    Equation("Sector <dim:i> demand for intermediate goods bundle", final_goods_eqs[1]),
    Equation("Sector <dim:i> demand for value added", final_goods_eqs[2]),
    # Value chain bundle
    Equation(
        "Sector <dim:i> production of intermediate goods bundle",
        value_chain_eqs[0],
    ),
    Equation(
        "Sector <dim:i> demand for sector <dim:j> intermediate input",
        value_chain_eqs[1],
    ),
    # Value add bundle
    Equation("Sector <dim:i> production of value add", va_eqs[0]),
    Equation("Sector <dim:i> demand for capital energy bundle", va_eqs[1]),
    Equation("Sector <dim:i> demand for labour", va_eqs[2]),
    # Capital Labour aggregation
    Equation("Producer <dim:i> production of capital and energy", ke_s_eqs[0]),
    Equation("Producer <dim:i> demand for captial", ke_s_eqs[1]),
    Equation("Producer <dim:i> demand for energy", ke_s_eqs[2]),
    # Energy Sector Final Goods
    Equation("Final good production of energy producer <dim:k>", energy_goods_eqs[0]),
    Equation("Sector <dim:k> demand for intermediate goods bundle", energy_goods_eqs[1]),
    Equation("Sector <dim:k> demand for value added", energy_goods_eqs[2]),
    # Value chain bundle
    Equation(
        "Energy producer <dim:k> production of intermediate goods bundle", energy_value_chain_eqs[0]
    ),
    Equation(
        "Sector <dim:k> demand for sector <dim:i> intermediate input", energy_value_chain_eqs[1]
    ),
    # Energy types
    Equation("Energy producer <dim:k> production", w_eqs[0]),
    Equation("Sector <dim:k> demand for capital labour bundle", w_eqs[1]),
    Equation("Sector <dim:k> demand for raw energy", w_eqs[2]),
    # Capital Labour aggregation
    Equation("Producer <dim:k> production of capital and labour", kl_e_eqs[0]),
    Equation("Producer <dim:k> demand for captial", kl_e_eqs[1]),
    Equation("Producer <dim:k> demand for labour", kl_e_eqs[2]),
    # Electricity aggregation
    Equation("Electricity production", energy_bundler_eqs[0]),
    Equation("Electrictiy production demand for raw energy <dim:k>", energy_bundler_eqs[1]),
    # Sector invest
    Equation("<dim:i> sector demand for installed capital", "P_Y * I_d = alpha_k * I_s"),
    # Government block
    Equation(name="Price level of sector <dim:i>", equation="P = (1 + tau_sales_S) * P_Y"),
    Equation(
        name="Energy input price level after tax <dim:k>",
        equation="P_Y_E_star = (1 + tau_energy_in) * P_Y_E",
    ),
    Equation(
        name="Energy output price level after tax <dim:i>",
        equation="P_E_star = (1 + tau_energy_out) * P_E",
    ),
    Equation(
        name="Net rental rate in sector <dim:i>", equation="r_star_S = (1 + tau_capital_S) * r"
    ),
    Equation(name="Net wage level in sector <dim:i>", equation="w_star_S = (1 + tau_wage_S) * w"),
    Equation(
        name="Net rental rate in sector <dim:k>", equation="r_star_E = (1 + tau_capital_E) * r"
    ),
    Equation(name="Net wage level in sector <dim:k>", equation="w_star_E = (1 + tau_wage_E) * w"),
    Equation(
        "Government budget constraint",
        "G + S_G = tau_income * income + "
        "(tau_sales_S * P_Y * C).sum(axis=0) + "
        "(tau_energy_out * P_E * E_d).sum(axis=0) + "
        "(tau_energy_in * P_Y_E * Y_E).sum(axis=0) + "
        "(tau_capital_S * r * K_S_d + tau_wage_S * w * L_S_d).sum(axis=0) + "
        "(tau_capital_E * r * K_E_d + tau_wage_E * w * L_E_d).sum(axis=0)",
    ),
    Equation("Government consumption of <dim:i> sector goods", "C_G = alpha_G * G"),
    Equation("Exogenous government savings", "S_G = S_G_bar"),
    # Household block
    Equation(
        "Household pre-tax income",
        "income = w * L_s + r * K_s + (R_s * P_R).sum()",
    ),
    Equation("Household after-tax income", "net_income = (1 - tau_income) * income"),
    Equation("Household budget constraint", "C_total * CPI = mpc * net_income"),
    Equation(
        "Household utility",
        "U = C_total ** (1 - sigma_C) / (1 - sigma_C) + F ** (1 - sigma_L) / (1 - sigma_L)",
    ),
    Equation("Household supply of labor", "F ** -sigma_L / C_total ** -sigma_C = w / CPI / Theta"),
    Equation(
        "Household shopping function",
        "C_total = A_C * (alpha_C * C ** ((phi_C - 1) / phi_C)).sum() ** (phi_C / (phi_C - 1))",
    ),
    Equation(
        "Household demand for good <dim:i>",
        "C = C_total / A_C * (A_C * alpha_C * CPI / P) ** phi_C",
    ),
    Equation("Household savings", "S = (1 - mpc) * net_income"),
    # Economic equilibrium
    Equation("Investment market clearing", "I_s = S + S_G + resid"),
    Equation("Labour market clearing", "L_s = L_S_d.sum() + L_E_d.sum()"),
    Equation("Capital market clearing", "K_s = K_S_d.sum() + K_E_d.sum()"),
    Equation("Energy market clearing", "E_s = E_d.sum()"),
    Equation("Ressource energy market clearing <dim:k>", "R_s = R_d"),
    Equation(
        "Sector <dim:i> goods market clearing",
        f"Y = C + C_G + I_d + X.sum(axis=1) + X_E.sum(axis=1)",
    ),
    Equation("Total time constraint", "T = L_s + F"),
    Equation("Numeraire", "P[0] = P_num"),
]

In [21]:
mod = CGEModel(
    variables=variable_info,
    parameters=param_info,
    equations=equations,
    coords=coords,
    parse_equations_to_sympy=False,
    backend="pytensor",
    mode="JAX",
    compile=True,
)

In [8]:
# from cge_modeling.pytensorf.compile import pytensor_objects_from_CGEModel
# eqs, var, par = pytensor_objects_from_CGEModel(mod)
# import pytensor
# pytensor.function(var + par, outputs=eqs)

In [9]:
mod.n_variables

117

In [10]:
len(mod.unpacked_equation_names)

117

In [11]:
# backward calibration SAM

In [12]:
df = pd.read_csv("data/11_sam_tax_energy.csv", index_col=[0, 1], header=[0, 1]).map(float).fillna(0)
assert np.allclose(df.sum(axis=0), df.sum(axis=1))

In [13]:
sectors_pretty_name = list(df["Production"].columns[:n_sectors])
eprod_pretty_name = list(df["Production"].columns[-len(eprod) :])

In [20]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Factor,Factor,Factor,Factor,Factor,Institution,Institution,Institution,Production,Production,...,Income Tax,Sales Tax,Sales Tax,Sales Tax,Sales Tax,Sales Tax,Use Tax,Use Tax,Use Tax,Other
Unnamed: 0_level_1,Unnamed: 1_level_1,Labor,Capital,Coal,Nuclear,Electricity,Household,Grid,Govt,Agriculture,Industry,...,Household,Agriculture,Industry,Services,Coal,Nuclear,Labor,Capital,Energy,Capital Accumulation
Factor,Labor,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Factor,Capital,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Factor,Coal,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Factor,Nuclear,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Factor,Electricity,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Institution,Household,7800.0,3430.0,850.0,1200.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Institution,Grid,0.0,0.0,0.0,0.0,4500.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Institution,Govt,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,600.0,500.0,800.0,400.0,-100.0,50.0,750.0,920.0,150.0,0.0
Production,Agriculture,0.0,0.0,0.0,0.0,0.0,2400.0,0.0,1100.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.0
Production,Industry,0.0,0.0,0.0,0.0,0.0,3030.0,0.0,1970.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,800.0


In [14]:
initial_data = {
    "income_tax_revenue": df.loc[("Income Tax", "Household"), ("Institution", "Household")],
    "sales_tax_revenue_S": df.loc["Sales Tax", ("Institution", "Household")]
    .loc[sectors_pretty_name]
    .values,
    "energy_tax_out_revenue": df.loc[
        ("Use Tax", "Energy"), ("Production", sectors_pretty_name)
    ].values,
    "energy_tax_in_revenue": df.loc[
        ("Sales Tax", eprod_pretty_name), ("Institution", "Grid")
    ].values,
    "capital_tax_revenue_S": df.loc[
        ("Use Tax", "Capital"), ("Production", sectors_pretty_name)
    ].values,
    "capital_tax_revenue_E": df.loc[
        ("Use Tax", "Capital"), ("Production", eprod_pretty_name)
    ].values,
    #     "resource_tax_revenue": df.loc[("Use Tax", )]
    "wage_tax_revenue_S": df.loc[("Use Tax", "Labor"), ("Production", sectors_pretty_name)].values,
    "wage_tax_revenue_E": df.loc[("Use Tax", "Labor"), ("Production", eprod_pretty_name)].values,
    "S": df.loc[("Other", "Capital Accumulation"), ("Institution", "Household")],
    "S_G": df.loc[("Other", "Capital Accumulation"), ("Institution", "Govt")],
    "L_S_d": df.loc[("Factor", "Labor"), ("Activities", sectors_pretty_name)].values,
    "K_S_d": df.loc[("Factor", "Capital"), ("Activities", sectors_pretty_name)].values,
    "L_E_d": df.loc[("Factor", "Labor"), ("Activities", eprod_pretty_name)].values,
    "K_E_d": df.loc[("Factor", "Capital"), ("Activities", eprod_pretty_name)].values,
    "E_d": df.loc[("Factor", "Electricity"), ("Activities", sectors_pretty_name)].values,
    "R_d": df.loc[("Institution", "Household"), ("Factor", eprod_pretty_name)].values,
    "I_d": df.loc[("Production", sectors_pretty_name), ("Other", "Capital Accumulation")].values,
    "C": df.loc[("Production", sectors_pretty_name), ("Institution", "Household")].values,
    "C_G": df.loc[("Production", sectors_pretty_name), ("Institution", "Govt")].values,
    "Y": df.loc[("Production")].loc[sectors_pretty_name].sum(axis=1).values,
    "Y_E": df.loc["Production"].loc[eprod_pretty_name].sum(axis=1).values,
    "X": df.loc["Production", "Activities"].loc[sectors_pretty_name, sectors_pretty_name].values,
    "X_E": df.loc["Production", ("Activities", eprod_pretty_name)].loc[sectors_pretty_name].values,
    "T": df.loc[("Factor", "Labor"), ("Activities")].sum() / 0.6,
}

econometric_estimates = {
    "epsilon_VA": np.ones(n_sectors) * 10.0,
    "epsilon_KE_S": np.ones(n_sectors) * 10.0,
    "epsilon_KL_E": np.ones(n_eprod) * 10.0,
    "epsilon_W": np.ones(n_eprod) * 5.0,
    "epsilon_Es": 3.0,
    "phi_C": 100.0,
    "Theta": 1.0,
    "sigma_C": 1.5,
}

In [15]:
def get_CES_2goods_params(fac1, fac2, p1, p2, output, epsilon):
    _rho = (epsilon - 1) / epsilon
    alpha = p1 * fac1 ** (1 / epsilon) / ((p1 * fac1 ** (1 / epsilon) + p2 * fac2 ** (1 / epsilon)))
    A = output * (alpha * fac1**_rho + (1 - alpha) * fac2**_rho) ** (-1 / _rho)
    return A, alpha

In [28]:
def calibrate_model(
    income_tax_revenue,
    sales_tax_revenue_S,
    energy_tax_out_revenue,
    energy_tax_in_revenue,
    capital_tax_revenue_S,
    capital_tax_revenue_E,
    wage_tax_revenue_S,
    wage_tax_revenue_E,
    #     resource_tax_revenue,
    S,
    S_G,
    L_S_d,
    K_S_d,
    L_E_d,
    K_E_d,
    E_d,
    R_d,
    C,
    C_G,
    Y,
    Y_E,
    X,
    X_E,
    I_d,
    T,
    epsilon_VA,
    epsilon_KE_S,
    epsilon_KL_E,
    epsilon_W,
    epsilon_Es,
    phi_C,
    sigma_C,
    Theta,
    variables,
    parameters,
):
    # taxes
    tau_sales_S = sales_tax_revenue_S / C

    tau_energy_out = energy_tax_out_revenue / E_d
    tau_energy_in = energy_tax_in_revenue / Y_E

    tau_capital_S = capital_tax_revenue_S / K_S_d
    tau_capital_E = capital_tax_revenue_E / K_E_d

    tau_wage_S = wage_tax_revenue_S / L_S_d
    tau_wage_E = wage_tax_revenue_E / L_E_d

    #     tau_R = resource_tax_revenue / R_d

    # Normalize prices to 1
    w = 1.0
    r = 1.0
    P_E = 1.0

    P_VA = np.ones(n_sectors)
    P_VC = np.ones(n_sectors)
    P_Y = np.ones(n_sectors)

    P_VC_E = np.ones(n_eprod)
    P_Y_E = np.ones(n_eprod)

    P_KE_S = np.ones(n_sectors)
    P_KL_E = np.ones(n_eprod)

    P_W = np.ones(n_eprod)
    P_R = np.ones(n_eprod)

    tau_raw_energy = np.zeros_like(P_R)  # not implemented yet

    P = P_Y * (1 + tau_sales_S)

    r_star_S = r * (1 + tau_capital_S)
    w_star_S = w * (1 + tau_wage_S)

    r_star_E = r * (1 + tau_capital_E)
    w_star_E = w * (1 + tau_wage_E)

    P_E_star = P_E * (1 + tau_energy_out)
    P_Y_E_star = P_Y_E * (1 + tau_energy_in)
    P_R_star = P_R * (1 + tau_raw_energy)

    # Zero the residual
    resid = 0.0

    # Factor supplies
    L_s = L_S_d.sum() + L_E_d.sum()
    K_s = K_S_d.sum() + K_E_d.sum()
    E_s = E_d.sum()
    F = T - L_s

    R_s = R_d

    # Numeraire
    P_num = P[0]

    # Firm calibration
    KE_S_d = (K_S_d * r_star_S + E_d * P_E_star) / P_KE_S
    KL_E_d = (K_E_d * r_star_E + L_E_d * w_star_E) / P_KL_E

    VA = (P_KE_S * KE_S_d + w_star_S * L_S_d) / P_VA
    W_s = (KL_E_d * P_KL_E + R_d * P_R_star) / P_W

    VC = (P_Y[:, None] * X).sum(axis=0) / P_VC

    VC_E = (P_Y_E * Y_E - P_W * W_s) / P_VC_E

    # Final good sector Value Add
    A_VA, alpha_VA = get_CES_2goods_params(
        fac1=KE_S_d, fac2=L_S_d, p1=P_KE_S, p2=w_star_S, output=VA, epsilon=epsilon_VA
    )
    A_KE_S, alpha_KE_S = get_CES_2goods_params(
        fac1=K_S_d, fac2=E_d, p1=r_star_S, p2=P_E_star, output=KE_S_d, epsilon=epsilon_KE_S
    )

    # Energy sector production function
    A_KL_E, alpha_KL_E = get_CES_2goods_params(
        fac1=K_E_d, fac2=L_E_d, p1=r_star_E, p2=w_star_E, output=KL_E_d, epsilon=epsilon_KL_E
    )
    A_W, alpha_W = get_CES_2goods_params(
        fac1=KL_E_d, fac2=R_d, p1=P_KL_E, p2=P_R_star, output=W_s, epsilon=epsilon_W
    )

    # Energy mix aggregator
    _rho = (epsilon_Es - 1) / epsilon_Es
    alpha_Es = P_Y_E_star * Y_E ** (1 / epsilon_Es) / (P_Y_E_star * Y_E ** (1 / epsilon_Es)).sum()
    A_Es = E_s / (alpha_Es * Y_E ** ((epsilon_Es - 1) / epsilon_Es)).sum() ** (
        epsilon_Es / (epsilon_Es - 1)
    )

    psi_VA = VA / Y
    psi_VC = VC / Y
    psi_VC_E = VC_E / Y_E
    psi_X = X / VC[None]
    psi_X_E = X_E / VC_E[None]
    psi_W = W_s / Y_E

    income = w * L_s + r * K_s + (P_R * R_d).sum()
    tau_income = income_tax_revenue / income
    net_income = (1 - tau_income) * income

    mpc = 1 - S / net_income
    consumption_spend = mpc * net_income

    C_total = C.sum()
    CPI = consumption_spend / C_total

    alpha_C = P * C ** (1 / phi_C) / (P * C ** (1 / phi_C)).sum()
    A_C = C_total / (alpha_C * C ** ((phi_C - 1) / phi_C)).sum() ** (phi_C / (phi_C - 1))
    sigma_L = (sigma_C * np.log(C_total) - np.log(w) + np.log(CPI)) / np.log(F)
    U = C_total ** (1 - sigma_C) / (1 - sigma_C) + Theta * (T - L_s) ** (1 - sigma_L) / (
        1 - sigma_L
    )

    # Exogenous government spending level
    S_G_bar = S_G

    G = (
        tau_income * income
        + (tau_sales_S * P_Y * C).sum()
        + (tau_energy_out * P_E * E_d).sum()
        + (tau_energy_in * P_Y_E * Y_E).sum()
        + (tau_capital_S * r * K_S_d + tau_wage_S * w * L_S_d).sum()
        + (tau_capital_E * r * K_E_d + tau_wage_E * w * L_E_d).sum()
        - S_G
    )
    alpha_G = C_G / G

    I_s = S + S_G
    alpha_k = (P_Y * I_d) / I_s

    d = {}
    for obj in variables + parameters:
        if obj.name != "_":
            d[obj.name] = locals()[obj.name]

    return d

In [29]:
calibrated_data = calibrate_model(
    **(initial_data | econometric_estimates), variables=mod.variables, parameters=mod.parameters
)

In [30]:
mod.check_for_equilibrium(calibrated_data)

Equilibrium not found. Total squared error: 3093343.631575


Equation                                                               Residual
Final good production of sector Ag                                     0.000000
Final good production of sector Ind                                    0.000000
Final good production of sector Serv                                   0.000000
Sector Ag demand for intermediate goods bundle                        -0.000000
Sector Ind demand for intermediate goods bundle                       -0.000000
Sector Serv demand for intermediate goods bundle                       0.000000
Sector Ag demand for value added                                      -0.000000
Sector Ind demand for value added                                     -0.000000
Sector Serv demand for value added                                    -0.000000
Sector Ag production of intermediate goods bundle                      0.000000
Sector Ind production of intermediate goods bundle         

# Simulation

In [63]:
from cge_modeling.base.utilities import flat_array_to_variable_dict, variable_dict_to_flat_array
from copy import deepcopy

tax_cut = deepcopy(calibrated_data)

# 50% income tax cut
tax_cut["tau_income"] *= 0.5
_, theta_tax_cut = variable_dict_to_flat_array(tax_cut, mod.variables, mod.parameters)

In [None]:
n_steps = 10
idata = mod._solve_with_euler_approximation(
    calibrated_data, theta_final=theta_tax_cut, n_steps=n_steps
)

  0%|                                                                                                         …

In [None]:
idata = mod.simulate(calibrated_data, final_delta_pct={"L_s": 0.9})

In [None]:
plot_lines(idata, mod)