# Estimating large demand systems using jax

Jeremy Large and Emmet Hall-Hoffarth

In [None]:
import os 
import sys
# implement PYTHONPATH within script
sys.path.insert(0, os.path.abspath('../../lib'))

In [None]:
import logging
import warnings

In [None]:
import numpy as np
import pandas as pd
import pylab as plt
from sklearn import decomposition

In [None]:
warnings.simplefilter("ignore")
from rube.model.model import RubeJaxModel, load_params, positivize
import rube.data.clean

from rube.utils import nearest_neigbours
from rube.model.model import save_embeddings_tsv
from rube.data import uci

In [None]:
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, force=True)

### Background

We demonstrate code which can implement the algorithms in Lanier, Large and Quah (2022).

The main way of running the code is in `/scripts/main.py`. However:

In [None]:
import jax
import jax.numpy as jnp

### Load retail dataset from UCI

In [None]:
MAX_QUANTITY = 6
STOCK_VOCAB_SIZE = 2000
USER_VOCAB_SIZE = 2200
K = 12
BATCH_SIZE = 1024
NEGATIVE_SAMPLES = 99
STEP_SIZE = 0.01
TIMEPERIOD_LENGTH = 4  # weeks - the length of time applied to each seasonal dummy
SEED = 42

Load data:

In [None]:
gen = uci.UCIGenerator(BATCH_SIZE, NEGATIVE_SAMPLES,
                       max_accepted_quantity=MAX_QUANTITY,
                       stock_vocab_size=STOCK_VOCAB_SIZE,
                       user_vocab_size=USER_VOCAB_SIZE,
                       period_in_weeks=TIMEPERIOD_LENGTH,
                       save_raw=True)

Create a model:

In [None]:
model = RubeJaxModel(stock_vocab_size=len(gen.stock_vocab),
                     user_vocab_size=gen.user_vocab_size,
                     embedding_dim=K,
                     n_periods=gen.get_n_periods(),
                     step_size=STEP_SIZE,
                     seed=SEED)

### Now fit the model

In [None]:
N_EPOCHS = 150   # 50 is a bit low - it is set this way for the sake of a manageable notebook

In [None]:
model.training_loop(gen, N_EPOCHS)

### Results

Obtain the model's fitted parameters (and remove unnecessary minus-signs):

In [None]:
params = positivize(load_params(model.params))

Now visualize `d_1`:

In [None]:
_ = pd.DataFrame(params['d_1'].T, columns=['histogram of estimated d_1 across users']).plot.hist(bins=50 , figsize=(9, 6), fontsize=14, grid=True)

Next look at `d_2`:

In [None]:
_ = pd.DataFrame(params['d_2'].T, columns=['histogram of estimated d_2 across users']).plot.hist(bins=50 , figsize=(9, 6), fontsize=14, grid=True)

In [None]:
_ = pd.DataFrame(params['d_3'].T, columns=['histogram of estimated d_3 across users']).plot.hist(bins=100, grid=True, figsize=(9, 6), fontsize=14)

We can examine correlations in the `b` matrix recording user preferences across latent dimensions:

In [None]:
b = pd.DataFrame(params['b'])
_ = pd.plotting.scatter_matrix(b.T, alpha=0.2, figsize=(12, 12), color='g', diagonal="kde")

In [None]:
b = pd.DataFrame(params['b'])
_ = pd.plotting.scatter_matrix(b[:4].T, alpha=0.2, figsize=(8, 8), color='g', diagonal="kde")

We can also examine correlations in the `A` matrix, which records stock-items' features, across latent dimensions:

In [None]:
A = pd.DataFrame(params['A'])
_ = pd.plotting.scatter_matrix(A[:5000], alpha=0.2, figsize=(12, 12), color='k', diagonal="kde")

Note the special behaviour of the first dimension, along which values are constrained positive. This dimension has a particular interpretation that is related to price sensitivity.

In [None]:
pca = decomposition.PCA(n_components=A.shape[1])
pca.fit(params['A'])
_title = 'Deviation in A explained by each principal component'
to_plot = pd.DataFrame((pca.explained_variance_ratio_), columns=['explained proportion of variance (counting components from 0 up)'])
_ = to_plot.plot(grid=True, title=_title, figsize=(9, 6), fontsize=14, marker='o', xticks=range(12))
plt.axhline(color='k'); _ = plt.axvline(color='k')

Save results for further assessments:

In [None]:
save_embeddings_tsv(params, gen)
vocab = pd.DataFrame(list(gen.stock_vocab), columns=['StockCode'])

### Price sensitivity

Lets do a simple-minded study of price sensitivities:

In [None]:
desc_lookup = gen.raw_data[['Description', 'product_token']].groupby('product_token').max()

In [None]:
sensitivities = vocab.join(pd.DataFrame(params['A'][:,0])).sort_values(0, ascending=False).join(desc_lookup)

Items where consumers exhibit the greatest price sensitivity:

In [None]:
sensitivities.head(10)

Items (drawn normally from quite a tightly-packed field) where consumers exhibit the least price sensitivity:

In [None]:
sensitivities.tail(10)