# Gaston work

Some additional work I'm doing for Gaston, namely...

1. Exporting a cleaned version of the data to CSV and making some small edits to the written PSET
3. Seeing if I can get autograd working for my BLP

## 1: Cleaning data

In [None]:
# Pre-question, loading libraries and data
import pandas as pd
from scipy.io import loadmat
import jaxopt as jaxopt
from PSET4_functions.misc import *
from PSET4_functions.shares import *
from PSET4_functions.delta import *
from PSET4_functions.moments import * 
from PSET4_functions.mpec_wrapper import *

# Exporting cleaned data as a CSV

m100_j3 = loadmat("data/100markets3products.mat")
dat = clean_data(m100_j3, 3)
dat.to_csv("data/100markets3products.csv", index = False)

m10_j3 = loadmat("data/10markets3products.mat")
dat = clean_data(m10_j3, 3)
dat.to_csv("data/10markets3products.csv", index = False)

m100_j5 = loadmat("data/100markets5products.mat")
dat = clean_data(m100_j5, 5)
dat.to_csv("data/100markets5products.csv", index = False)

## 2: Working on autograd

Not going so well, for now. I rewrote my shares functions in shares_autograd.py, to make them JAX-compilable. They work, and jax.jackfwd produces the correct Jacobian, but it takes **way** too long: about 2 minutes for the 10X3 dataset (you can see in my output below). 

What's costing me here (I think) is because of how I structured my shares functions, auto_grad is taking the derivative at the market level, **for each person**, i.e. each lognormal draw. That should be solvable, and I'll work on it, it's just taken much longer to get it running at all than I expected!

In [16]:
# Pre-question, loading libraries and data
import importlib
import pandas as pd
from scipy.io import loadmat
import jaxopt as jaxopt
from PSET4_functions.misc import *
import PSET4_functions.shares_autograd as s_auto 
import PSET4_functions.shares as s
from PSET4_functions.delta import *
from jax import jacfwd, jit
from time import perf_counter

# importlib.reload(s_auto)
# importlib.reload(s)

m10_j3 = loadmat("data/10markets3products.mat")
dat = clean_data(m10_j3, 3)

shares_data_long = dat[['sjm']].to_numpy()
prices_data_long = dat[['pjm']].to_numpy() 
x_data_long = dat[['X1jm', 'X2jm', 'X3jm']].to_numpy()
w_data_long = dat[['wj']].to_numpy()
supply_features_data_long = dat[['wj', 'zjm', 'etajm']].to_numpy()

shares_data_wide = shares_data_long.reshape(10,3)

prices_data_wide = prices_data_long.reshape(10,3)
delta_0 = logit_delta(shares_data_wide)

np.random.seed(456)
random_vs = np.random.lognormal(0, 1, 1000)

og_shares = s.shares(1, delta_0, prices_data_wide, random_vs)
print("JAX shares time:")
%time auto_shares = s_auto.shares(1, delta_0, prices_data_wide, random_vs)
print("Checking shares equal (1e-10):", np.allclose(og_shares, auto_shares, atol=1e-10))

autograd_shares = jacfwd(s_auto.shares, argnums=(1))
og_ds_ddelta = s.shares_ddelta(1, delta_0, prices_data_wide, random_vs)
print("JAX ds_ddelta time:")
%time auto_ds_ddelta = autograd_shares(1, delta_0, prices_data_wide, random_vs)
print("Checking derivates equal (1e-10):", np.allclose(og_ds_ddelta, auto_ds_ddelta, atol=1e-10))

JAX shares time:
CPU times: user 16.7 s, sys: 64.9 ms, total: 16.8 s
Wall time: 17.4 s
Checking shares equal (1e-10): True
JAX ds_ddelta time:
CPU times: user 2min 4s, sys: 430 ms, total: 2min 4s
Wall time: 2min 6s
Checking derivates equal (1e-10): True
