In [2]:
# import required dependencies.
import numpy as np
import pandas as pd
import torch
from typing import Optional, List, Dict
from sklearn.preprocessing import LabelEncoder
from torch_choice.data import ChoiceDataset 

# 
This tutorial aim to show how to manage choice datasets in the `torch-choice` package, we will follow the Stata documentation [here](https://www.stata.com/manuals/cm.pdf) to offer a seamless experience for the user to transfer prior knowledge in other packages to our package.

*From Stata Documentation*: Choice models (CM) are models for data with outcomes that are choices. The choices are selected by a decision maker, such as a person or a business, from a set of possible alternatives. For instance, we could model choices made by consumers who select a breakfast cereal from several different brands. Or we could model choices made by businesses who chose whether to buy TV, radio, Internet, or newspaper advertising.

Models for choice data come in two varieties—models for discrete choices and models for rankordered alternatives. When each individual selects a single alternative, say, he or she purchases one box of cereal, the data are discrete choice data. When each individual ranks the choices, say, he or she orders cereals from most favorite to least favorite, the data are rank-ordered data. Stata has commands for ﬁtting both discrete choice models and rank-ordered models.

Our `torch-choice` package handles the **discrete choice** models in the Stata document above.

## Data Layout

Typically in Stata, a single Stata observation corresponds to a single statistical observation—that is why Stata calls rows in a Stata dataset “observations”.

So as not to confuse statistical observations with Stata observations, we call a single statistical observation a **“case”** or a **"purchase record"** and use this terminology throughout the CM manual. 

We load the artificial dataset from the Stata website. Here we borrow the description of dataset reported from the `describe` command in Stata. 

```
Contains data from https://www.stata-press.com/data/r17/carchoice.dta
 Observations:         3,160                  Car choice data
    Variables:             6                  30 Jul 2020 14:58
---------------------------------------------------------------------------------------------------------------------------------------------------
Variable      Storage   Display    Value
    name         type    format    label      Variable label
---------------------------------------------------------------------------------------------------------------------------------------------------
consumerid      int     %8.0g                 ID of individual consumer
car             byte    %9.0g      nation     Nationality of car
purchase        byte    %10.0g                Indicator of car purchased
gender          byte    %9.0g      gender     Gender: 0 = Female, 1 = Male
income          float   %9.0g                 Income (in $1,000)
dealers         byte    %9.0g                 No. of dealerships in community
---------------------------------------------------------------------------------------------------------------------------------------------------
Sorted by: consumerid  car

```

In [86]:
df = pd.read_stata('https://www.stata-press.com/data/r17/carchoice.dta')

In [88]:
df.head(10)

Unnamed: 0,consumerid,car,purchase,gender,income,dealers
0,1,American,1,Male,46.699997,9
1,1,Japanese,0,Male,46.699997,11
2,1,European,0,Male,46.699997,5
3,1,Korean,0,Male,46.699997,1
4,2,American,1,Male,26.1,10
5,2,Japanese,0,Male,26.1,7
6,2,European,0,Male,26.1,2
7,2,Korean,0,Male,26.1,1
8,3,American,0,Male,32.700001,8
9,3,Japanese,1,Male,32.700001,6


## Main Dataset
The wrapper we built requires several data frames, providing the correct information is all we need to do in this tutorial, the data wrapper will handle the construction of `ChoiceDataset` for you.

The dataset in this tutorial is a bit simplified, we only have one choice for each user in each session, so the `consumerid` column identifies all of the user, the session, and the purchase record (because we have different dealers for the same type of car, we define each purchase record of it's session instead of assigning all purchase records to the same session).

The **main dataset** consists of columns:
1. (`consumerid`) The column identifies **case** (in Stata's language) /**purchase record** (in our language). In this tutorial, the `consumerid` column is the identifier. For example, the first 4 rows of the dataset (see above) has `consumerid == 1`, this means we should look at the first 4 rows together and they constitute the first case/purchase record.
2. (`car`) The column identifies names of alternatives (i.e., available items in each purchase record).
3. (`purchase`) The column identifies the choice made by the consumer in each purchase record, exactly one row per purchase record (i.e., rows with the same `consumerid`) should have 1, while the values are zeros for all other rows.
4. (Optional, `consumerid` here) A column identifies the user making the choice.
5. (Optional, `consumerid` here) A column identifies the session of the choice.

In [101]:
df_main = df[['consumerid', 'car', 'purchase']]
df_main

Unnamed: 0,consumerid,car,purchase
0,1,American,1
1,1,Japanese,0
2,1,European,0
3,1,Korean,0
4,2,American,1
...,...,...,...
3155,884,Japanese,1
3156,884,European,0
3157,885,American,1
3158,885,Japanese,0


## Datasets of Observables
We now construct data frames for different observables.
**Note**: the **index** of these data frames matter a lot! You can use pandas' [`set_index`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.set_index.html) method to set the index of a data frame.
1. user-specific observables (e.g., gender and income) should be indexed by user names: from 1 to 885 in this tutorial.
2. item-specific observables (not shown in this tutorial) should be indexed by item names: American, Japanese, European, and Korean in this tutorial.
3. session-specific observables (not shown in this tutorial) should be indexed by session names: from 1 to 885 in this tutorial.
4. session-and-item-specific observables (e.g., dealers) should be indexed by both session names and item names (i.e., multi-indexing): from (1, American) to (885, Korean) in this example.

In [102]:
gender = pd.get_dummies(df.groupby('consumerid')['gender'].first().to_frame())
income = df.groupby('consumerid')['income'].first().to_frame()
dealers = pd.get_dummies(df.set_index(['consumerid', 'car'])['dealers'])

In [103]:
gender

Unnamed: 0_level_0,gender_Female,gender_Male
consumerid,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0,1
2,0,1
3,0,1
4,1,0
5,0,1
...,...,...
881,0,1
882,0,1
883,1,0
884,0,1


In [96]:
income

Unnamed: 0_level_0,income
consumerid,Unnamed: 1_level_1
1,46.699997
2,26.100000
3,32.700001
4,49.199997
5,24.299999
...,...
881,45.700001
882,69.800003
883,45.599998
884,20.900000


In [100]:
dealers

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
consumerid,car,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
1,American,0,0,0,0,0,0,0,0,0,1,0,0,0,0
1,Japanese,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,European,0,0,0,0,0,1,0,0,0,0,0,0,0,0
1,Korean,0,1,0,0,0,0,0,0,0,0,0,0,0,0
2,American,0,0,0,0,0,0,0,0,0,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
884,Japanese,0,0,0,0,0,0,0,0,0,0,1,0,0,0
884,European,0,0,0,0,1,0,0,0,0,0,0,0,0,0
885,American,0,0,0,0,0,0,0,0,0,0,1,0,0,0
885,Japanese,0,0,0,0,0,1,0,0,0,0,0,0,0,0


In [82]:
class EasyDatasetWrapper():
    SUPPORTED_FORMATS = ['stata']
    def __init__(self,
                 main_data: pd.DataFrame,
                 case_index_column: str,  # also known as the purchase record.
                 alternative_column: str,
                 choice_column: str,  # purchase, binary.
                 user_index_column: Optional[str] = None,  # consumer idx.
                 session_index_column: Optional[str] = None,
                 user_observable_data: Optional[Dict[str, pd.DataFrame]] = dict(),
                 item_observable_data: Optional[Dict[str, pd.DataFrame]] = dict(),
                 session_observable_data: Optional[Dict[str, pd.DataFrame]] = dict(),
                 price_observable_data: Optional[Dict[str, pd.DataFrame]] = dict(),
                 format: str = 'stata'):

        if format not in self.SUPPORTED_FORMATS:
            raise ValueError(f'Format {format} is not supported, only {self.SUPPORTED_FORMATS} are supported.') 
        
        self.raw_data = main_data.copy()
        self.main_data = main_data
    
        # check compatibility of the data.
        # TODO:
        self.case_index_column = case_index_column
        self.case_index = main_data[case_index_column].unique()
        self.alternative_column = alternative_column
        self.choice_column = choice_column
        self.user_index_column = user_index_column
        self.session_index_column = session_index_column
        
        self.encode()

        self.align_observable_data(item_observable_data, user_observable_data, session_observable_data, price_observable_data)
        
        self.observable_data_to_observable_tensors()
        
        self.create_choice_dataset_from_stata()

    def encode(self) -> None:
        """Encodes item/alternative names, user names, and session names to {0, 1, 2, ...} integers."""
        self.item_name_encoder = LabelEncoder().fit(self.main_data[self.alternative_column].unique())
        
        if self.user_index_column is not None:
            self.user_name_encoder = LabelEncoder().fit(self.main_data[self.user_index_column].unique())
        
        if self.session_index_column is not None:
            self.session_name_encoder = LabelEncoder().fit(self.main_data[self.session_index_column].unique())

    def align_observable_data(self, item_observable_data, user_observable_data, session_observable_data, price_observable_data) -> None:
        self.item_observable_data = dict()
        for key, val in item_observable_data.items():
            self.item_observable_data['item_' + key] = val.loc[self.item_name_encoder.classes_]
        
        self.user_observable_data = dict() 
        for key, val in user_observable_data.items():
            self.user_observable_data['user_' + key] = val.loc[self.user_name_encoder.classes_]
    
        self.session_observable_data = dict()
        for key, val in session_observable_data.items():
            self.session_observable_data['session_' + key] = val.loc[self.session_name_encoder.classes_]
        
        self.price_observable_data = dict()
        for key, val in price_observable_data.items():
            # we will reshape price observable data to the observable_data_to_observable_tensors stage.
            self.price_observable_data['price_' + key] = val

    def observable_data_to_observable_tensors(self):
        """Convert all self.*_observable_data to self.*_observable_tensors for PyTorch."""
        self.item_observable_tensors = dict()
        for key, val in self.item_observable_data.items():
            self.item_observable_tensors[key] = torch.tensor(val.loc[self.item_name_encoder.classes_].values, dtype=torch.float32)

        self.user_observable_tensors = dict()
        for key, val in self.user_observable_data.items():
            self.user_observable_tensors[key] = torch.tensor(val.loc[self.user_name_encoder.classes_].values, dtype=torch.float32)
        
        self.session_observable_tensors = dict()
        for key, val in self.session_observable_data.items():
            self.session_observable_tensors[key] = torch.tensor(val.loc[self.session_name_encoder.classes_].values, dtype=torch.float32)
        
        self.price_observable_tensors = dict()
        for key, val in self.price_observable_data.items():
            val = val.copy()

            column_list = val.columns
            complete_index = pd.MultiIndex.from_product([self.session_name_encoder.classes_, self.item_name_encoder.classes_],
                                                        names=[self.session_index_column, self.alternative_column])
            val = val.reindex(complete_index)
            # convert item index and session names to the encoded values, and add indices to columns.
            val = val.reset_index()
            val[self.session_index_column] = self.session_name_encoder.transform(val[self.session_index_column].values)
            val[self.alternative_column] = self.item_name_encoder.transform(val[self.alternative_column].values)

            tensor_slice = list()
            for column in column_list:
                df_slice = val.pivot(index=self.session_index_column, columns=self.alternative_column, values=column)
                
                assert np.all(df_slice.index == np.arange(len(self.session_name_encoder.classes_)))
                assert np.all(df_slice.columns == np.arange(len(self.item_name_encoder.classes_)))
                tensor_slice.append(torch.Tensor(df_slice.values).float())

            tensor = torch.stack(tensor_slice, dim=-1)
            self.price_observable_tensors[key] = tensor

    def create_choice_dataset_from_stata(self):
        print('Creating choice dataset from stata format data-frames...')
        choice_set_size = self.main_data.groupby(self.case_index_column)[self.alternative_column].nunique()
        s = choice_set_size.value_counts()
        rep = dict(zip([f'size {x}' for x in s.index], [f'occurrence {x}' for x in s.values]))
        if len(np.unique(choice_set_size)) > 1:
            print(f'Note: choice sets of different sizes found in different purchase records: {rep}')
            self.item_availability = self.get_item_availability_tensor()
 
        item_bought = self.main_data[self.main_data[self.choice_column] == 1].set_index(self.case_index_column).loc[self.case_index, self.alternative_column].values
        self.item_index = self.item_name_encoder.transform(item_bought)

        # user index
        if self.user_index_column is None:
            self.user_index = None
        else:
            # get the user index of each purchase record.
            self.user_index = self.main_data.groupby(self.case_index_column)[self.user_index_column].first().loc[self.case_index].values
            self.user_index = self.user_name_encoder.transform(self.user_index)

        # session index
        if self.session_index_column is None:
            # print('Note: no session index provided, assign each case/purchase record to a unique session index.')
            self.session_index = None
        else:
            self.session_index = self.session_name_encoder.transform(self.main_data.groupby(self.case_index_column)[self.session_index_column].first().loc[self.case_index].values)
        
        self.choice_dataset = ChoiceDataset(item_index=torch.LongTensor(self.item_index),
                                            user_index=torch.LongTensor(self.user_index) if self.user_index is not None else None,
                                            session_index=torch.LongTensor(self.session_index) if self.session_index is not None else None,
                                            item_availability=self.item_availability,
                                            **self.item_observable_tensors,
                                            **self.user_observable_tensors,
                                            **self.session_observable_tensors,
                                            **self.price_observable_tensors)

    def get_item_availability_tensor(self) -> torch.BoolTensor:
        if self.session_index_column is None:
            raise ValueError(f'Item availability cannot be constructed without session index column.')
        A = self.main_data.pivot(self.session_index_column, self.alternative_column, self.choice_column)
        return torch.BoolTensor(~np.isnan(A.values))

    def __len__(self):
        return len(self.item_index)
    
    def summary(self):
        print(f'* Space of {len(self.item_name_encoder.classes_)} items:\n', pd.DataFrame(data={'item name': self.item_name_encoder.classes_}, index=np.arange(len(self.item_name_encoder.classes_))).T)
        print(f'* Number of purchase records/cases: {len(self)}.')
        print('* Preview of main data frame:')
        print(self.main_data)
        print('* Preview of ChoiceDataset:')
        print(self.choice_dataset)

# Build Datasets using `EasyDatasetWrapper`
We first need to provide the main dataset to the wrapper, then we need to tell the wrapper a bit information about the data.
In our example, `consumerid` column in the main dataset identifies all of `case_index_column`, `session_index_column`, and `purchase_record_index_column`.

In [104]:
data = EasyDatasetWrapper(main_data=df_main,
                          case_index_column='consumerid',
                          alternative_column='car',
                          choice_column='purchase',
                          session_index_column='consumerid',
                          user_index_column='consumerid',
                          user_observable_data={'gender': gender, 'income': income},
                          price_observable_data={'dealer': dealers})

Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}


In [105]:
# Use summary to see what's inside the data wrapper.
data.summary()

* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase
0              1  American         1
1              1  Japanese         0
2              1  European         0
3              1    Korean         0
4              2  American         1
...          ...       ...       ...
3155         884  Japanese         1
3156         884  European         0
3157         885  American         1
3158         885  Japanese         0
3159         885  European         0

[3160 rows x 3 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 2], user_income=[885, 1], price_dealer=[885, 4, 14], device=cpu)


In [75]:
len(data)

885