# `ChoiceDataset` Examples
This notebook aims to help users understand the functionality of `ChoiceDataset` object.
**TODO**: add more detailed descriptions explaining features.

In [20]:
import numpy as np
import torch
from choice_dataset import ChoiceDataset

In [21]:
def print_dict_shape(d):
    for key, val in d.items():
        if torch.is_tensor(val):
            print(f'dict.{key}.shape={val.shape}')

## Create Aritfical Data

In [22]:
# Creates some fake input features, feel free to modify it as you want.
num_users = 10
num_user_features = 128
num_items = 4
num_item_features = 64
num_sessions = 10000

In [23]:
# create observables/features/covariates, the number of parameters are
# arbitrarily chosen.
user_obs = torch.randn(num_users, 128)  # generate 128 features for each user.
item_obs = torch.randn(num_items, 64)
session_obs = torch.randn(num_sessions, 234)
taste_obs = torch.randn(num_users, num_items, 567)
price_obs = torch.randn(num_sessions, num_items, 12)

In [24]:
label = torch.LongTensor(np.random.choice(num_items, size=num_sessions))

user_onehot = torch.zeros(num_sessions, num_users)
user_idx = torch.LongTensor(np.random.choice(num_users, size=num_sessions))
user_onehot[torch.arange(num_sessions), user_idx] = 1

item_availability = torch.ones(num_sessions, num_items).bool()

In [25]:
dataset = ChoiceDataset(
    # required keywords of __init__
    label=label.long(),
    user_onehot=user_onehot.long(),
    item_availability=item_availability.bool(),
    # optional keywords of __init__
    user_obs=user_obs,
    item_obs=item_obs,
    session_obs=session_obs,
    taste_obs=taste_obs,
    price_obs=price_obs)

In [26]:
dataset

ChoiceDataset(label=[10000], user_onehot=[10000, 10], item_availability=[10000, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[10000, 234], taste_obs=[10, 4, 567], price_obs=[10000, 4, 12], device=cpu)

In [27]:
print(f'{dataset.num_users=:}')
print(f'{dataset.num_items=:}')
print(f'{dataset.num_sessions=:}')
print(f'{len(dataset)=:}')

dataset.num_users=10
dataset.num_items=4
dataset.num_sessions=10000
len(dataset)=10000


In [28]:
# clone
print(dataset.label[:10])
dataset_cloned = dataset.clone()
dataset_cloned.label = 99 * torch.ones(num_sessions)
print(dataset_cloned.label[:10])
print(dataset.label[:10])  # does not change the original dataset.

tensor([1, 2, 0, 0, 1, 1, 3, 0, 0, 1])
tensor([99., 99., 99., 99., 99., 99., 99., 99., 99., 99.])
tensor([1, 2, 0, 0, 1, 1, 3, 0, 0, 1])


In [29]:
# move to device
print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.taste_obs.device=:}')
print(f'{dataset.user_onehot.device=:}')

dataset = dataset.to('cuda')

print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.taste_obs.device=:}')
print(f'{dataset.user_onehot.device=:}')

dataset.device=cpu
dataset.label.device=cpu
dataset.taste_obs.device=cpu
dataset.user_onehot.device=cpu
dataset.device=cuda:0
dataset.label.device=cuda:0
dataset.taste_obs.device=cuda:0
dataset.user_onehot.device=cuda:0


In [35]:
dataset._check_device_consistency()

In [36]:
# NOTE: this cell will result errors, this is intentional.
dataset.label = dataset.label.to('cpu')
dataset._check_device_consistency()

Exception: ("Found tensors on different devices: {device(type='cuda', index=0), device(type='cpu')}.", 'Use dataset.to() method to align devices.')

In [30]:
# create dictionary inputs for model.forward()
print_dict_shape(dataset.x_dict)

dict.user_obs.shape=torch.Size([10000, 4, 128])
dict.item_obs.shape=torch.Size([10000, 4, 64])
dict.session_obs.shape=torch.Size([10000, 4, 234])
dict.taste_obs.shape=torch.Size([10000, 4, 567])
dict.price_obs.shape=torch.Size([10000, 4, 12])


In [34]:
# __getitem__ to get batch.
# pick 5 random sessions as the mini-batch.
indices = torch.LongTensor(np.random.choice(len(dataset), size=5, replace=False))
subset = dataset[indices]
print(subset)
print_dict_shape(subset.x_dict)

assert torch.all(dataset.x_dict['price_obs'][indices, :, :] == subset.x_dict['price_obs'])

ChoiceDataset(label=[5], user_onehot=[5, 10], item_availability=[5, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[5, 234], taste_obs=[10, 4, 567], price_obs=[5, 4, 12], device=cuda:0)
dict.user_obs.shape=torch.Size([5, 4, 128])
dict.item_obs.shape=torch.Size([5, 4, 64])
dict.session_obs.shape=torch.Size([5, 4, 234])
dict.taste_obs.shape=torch.Size([5, 4, 567])
dict.price_obs.shape=torch.Size([5, 4, 12])
