# Easy Data Wrapper Tutorial
The data construction covered in the Data Management tutorial might be too complicated for users with limited experience in PyTorch.
This tutorial offers a helper class to wrap the dataset, all the user needs to know is (1) manipulating CSV files and (2) basic usage of pandas. 

**Note**: this tutorial assumes the reader has already read the first part of *Data Management tutorial* and is familiar with the terminology of `Torch-Choice`. For example, the reader should know what a session is in our framework and what a price observable is.

Author: Tianyu Du

Date: May. 20, 2022

In [1]:
import numpy as np
import pandas as pd
import torch
from typing import Optional, Dict
from sklearn.preprocessing import LabelEncoder
from torch_choice.data import ChoiceDataset 

# 
This tutorial aim to show how to manage choice datasets in the `torch-choice` package, we will follow the Stata documentation [here](https://www.stata.com/manuals/cm.pdf) to offer a seamless experience for the user to transfer prior knowledge in other packages to our package.

*From Stata Documentation*: Choice models (CM) are models for data with outcomes that are choices. The choices are selected by a decision maker, such as a person or a business, from a set of possible alternatives. For instance, we could model choices made by consumers who select a breakfast cereal from several different brands. Or we could model choices made by businesses who chose whether to buy TV, radio, Internet, or newspaper advertising.

Models for choice data come in two varieties—models for discrete choices and models for rankordered alternatives. When each individual selects a single alternative, say, he or she purchases one box of cereal, the data are discrete choice data. When each individual ranks the choices, say, he or she orders cereals from most favorite to least favorite, the data are rank-ordered data. Stata has commands for ﬁtting both discrete choice models and rank-ordered models.

Our `torch-choice` package handles the **discrete choice** models in the Stata document above.

## Data Layout

We would need to collect a couple of CSV files (or data-frames if already loaded to the memory) as the essential information to build our dataset.

Typically in Stata, a single Stata observation corresponds to a single statistical observation—that is why Stata calls rows in a Stata dataset “observations”.

So as not to confuse statistical observations with Stata observations, we call a single statistical observation a **“case”** or a **"purchase record"** and use this terminology throughout the CM manual. 

In [2]:
df = pd.read_stata('https://www.stata-press.com/data/r17/carchoice.dta')

We load the artificial dataset from the Stata website. Here we borrow the description of dataset reported from the `describe` command in Stata. 

```
Contains data from https://www.stata-press.com/data/r17/carchoice.dta
 Observations:         3,160                  Car choice data
    Variables:             6                  30 Jul 2020 14:58
---------------------------------------------------------------------------------------------------------------------------------------------------
Variable      Storage   Display    Value
    name         type    format    label      Variable label
---------------------------------------------------------------------------------------------------------------------------------------------------
consumerid      int     %8.0g                 ID of individual consumer
car             byte    %9.0g      nation     Nationality of car
purchase        byte    %10.0g                Indicator of car purchased
gender          byte    %9.0g      gender     Gender: 0 = Female, 1 = Male
income          float   %9.0g                 Income (in $1,000)
dealers         byte    %9.0g                 No. of dealerships in community
---------------------------------------------------------------------------------------------------------------------------------------------------
Sorted by: consumerid  car

```

In [3]:
df.head(10)

Unnamed: 0,consumerid,car,purchase,gender,income,dealers
0,1,American,1,Male,46.699997,9
1,1,Japanese,0,Male,46.699997,11
2,1,European,0,Male,46.699997,5
3,1,Korean,0,Male,46.699997,1
4,2,American,1,Male,26.1,10
5,2,Japanese,0,Male,26.1,7
6,2,European,0,Male,26.1,2
7,2,Korean,0,Male,26.1,1
8,3,American,0,Male,32.700001,8
9,3,Japanese,1,Male,32.700001,6


## Main Dataset
The wrapper we built requires several data frames, providing the correct information is all we need to do in this tutorial, the data wrapper will handle the construction of `ChoiceDataset` for you.

**Note**: The dataset in this tutorial is a bit simplified, we only have one choice for each user in each session, so the `consumerid` column identifies all of the user, the session, and the purchase record (because we have different dealers for the same type of car, we define each purchase record of it's session instead of assigning all purchase records to the same session).
That is, we have a single user makes a single choice in each single session.

The **main dataset** should contain the following columns:

1. A column identifies **purchase record** (also called **case** in Stata syntax). this tutorial, the `consumerid` column is the identifier. For example, the first 4 rows of the dataset (see above) has `consumerid == 1`, this means we should look at the first 4 rows together and they constitute the first case/purchase record.
2. A column identifies **names of alternatives**, which is `car` in the dataset above.
3. A column identifies the **choice** made by the consumer in each purchase record, which is `purchase` in our case. Exactly one row per purchase record (i.e., rows with the same `consumerid`) should have 1, while the values are zeros for all other rows.
4. A *optional* column identifies the **user** making the choice, which is also `consumerid` in our case.
5. A *optional* column identifies the **session** of the choice, which is also `consumerid` in our case.

Since `consumerid` identifies multiple pieces of information, the `df_main` data-frame below only has three columns.

In [4]:
df_main = df[['consumerid', 'car', 'purchase']]
df_main

Unnamed: 0,consumerid,car,purchase
0,1,American,1
1,1,Japanese,0
2,1,European,0
3,1,Korean,0
4,2,American,1
...,...,...,...
3155,884,Japanese,1
3156,884,European,0
3157,885,American,1
3158,885,Japanese,0


## Datasets of Observables
We now construct data frames for different observables.

**Note**: the **index** (also the name of index) of these data frames matter a lot! You can use pandas' [`set_index`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.set_index.html) method to set the index of a data frame.

**Note**: the **name** of indices should be the same as the column name indicating that information in the main dataset. For example, the name of user observable's index should be `consumerid`. 

### Suggested Procedure of Storing and Loading Data
1. Suppose `SESSION_INDEX` column in `df_main` is the index of the session, `ALTERNATIVES` column is the index of the car.
2. For user-specific observables, you should have a CSV on disk with columns {`consumerid`, `var_1`, `var_2`, ...}.
3. You load the user-specific dataset as `user_obs = pd.read_csv(..., index='consumerid')`.


### Allowed Types of Observables
1. user-specific observables (e.g., gender and income) should be indexed by user names: from 1 to 885 in this tutorial.
2. item-specific observables (not shown in this tutorial) should be indexed by item names: American, Japanese, European, and Korean in this tutorial.
3. session-specific observables (not shown in this tutorial) should be indexed by session names: from 1 to 885 in this tutorial.
4. session-and-item-specific observables (e.g., dealers) should be indexed by both session names and item names (i.e., multi-indexing): from (1, American) to (885, Korean) in this example.

In [5]:
gender = pd.get_dummies(df.groupby('consumerid')['gender'].first().to_frame())
income = df.groupby('consumerid')['income'].first().to_frame()
dealers = pd.get_dummies(df.set_index(['consumerid', 'car'])['dealers'])

In [6]:
gender

Unnamed: 0_level_0,gender_Female,gender_Male
consumerid,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0,1
2,0,1
3,0,1
4,1,0
5,0,1
...,...,...
881,0,1
882,0,1
883,1,0
884,0,1


In [7]:
income

Unnamed: 0_level_0,income
consumerid,Unnamed: 1_level_1
1,46.699997
2,26.100000
3,32.700001
4,49.199997
5,24.299999
...,...
881,45.700001
882,69.800003
883,45.599998
884,20.900000


In [8]:
dealers

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
consumerid,car,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
1,American,0,0,0,0,0,0,0,0,0,1,0,0,0,0
1,Japanese,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,European,0,0,0,0,0,1,0,0,0,0,0,0,0,0
1,Korean,0,1,0,0,0,0,0,0,0,0,0,0,0,0
2,American,0,0,0,0,0,0,0,0,0,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
884,Japanese,0,0,0,0,0,0,0,0,0,0,1,0,0,0
884,European,0,0,0,0,1,0,0,0,0,0,0,0,0,0
885,American,0,0,0,0,0,0,0,0,0,0,1,0,0,0
885,Japanese,0,0,0,0,0,1,0,0,0,0,0,0,0,0


# Build Datasets using `EasyDatasetWrapper`
We first need to provide the main dataset to the wrapper, then we need to tell the wrapper a bit information about the data.
In our example, `consumerid` column in the main dataset identifies all of `case_index_column`, `session_index_column`, and `purchase_record_index_column`.

In [19]:
import torch_choice
from torch_choice.utils.easy_data_wrapper import EasyDatasetWrapper

In [21]:
data = EasyDatasetWrapper(main_data=df_main,
                          purchase_record_column='consumerid',
                          item_name_column='car',
                          choice_column='purchase',
                          session_index_column='consumerid',
                          user_index_column='consumerid',
                          user_observable_data={'gender': gender, 'income': income},
                          price_observable_data={'dealer': dealers})

Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}


In [22]:
# Use summary to see what's inside the data wrapper.
data.summary()

* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase
0              1  American         1
1              1  Japanese         0
2              1  European         0
3              1    Korean         0
4              2  American         1
...          ...       ...       ...
3155         884  Japanese         1
3156         884  European         0
3157         885  American         1
3158         885  Japanese         0
3159         885  European         0

[3160 rows x 3 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 2], user_income=[885, 1], price_dealer=[885, 4, 14], device=cpu)
