In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import namedtuple

from firm import supply
from agent import demand
from optimization import find_equilibrium_prices, error

%load_ext autoreload

In [2]:
args = dict(
    n_products = 3,
    n_assets = 2,
    n_firms = 2,
    n_agents = 5000,
    theta = 10.,
    alpha = 5.,
    beta = 5.,
    scale = 0.75,
    T = 100,
    max_iter = 1000,
    step_size = 1e-2,
    tol = 1e-12,
    key = jax.random.PRNGKey(1),
)



In [3]:
n_goods = args['n_products'] - args['n_assets']

# a firm is a tuple A, rts, output_idx
firms = (jax.random.gamma(args['key'], args['theta'], shape=(args['n_firms'],)), 
         jax.random.beta(args['key'], args['alpha'], args['beta'], shape=(args['n_firms'],)), 
         jnp.zeros((args['n_firms'], 1)).astype(jnp.int32))

# An agent is defined by their asset holdings and preference parameter sigma.
assets = jnp.exp(args['scale'] * jax.random.normal(args['key'], (args['n_agents'], args['n_assets'])))
sigmas = 2 + 0.1 * jax.random.normal(args['key'], (args['n_agents'],))
agents = jnp.concatenate((assets.reshape(args['n_agents'], -1), sigmas.reshape(args['n_agents'], -1)), axis=1)

# Define error function
e = error(firms, agents, n_goods, args['T'])

# Define some random starting prices:
# prices = jax.random.normal(key, (n_products,))
prices = jnp.zeros((args['n_products']))

In [None]:
s = supply(jnp.exp(prices), firms)
R = s[1]
assert not jnp.isnan(R)
d = demand(R, jnp.exp(prices), agents, n_goods, args['T'])
err = e(prices)
print(f'(Log) starting prices: {prices}\nSupply: {s[0][0]}\nDemand: {d[0]}\nTherefore, error: {err}')

### Distribution of agent wealth (source of heterogeneity)

In [None]:
_ = pd.DataFrame(agents[:, 0]).hist(bins=int(args['n_agents']/20))

### Distribution of agent preferences (sigma)

In [None]:
_ = pd.DataFrame(agents[:, -1]).hist(bins=int(args['n_agents']/20))

Firms are also heterogenous in terms of productivity, but there are only two of them for now so there is not distribution to look at.

### Visual representation of equilibrium

In [None]:
bins = 200
output_prices = jnp.linspace(-1, 1, bins)[:, jnp.newaxis]
input_prices = jnp.repeat(prices[1:][jnp.newaxis], bins, axis=0)
prices_mat = jnp.concatenate((output_prices, input_prices), axis=1)
s = jax.vmap(lambda p: supply(jnp.exp(p), firms))(prices_mat)
do = jax.vmap(lambda r, p: demand(r, jnp.exp(p), agents, n_goods, args['T']))(s[1], prices_mat)
errors = jax.vmap(e)(prices_mat)[:, jnp.newaxis]
so = s[0][:, 0][:, jnp.newaxis]

df = pd.DataFrame(jnp.concatenate([prices_mat[:, 0][:, jnp.newaxis], so, do, errors], axis=1))
df.columns = ['Log Prices', 'Supply', 'Demand', 'Error']

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(18,8))
_ = df.plot(x=0, y=[1, 2], ax=axes[0])
_ = df.plot(x=0, y=[1], ax=axes[1])
_ = df.plot(x=0, y=[2], ax=axes[2])
_ = df.plot(x=0, y=[3], ax=axes[3])

If it looks like the equillibrium will fall in a reasonable area of the problem space, then let's search for it!

In [None]:
eql_log_prices = find_equilibrium_prices(*(args.values()))

In [None]:
s = supply(jnp.exp(eql_log_prices), firms)
R = s[1]
assert not jnp.isnan(R)
d = demand(R, jnp.exp(eql_log_prices), agents, n_goods, args['T'])
err = e(eql_log_prices)
print(f'(Log) equilibrium prices: {eql_log_prices}\nSupply: {s[0][0]}\nDemand: {d[0]}\nTherefore, error: {err}')