# Replication Materials for the Torch-Choice Paper

> Author: Tianyu Du
> 
> Email: `tianyudu@stanford.edu`

This repository contains the replication materials for the paper "Torch-Choice: A Library for Choice Models in PyTorch".

In [2]:
import warnings
warnings.filterwarnings("ignore")

from time import time
import numpy as np
import pandas as pd
import torch
import torch_choice
from torch_choice import run
from tqdm import tqdm
from typing import List
from torch_choice.data import ChoiceDataset, utils
from torch_choice.model import ConditionalLogitModel

# Data Structure

In [4]:
car_choice = pd.read_csv("https://raw.githubusercontent.com/gsbDBI/torch-choice/main/tutorials/public_datasets/car_choice.csv")
car_choice.head()

Unnamed: 0,record_id,session_id,consumer_id,car,purchase,gender,income,speed,discount,price
0,1,1,1,American,1,1,46.699997,10,0.94,90
1,1,1,1,Japanese,0,1,46.699997,8,0.94,110
2,1,1,1,European,0,1,46.699997,7,0.94,50
3,1,1,1,Korean,0,1,46.699997,8,0.94,10
4,2,2,2,American,1,1,26.1,10,0.95,100


## Adding Observables, Method 1: Observables Derived from Columns of the Main Dataset

In [5]:
user_observable_columns=["gender", "income"]
from torch_choice.utils.easy_data_wrapper import EasyDatasetWrapper
data_wrapper_from_columns = EasyDatasetWrapper(
    main_data=car_choice,
    purchase_record_column='record_id',
    choice_column='purchase',
    item_name_column='car',
    user_index_column='consumer_id',
    session_index_column='session_id',
    user_observable_columns=['gender', 'income'],
    item_observable_columns=['speed'],
    session_observable_columns=['discount'],
    itemsession_observable_columns=['price'])

data_wrapper_from_columns.summary()
dataset = data_wrapper_from_columns.choice_dataset
# ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], item_speed=[4, 1], user_gender=[885, 1], user_income=[885, 1], session_discount=[885, 1], itemsession_price=[885, 4, 1], device=cpu)

Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.
* purchase record index range: [1 2 3] ... [883 884 885]
* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      record_id  session_id  consumer_id       car  purchase  gender  \
0             1           1            1  American         1       1   
1             1           1            1  Japanese         0       1   
2             1           1            1  European         0       1   
3             1           1            1    Korean         0       1   
4             2           2            2  American         1       1   
...         ...         ...          ...       ...       ...     ...   
3155        884         884  

## Adding Observables, Method 2: Added as Separated DataFrames

In [6]:
# create dataframes for gender and income. The dataframe for user-specific observable needs to have the `consumer_id` column.
gender = car_choice.groupby('consumer_id')['gender'].first().reset_index()
income = car_choice.groupby('consumer_id')['income'].first().reset_index()
# alternatively, put gender and income in the same dataframe.
gender_and_income = car_choice.groupby('consumer_id')[['gender', 'income']].first().reset_index()
# speed as item observable, the dataframe requires a `car` column.
speed = car_choice.groupby('car')['speed'].first().reset_index()
# discount as session observable. the dataframe requires a `session_id` column.
discount = car_choice.groupby('session_id')['discount'].first().reset_index()
# create the price as itemsession observable, the dataframe requires both `car` and `session_id` columns.
price = car_choice[['car', 'session_id', 'price']]
# fill in NANs for (session, item) pairs that the item was not available in that session.
price = price.pivot('car', 'session_id', 'price').melt(ignore_index=False).reset_index()

In [7]:
data_wrapper_from_dataframes = EasyDatasetWrapper(
    main_data=car_choice,
    purchase_record_column='record_id',
    choice_column='purchase',
    item_name_column='car',
    user_index_column='consumer_id',
    session_index_column='session_id',
    user_observable_data={'gender': gender, 'income': income},
    # alternatively, supply gender and income as a single dataframe.
    # user_observable_data={'gender_and_income': gender_and_income},
    item_observable_data={'speed': speed},
    session_observable_data={'discount': discount},
    itemsession_observable_data={'price': price})

# the second method creates exactly the same ChoiceDataset as the previous method.
assert data_wrapper_from_dataframes.choice_dataset == data_wrapper_from_columns.choice_dataset

Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.


In [8]:
data_wrapper_mixed = EasyDatasetWrapper(
    main_data=car_choice,
    purchase_record_column='record_id',
    choice_column='purchase',
    item_name_column='car',
    user_index_column='consumer_id',
    session_index_column='session_id',
    user_observable_data={'gender': gender, 'income': income},
    item_observable_data={'speed': speed},
    session_observable_data={'discount': discount},
    itemsession_observable_columns=['price'])

# these methods create exactly the same choice dataset.
assert data_wrapper_mixed.choice_dataset == data_wrapper_from_columns.choice_dataset == data_wrapper_from_dataframes.choice_dataset

Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.
