# `ChoiceDataset` Examples
This notebook aims to help users understand the functionality of `ChoiceDataset` object.
**TODO**: add more detailed descriptions explaining features.

In [22]:
import numpy as np
import torch
from choice_dataset import ChoiceDataset
from joint_dataset import JointDataset

In [23]:
def print_dict_shape(d):
    for key, val in d.items():
        if torch.is_tensor(val):
            print(f'dict.{key}.shape={val.shape}')

## Create Aritfical Data

In [24]:
# Creates some fake input features, feel free to modify it as you want.
num_users = 10
num_user_features = 128
num_items = 4
num_item_features = 64
num_sessions = 10000

In [25]:
# create observables/features/covariates, the number of parameters are
# arbitrarily chosen.
user_obs = torch.randn(num_users, 128)  # generate 128 features for each user.
item_obs = torch.randn(num_items, 64)
session_obs = torch.randn(num_sessions, 234)
taste_obs = torch.randn(num_users, num_items, 567)
price_obs = torch.randn(num_sessions, num_items, 12)

In [26]:
label = torch.LongTensor(np.random.choice(num_items, size=num_sessions))

user_onehot = torch.zeros(num_sessions, num_users)
user_idx = torch.LongTensor(np.random.choice(num_users, size=num_sessions))
user_onehot[torch.arange(num_sessions), user_idx] = 1

item_availability = torch.ones(num_sessions, num_items).bool()

In [27]:
dataset = ChoiceDataset(
    # pre-specified keywords of __init__
    label=label.long(),  # required.
    # optional:
    user_onehot=user_onehot.long(),
    item_availability=item_availability.bool(),
    # additional keywords of __init__
    user_obs=user_obs,
    item_obs=item_obs,
    session_obs=session_obs,
    taste_obs=taste_obs,
    price_obs=price_obs)

In [28]:
dataset

ChoiceDataset(label=[10000], user_onehot=[10000, 10], item_availability=[10000, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[10000, 234], taste_obs=[10, 4, 567], price_obs=[10000, 4, 12], device=cpu)

In [29]:
print(f'{dataset.num_users=:}')
print(f'{dataset.num_items=:}')
print(f'{dataset.num_sessions=:}')
print(f'{len(dataset)=:}')

dataset.num_users=10
dataset.num_items=4
dataset.num_sessions=10000
len(dataset)=10000


In [30]:
# clone
print(dataset.label[:10])
dataset_cloned = dataset.clone()
dataset_cloned.label = 99 * torch.ones(num_sessions)
print(dataset_cloned.label[:10])
print(dataset.label[:10])  # does not change the original dataset.

tensor([1, 2, 3, 3, 2, 0, 3, 0, 1, 1])
tensor([99., 99., 99., 99., 99., 99., 99., 99., 99., 99.])
tensor([1, 2, 3, 3, 2, 0, 3, 0, 1, 1])


In [31]:
# move to device
print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.taste_obs.device=:}')
print(f'{dataset.user_onehot.device=:}')

dataset = dataset.to('cuda')

print(f'{dataset.device=:}')
print(f'{dataset.label.device=:}')
print(f'{dataset.taste_obs.device=:}')
print(f'{dataset.user_onehot.device=:}')

dataset.device=cpu
dataset.label.device=cpu
dataset.taste_obs.device=cpu
dataset.user_onehot.device=cpu
dataset.device=cuda:0
dataset.label.device=cuda:0
dataset.taste_obs.device=cuda:0
dataset.user_onehot.device=cuda:0


In [32]:
dataset._check_device_consistency()

In [33]:
# # NOTE: this cell will result errors, this is intentional.
# dataset.label = dataset.label.to('cpu')
# dataset._check_device_consistency()

In [34]:
# create dictionary inputs for model.forward()
print_dict_shape(dataset.x_dict)

dict.user_obs.shape=torch.Size([10000, 4, 128])
dict.item_obs.shape=torch.Size([10000, 4, 64])
dict.session_obs.shape=torch.Size([10000, 4, 234])
dict.taste_obs.shape=torch.Size([10000, 4, 567])
dict.price_obs.shape=torch.Size([10000, 4, 12])


In [35]:
# __getitem__ to get batch.
# pick 5 random sessions as the mini-batch.
indices = torch.LongTensor(np.random.choice(len(dataset), size=5, replace=False))
print(indices)
subset = dataset[indices]
print(dataset)
print(subset)
print_dict_shape(subset.x_dict)

assert torch.all(dataset.x_dict['price_obs'][indices, :, :] == subset.x_dict['price_obs'])
assert torch.all(dataset.label[indices] == subset.label)

tensor([4985, 4855, 9243, 1328, 4047])
ChoiceDataset(label=[10000], user_onehot=[10000, 10], item_availability=[10000, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[10000, 234], taste_obs=[10, 4, 567], price_obs=[10000, 4, 12], device=cuda:0)
ChoiceDataset(label=[5], user_onehot=[5, 10], item_availability=[5, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[5, 234], taste_obs=[10, 4, 567], price_obs=[5, 4, 12], device=cuda:0)
dict.user_obs.shape=torch.Size([5, 4, 128])
dict.item_obs.shape=torch.Size([5, 4, 64])
dict.session_obs.shape=torch.Size([5, 4, 234])
dict.taste_obs.shape=torch.Size([5, 4, 567])
dict.price_obs.shape=torch.Size([5, 4, 12])


In [36]:
print(subset.label)
print(dataset.label[indices])

subset.label += 1  # modifying the batch does not change the original dataset.

print(subset.label)
print(dataset.label[indices])

tensor([2, 0, 0, 1, 3], device='cuda:0')
tensor([2, 0, 0, 1, 3], device='cuda:0')
tensor([3, 1, 1, 2, 4], device='cuda:0')
tensor([2, 0, 0, 1, 3], device='cuda:0')


In [37]:
print(id(subset.label))
print(id(dataset.label[indices]))

140292701334144
140292701228608


## Using Pytorch dataloader for the training loop.

In [38]:
from torch.utils.data.sampler import BatchSampler, SequentialSampler, RandomSampler
shuffle = False
batch_size = 32

sampler = BatchSampler(
    RandomSampler(dataset) if shuffle else SequentialSampler(dataset),
    batch_size=batch_size,
    drop_last=False)

dataloader = torch.utils.data.DataLoader(dataset,
                                         sampler=sampler,
                                         num_workers=0,  # 0 if dataset.device == 'cuda' else os.cpu_count(),
                                         collate_fn=lambda x: x[0],
                                         pin_memory=(dataset.device == 'cpu'))


In [39]:
print(f'{item_obs.shape=:}')
item_obs_all = item_obs.view(1, num_items, -1).expand(num_sessions, -1, -1)
item_obs_all = item_obs_all.to(dataset.device)
label_all = label.to(dataset.device)
print(f'{item_obs_all.shape=:}')

item_obs.shape=torch.Size([4, 64])
item_obs_all.shape=torch.Size([10000, 4, 64])


In [40]:
for i, batch in enumerate(dataloader):
    # check consistency.
    first, last = i * batch_size, min(len(dataset), (i + 1) * batch_size)
    idx = torch.arange(first, last)
    assert torch.all(item_obs_all[idx, :, :] == batch.x_dict['item_obs'])
    assert torch.all(label_all[idx] == batch.label)

# `JointDataset` Examples

In [41]:
dataset1 = dataset.clone()
dataset2 = dataset.clone()
joint_dataset = JointDataset(the_dataset=dataset1, another_dataset=dataset2)

In [21]:
joint_dataset

JointDataset with 2 sub-datasets: (
	the_dataset: ChoiceDataset(label=[10000], user_onehot=[10000, 10], item_availability=[10000, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[10000, 234], taste_obs=[10, 4, 567], price_obs=[10000, 4, 12], device=cuda:0)
	another_dataset: ChoiceDataset(label=[10000], user_onehot=[10000, 10], item_availability=[10000, 4], variable_types=[5], user_obs=[10, 128], item_obs=[4, 64], session_obs=[10000, 234], taste_obs=[10, 4, 567], price_obs=[10000, 4, 12], device=cuda:0)
)