In [1]:
import time
import numpy as onp
from jax import numpy as np
from functools import partial
from jax.ops import segment_sum
from jax import random, grad, jit, vmap, remat, lax

def softmax(utilities):
    exp_utility = np.exp(utilities)
    sum_expu_across_submodels = np.sum(exp_utility, axis=1, keepdims=True)
    proba = exp_utility / sum_expu_across_submodels
    return proba

def mse(target, predicted):
    error = target - predicted
    squared_error = error**2
    return np.mean(squared_error)

# mock data

In [2]:
target = onp.load('target.npy')
x = onp.load('x.npy')
print(x.shape)
w = np.array([[ 0.00304216,  0.01455319, -0.0026763 ,  0.02046709,
               0.00563778, -0.00192821, -0.01059241,  0.00204556,
               0.0079378 ,  0.00027923,  0.01584745]])

# 10k alternatives
x = np.tile(x, 54)
target = np.tile(target, 54)

chooser_incomes = onp.random.normal(60000, 15000, 100000) ## 100k chooser incomes
chooser_incomes = chooser_incomes / 100000
chooser_incomes_reshaped = chooser_incomes.reshape((40, 2500))

(11, 187)


In [3]:
x.shape

(11, 10098)

# functions

In [3]:
@jit
@remat
def income_interaction(income, x, idx_alt_income=1, w=w):
    """For a single chooser, interact their income with
       the alternatives' income-related attribute (e.g. mean income).
       Then calculate probabilities across alternatives for this
       single chooser. Other interactions, or interaction types,
       could be added here-  e.g. enforcement of budget constaints.
       
       Parameters
       ----------
       income : float
           Scalar income value for the single chooser.
       x : np.array
           Array of alternatives' explanatory variables.  Should be of
           shape (num_expvars, num_alts)
       idx_alt_income : int
           Index location of the explanatory variable in x that pertains
           to income.
        w : np.array
            Weights (parameter values).  Of shape (1, num_expvars)
       """
    income_interacted = x[idx_alt_income] * income
    x2 = x.at[idx_alt_income].set(income_interacted)
    
    logits = np.dot(w, x2)
    probas = softmax(logits)
    probas = probas.flatten()
    return probas

#Partialed, 1-argument form of the income_interaction func
income_interaction2 = partial(income_interaction, x=x, idx_alt_income=1, w=w)

def loss_disagg(weights):
    """VMAP the income interaction function over all choosers' incomes"""
    income_interaction3 = partial(income_interaction, x=x, idx_alt_income=1, w=weights)
    probas_all = vmap(income_interaction3)(chooser_incomes)
    proba_sum = np.sum(probas_all, axis=0)
    return mse(proba_sum, target)

def loss_disagg_lax(weights):
    """lax.map over batches of choosers-  lower memory version of loss_disagg"""
    income_interaction3 = partial(income_interaction, x=x, idx_alt_income=1, w=weights)
    ## Map over batches of choosers, VMAP each batch of choosers. #Reshaped choosers has dim == num_batches
    probas_all = lax.map(vmap(income_interaction3), chooser_incomes_reshaped)

    ## Reshape output to remove the batch dimension
    num_alts = x.shape[1]
    num_total_choosers = chooser_incomes_reshaped.size
    probas_all = probas_all.reshape((num_total_choosers, num_alts))
    
    ## Sum probas for each alternative
    proba_sum = np.sum(probas_all, axis=0)
    return mse(proba_sum, target)

# run

In [4]:
print('probas for single chooser')
income_interaction2(0.5)

probas for single chooser


DeviceArray([9.9656630e-05, 9.8727876e-05, 9.9612160e-05, ...,
             9.2831062e-05, 1.1771860e-04, 9.2662689e-05], dtype=float32)

In [5]:
print('loss_disagg')
start_time = time.time()

loss_disagg(w)

end_time = time.time()
time_elapsed = end_time - start_time
print(time_elapsed)

loss_disagg
0.5063130855560303


In [6]:
print('loss_disagg_lax')
start_time = time.time()

loss_disagg_lax(w)

end_time = time.time()
time_elapsed = end_time - start_time
print(time_elapsed)

loss_disagg_lax
0.4029693603515625


In [7]:
print('grad of loss')
loss_disagg_lax_grad = grad(loss_disagg_lax)
print(loss_disagg_lax_grad(w))

end_time = time.time()
time_elapsed = end_time - start_time
print(time_elapsed)

grad of loss
[[ 0.12902279 -0.3715989  -0.9149651   4.1221437   2.4685009   0.35714096
  -1.8347561   0.13082723  3.187939   -0.6747903  -1.4099519 ]]
77.58303189277649
