# Data Management Tutorial
**Author: Tianyu Du (tianyudu@stanford.edu)**

This notebook aims to help users understand the functionality of `ChoiceDataset` object. The `ChoiceDataset` is the core
class to hold 

Since this package was initially proposed for modelling consumer choices, the naming convention follows the consumer choice
scenario.

Now imagine that you have access to purchase records of customers of a supermarket.

## Core Components of the Dataset
We begin with notations and essential factors of the prediction problem. Suppose there are $I$ items under our consideration and the set of **items** $i \in \{1,2,\dots,I\}$ can be grouped into $C$ **categories** $c \in \{1,2,\dots,C\}$, further, let $I_c$ denote the collection of items in category $c$ (**NOTE** category information is only needed if you are running the *nested* logit model).

Since we will be using PyTorch to train our model, we represent their identities using integers. Moreover, this document will use lower cases $i, c$, etc to index items and categories respectively.

Let $B$ denote the number of purchasing records in the dataset, each $b \in \{1,2,\dots, B\}$ record is associated with an **user** $u \in \{1,2,\dots,U\}$ and a **session** $s \in \{1,2,\dots, S\}$. When there are multiple items bought in the same shopping trip, there will be multiple rows in the dataset with the same $(u, s)$.

One canonical example of session $s$ is the date of purchase or the shopping trip. 

The `ChoiceDataset` data manager is initialized with the following PyTorch tensors:

1. `label` $\in \{1,2,\dots,I\}^B$ : the ID of the bought item.
2. `user_index` $\in \{1,2,\dots,U\}^B$: the ID of the corresponding user (shopper).
3. `session_index` $\in \{1,2,\dots,S\}^B$
4. `item_availability` $\in \{\texttt{True}, \texttt{False}\}^{S\times I}$  identifies the availability of items in each session, the model will ignore unavailable items while making prediction.
5. `user_obs` $\in \mathbb{R}^{U\times K_{user}}$
6. `item_obs` $\in \mathbb{R}^{I\times K_{item}}$
7. `session_obs` $\in \mathbb{R}^{S \times K_{session}}$
8. `price_obs` $\in \mathbb{R}^{S \times I \times K_{price}}$

## Example
Suppose we have a dataset of pucrhase history from two stores (Store A and B) on two dates (Sep 16 and 17), both stores sell {apple, banana, orange} (`num_items=3`) and there are three people came to those stores between Sep 16 and 17.

| user_index | session_index       | label  |
| ---------- | ------------------- | ------ |
| Amy        | Sep-17-2021-Store-A | banana |
| Ben        | Sep-17-2021-Store-B | apple  |
| Ben        | Sep-16-2021-Store-A | orange |
| Charlie    | Sep-16-2021-Store-B | apple  |
| Charlie    | Sep-16-2021-Store-B | orange |

**NOTE**: For demonstration purpose, the example dataset has `user_index`, `session_index` and `label` as strings, they should be consecutive integers in actual production. One can easily convert them to integers using `sklearn.preprocessing.LabelEncoder`. In this case, `user_index=[0,1,1,2,2]`, `session_index=[0,1,2,3,3]`, and `label=[0,1,2,1,2]`. Suppose we believe people's purchasing decision depends on nutrition levels of these fruits, suppose apple has the highest nutrition level and banana has the lowest one, we can add `item_obs=[1.5, 12.0, 3.3]` (recall the integer encoding of fruits based on the `label` variable above. These numbers were arbitrary )

**NOTE**: If someone went to one store and bought multiple items (e.g., Charlie bought both apple and orange at Store B on Sep-16), we include them as separate rows in the dataset and model them independently.


In [2]:
# import required dependencies.
import numpy as np
import torch
from torch_choice.data import ChoiceDataset, JointDataset

In [3]:
def print_dict_shape(d):
    for key, val in d.items():
        if torch.is_tensor(val):
            print(f'dict.{key}.shape={val.shape}')

## Creating  `ChoiceDataset` Object

In [4]:
# Creates some fake input features, feel free to modify it as you want.
num_users = 10
num_items = 4
num_sessions = 500

length_of_dataset = 10000

### Step 1: Gather Information.
The first step is to create tensors encompassing information about users, items and sessions. Here we are providing some random tensors for demonstration purpose only. Please check out tutorials for the logit model and the nested logit model for real examples. 

In [5]:
# create observables/features, the number of parameters are
# arbitrarily chosen.
user_obs = torch.randn(num_users, 128)  # generate 128 features for each user, e.g., race, gender.
item_obs = torch.randn(num_items, 64)  # generate 64 features for each user, e.g., quality.
session_obs = torch.randn(num_sessions, 10)  # generate 10 features for each session, e.g., weekday indicator. 
price_obs = torch.randn(num_sessions, num_items, 12)  # generate 12 features for each session user pair, e.g., the budget of that user at the shopping day.

In [6]:
label = torch.LongTensor(np.random.choice(num_items, size=length_of_dataset))
user_index = torch.LongTensor(np.random.choice(num_users, size=length_of_dataset))
session_index = torch.LongTensor(np.random.choice(num_sessions, size=length_of_dataset))

# assume all items are available in all sessions.
item_availability = torch.ones(num_sessions, num_items).bool()

### Step 2: Initialize the `ChoiceDataset`.
You can construct a choice set using the following code, which manage all information for you.

In [7]:
dataset = ChoiceDataset(
    # pre-specified keywords of __init__
    label=label,  # required.
    # optional:
    user_index=user_index,
    session_index=session_index,
    item_availability=item_availability,
    # additional keywords of __init__
    user_obs=user_obs,
    item_obs=item_obs,
    session_obs=session_obs,
    price_obs=price_obs)

The `__repr__` string of choice dataset object provides you with shapes of tensors the dataset is holding.

In [8]:
dataset

ChoiceDataset(label=[10000], user_index=[10000], session_index=[10000], item_availability=[500, 4], observable_prefix=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[500, 10], price_obs=[500, 4, 12], device=cpu)

You can use the `num_{users, items, sessions}` attribute to obtain the number of users, items, and sessions, they are determined automatically from the `{user, item, session}_obs` tensors provided while initializing the dataset object.

In [9]:
print(f'{dataset.num_users=:}')
print(f'{dataset.num_items=:}')
print(f'{dataset.num_sessions=:}')
print(f'{len(dataset)=:}')

dataset.num_users=10
dataset.num_items=4
dataset.num_sessions=500
len(dataset)=10000


In [10]:
# clone
print(dataset.label[:10])
dataset_cloned = dataset.clone()
dataset_cloned.label = 99 * torch.ones(num_sessions)
print(dataset_cloned.label[:10])
print(dataset.label[:10])  # does not change the original dataset.

tensor([2, 1, 2, 1, 3, 0, 0, 0, 3, 2])
tensor([99., 99., 99., 99., 99., 99., 99., 99., 99., 99.])
tensor([2, 1, 2, 1, 3, 0, 0, 0, 3, 2])


In [11]:
# move to device
print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.user_index.device=:}')
print(f'{dataset.session_index.device=:}')

dataset = dataset.to('cuda')

print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.user_index.device=:}')
print(f'{dataset.session_index.device=:}')

dataset.device=cpu
dataset.label.device=cpu
dataset.user_index.device=cpu
dataset.session_index.device=cpu
dataset.device=cuda:0
dataset.label.device=cuda:0
dataset.user_index.device=cuda:0
dataset.session_index.device=cuda:0


In [12]:
dataset._check_device_consistency()

In [13]:
# # NOTE: this cell will result errors, this is intentional.
# dataset.label = dataset.label.to('cpu')
# dataset._check_device_consistency()

In [14]:
# create dictionary inputs for model.forward()
# collapse to a dictionary object.
print_dict_shape(dataset.x_dict)

dict.user_obs.shape=torch.Size([10000, 4, 128])
dict.item_obs.shape=torch.Size([10000, 4, 64])
dict.session_obs.shape=torch.Size([10000, 4, 10])
dict.price_obs.shape=torch.Size([10000, 4, 12])


In [15]:
# __getitem__ to get batch.
# pick 5 random sessions as the mini-batch.
dataset = dataset.to('cpu')
indices = torch.Tensor(np.random.choice(len(dataset), size=5, replace=False)).long()
print(indices)
subset = dataset[indices]
print(dataset)
print(subset)
print_dict_shape(subset.x_dict)

assert torch.all(dataset.x_dict['price_obs'][indices, :, :] == subset.x_dict['price_obs'])
assert torch.all(dataset.label[indices] == subset.label)

tensor([7117, 4566, 6732, 7475, 3698])
ChoiceDataset(label=[10000], user_index=[10000], session_index=[10000], item_availability=[500, 4], observable_prefix=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[500, 10], price_obs=[500, 4, 12], device=cpu)
ChoiceDataset(label=[5], user_index=[5], session_index=[5], item_availability=[500, 4], observable_prefix=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[500, 10], price_obs=[500, 4, 12], device=cpu)
dict.user_obs.shape=torch.Size([5, 4, 128])
dict.item_obs.shape=torch.Size([5, 4, 64])
dict.session_obs.shape=torch.Size([5, 4, 10])
dict.price_obs.shape=torch.Size([5, 4, 12])


In [None]:
print(subset.label)
print(dataset.label[indices])

subset.label += 1  # modifying the batch does not change the original dataset.

print(subset.label)
print(dataset.label[indices])

tensor([3, 1, 1, 0, 2])
tensor([3, 1, 1, 0, 2])
tensor([4, 2, 2, 1, 3])
tensor([3, 1, 1, 0, 2])


In [19]:
print(subset.item_obs[0, 0])
print(dataset.item_obs[0, 0])
subset.item_obs += 1
print(subset.item_obs[0, 0])
print(dataset.item_obs[0, 0])

tensor(-0.4046)
tensor(-0.4046)
tensor(0.5954)
tensor(-0.4046)


In [None]:
print(id(subset.label))
print(id(dataset.label[indices]))

139651371619408
139651371561184


## Using Pytorch dataloader for the training loop.

In [20]:
# from torch.utils.data.sampler import BatchSampler, SequentialSampler, RandomSampler
# shuffle = False
# batch_size = 32

# sampler = BatchSampler(
#     RandomSampler(dataset) if shuffle else SequentialSampler(dataset),
#     batch_size=batch_size,
#     drop_last=False)

# dataloader = torch.utils.data.DataLoader(dataset,
#                                          sampler=sampler,
#                                          num_workers=0,  # 0 if dataset.device == 'cuda' else os.cpu_count(),
#                                          collate_fn=lambda x: x[0],
#                                          pin_memory=(dataset.device == 'cpu'))


In [21]:
# print(f'{item_obs.shape=:}')
# item_obs_all = item_obs.view(1, num_items, -1).expand(num_sessions, -1, -1)
# item_obs_all = item_obs_all.to(dataset.device)
# label_all = label.to(dataset.device)
# print(f'{item_obs_all.shape=:}')

item_obs.shape=torch.Size([4, 64])
item_obs_all.shape=torch.Size([500, 4, 64])


In [23]:
# for i, batch in enumerate(dataloader):
#     # check consistency.
#     first, last = i * batch_size, min(len(dataset), (i + 1) * batch_size)
#     idx = torch.arange(first, last)
#     assert torch.all(item_obs_all[idx, :, :] == batch.x_dict['item_obs'])
#     assert torch.all(label_all[idx] == batch.label)

## Chaining Multiple Datasets: `JointDataset` Examples

In [24]:
dataset1 = dataset.clone()
dataset2 = dataset.clone()
joint_dataset = JointDataset(the_dataset=dataset1, another_dataset=dataset2)

In [25]:
joint_dataset

JointDataset with 2 sub-datasets: (
	the_dataset: ChoiceDataset(label=[10000], user_index=[10000], session_index=[10000], item_availability=[500, 4], observable_prefix=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[500, 10], price_obs=[500, 4, 12], device=cpu)
	another_dataset: ChoiceDataset(label=[10000], user_index=[10000], session_index=[10000], item_availability=[500, 4], observable_prefix=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[500, 10], price_obs=[500, 4, 12], device=cpu)
)