In [1]:
import sys
import os
sys.path.append('../')
os.environ['SEQ_SPLITS_DATA_PATH'] = "../data/"

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from hydra import compose, initialize
from omegaconf import OmegaConf


In [4]:
from runs.train import prepare_data, create_dataloaders


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
config = OmegaConf.load('../runs/configs/train.yaml')

with initialize(config_path="../runs/configs/"):  
    config = compose(
        config_name="train",      
        overrides=[
            "quantile=0.9",
            "split_subtype=val_by_time",
            "dataset=Beauty",
            "model=GPT2",
            "trainer_params.max_epochs=10",
            "use_semantic_ids=True",
            "semantic_ids_map_path=/home/jovyan/gusak/semantic_seqrec/data/item_sem_id_modified.pkl",
            "model.mode=greedy",
            
        ],
        return_hydra_config=False,
    )

print(OmegaConf.to_yaml(config, resolve=True))

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../runs/configs/"):


cuda_visible_devices: 0
random_state: 101
clearml_project_folder: null
clearml_task_name: null
use_pretrained_embeddings: false
pretrained_embeddings:
  add_padding_emb: true
  freeze: false
use_semantic_ids: true
semantic_ids_map_path: /home/jovyan/gusak/semantic_seqrec/data/item_sem_id_modified.pkl
semantic_ids_len: 4
split_type: global_timesplit
split_subtype: val_by_time
quantile: 0.9
validation_quantile: 0.9
dataset_params:
  max_length: 128
dataloader:
  batch_size: 128
  test_batch_size: 256
  num_workers: 8
  validation_size: 2048
seqrec_module:
  lr: 0.001
  predict_top_k: 10
  filter_seen: false
trainer_params:
  max_epochs: 10
  accelerator: gpu
patience: 20
load_if_possible: false
evaluator:
  successive_val: false
  successive_test: false
  successive_test_retrained: false
  calc_successive_metrics_val: true
  calc_successive_metrics_test: true
  calc_successive_metrics_test_retrained: true
  successive_replay_metrics: false
  metrics:
  - NDCG
  - HitRate
  - MRR
  - Cove

In [6]:
config.cuda_visible_devices = ''

In [7]:
import os
import time

import hydra
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from clearml import Task
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import (EarlyStopping, ModelCheckpoint,
                                         ModelSummary, TQDMProgressBar)
from torch.utils.data import DataLoader

from src.datasets import (CausalLMDataset, CausalLMPredictionDataset,
                          PaddingCollateFn)
from src.metrics import Evaluator
from src.models import SASRec
from src.modules import SeqRec, SeqRecHuggingface
from src.postprocess import preds2recs
from src.prepr import last_item_split

import itertools
import pickle
from tqdm.auto import tqdm

### Prepare data

In [8]:
if hasattr(config, 'cuda_visible_devices'):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(config.cuda_visible_devices)


train, validation, test, max_item_id, global_timepoint, global_timepoint_val = prepare_data(config)

train_loader, eval_loader = create_dataloaders(train, validation, config)

train shape (160178, 4)
validation shape (66297, 4)
test shape (70263, 4)
   user_id               item_id   timestamp            user_id_old
0        1  [117, 318, 621, 768]  1357430400  A00473363TJ8YSZ3YAGG9
1        1  [114, 386, 594, 768]  1384387200  A00473363TJ8YSZ3YAGG9
2        1    [1, 470, 745, 768]  1384387200  A00473363TJ8YSZ3YAGG9
3        1   [91, 411, 726, 768]  1387843200  A00473363TJ8YSZ3YAGG9
4        2   [6, 377, 571, 1013]  1385337600  A00700212KB3K0MVESPIY
   user_id              item_id   timestamp            user_id_old
0        0  [67, 376, 750, 768]  1405296000  A00414041RD0BXM6WK0GX
1        0  [67, 390, 715, 768]  1405296000  A00414041RD0BXM6WK0GX
2        0  [67, 313, 751, 768]  1405296000  A00414041RD0BXM6WK0GX
3        0  [67, 365, 605, 768]  1405296000  A00414041RD0BXM6WK0GX
4        0  [67, 482, 721, 768]  1405296000  A00414041RD0BXM6WK0GX
   user_id               item_id   timestamp            user_id_old
0        1  [117, 318, 621, 768]  1357430400  A0

   user_id item_id   timestamp            user_id_old
0        1     117  1357430400  A00473363TJ8YSZ3YAGG9
0        1     318  1357430400  A00473363TJ8YSZ3YAGG9
0        1     621  1357430400  A00473363TJ8YSZ3YAGG9
0        1     768  1357430400  A00473363TJ8YSZ3YAGG9
1        1     114  1384387200  A00473363TJ8YSZ3YAGG9
   user_id item_id   timestamp            user_id_old
0        0      67  1405296000  A00414041RD0BXM6WK0GX
0        0     376  1405296000  A00414041RD0BXM6WK0GX
0        0     750  1405296000  A00414041RD0BXM6WK0GX
0        0     768  1405296000  A00414041RD0BXM6WK0GX
1        0      67  1405296000  A00414041RD0BXM6WK0GX
   user_id item_id   timestamp            user_id_old
0        1     117  1357430400  A00473363TJ8YSZ3YAGG9
0        1     318  1357430400  A00473363TJ8YSZ3YAGG9
0        1     621  1357430400  A00473363TJ8YSZ3YAGG9
0        1     768  1357430400  A00473363TJ8YSZ3YAGG9
1        1     114  1384387200  A00473363TJ8YSZ3YAGG9
Test global timepoint 139993

In [10]:
max_item_id

1706

### Batch example

In [11]:
batch = next(iter(eval_loader))

print(batch)

{'input_ids': tensor([[162, 466, 580,  ...,   0,   0,   0],
        [120, 460, 664,  ...,   0,   0,   0],
        [128, 407, 661,  ...,   0,   0,   0],
        ...,
        [ 75, 400, 715,  ...,   0,   0,   0],
        [167, 419, 751,  ...,   0,   0,   0],
        [221, 495, 615,  ...,   0,   0,   0]]), 'user_id': tensor([   4,   11,   24,   28,   42,   47,   48,   69,   75,   82,   90,   93,
         102,  115,  123,  130,  133,  138,  144,  153,  168,  184,  191,  192,
         196,  198,  200,  201,  205,  206,  221,  229,  256,  259,  263,  268,
         278,  279,  294,  295,  297,  298,  334,  349,  351,  354,  365,  370,
         381,  386,  387,  424,  450,  454,  468,  469,  484,  486,  497,  506,
         528,  539,  545,  570,  580,  594,  600,  610,  616,  617,  628,  643,
         652,  661,  663,  678,  688,  702,  718,  720,  737,  748,  769,  792,
         806,  829,  830,  831,  835,  881,  884,  890,  891,  901,  915,  916,
         920,  921,  924,  926,  932,  946, 

### Model creation

In [12]:
pwd

'/home/jovyan/gusak/semantic_seqrec/notebooks'

In [12]:
from runs.train import create_model

In [28]:
config.model.model_params.n_positions = config.dataset_params.max_length * config.semantic_ids_len

In [29]:
config.model.model_params.n_positions

512

In [30]:
model = create_model(config, item_count=max_item_id)

In [31]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(1707, 64)
    (wpe): Embedding(512, 64)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-1): 2 x GPT2Block(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=192, nx=64)
          (c_proj): Conv1D(nf=64, nx=64)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=256, nx=64)
          (c_proj): Conv1D(nf=64, nx=256)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=64, out_features=1707, bias=False)
)

In [32]:
retrain = False
split_subtype = config.split_subtype or ''
q = 'q0' + str(config.quantile)[2:] if config.split_type == 'global_timesplit' else ''
model_path = os.path.join(
    os.path.dirname(os.path.abspath('.')), 'models', config.split_type,
    split_subtype, config.dataset.name, q, config.model.model_class, 'retrain_with_val' if retrain else '')

print(model_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)

/home/jovyan/gusak/semantic_seqrec/models/global_timesplit/val_by_time/Beauty/q09/GPT-2/


In [33]:
config.trainer_params.accelerator = 'cpu'

In [34]:
if config.model.model_class == 'SASRec':
    file_name = (
        f"{config.model.model_params.hidden_units}_"
        f"{config.model.model_params.num_blocks}_"
        f"{config.model.model_params.num_heads}_"
        f"{config.model.model_params.dropout_rate}_"
        f"{config.model.model_params.maxlen}_"
        f"{config.dataloader.batch_size}_"
        f"{config.random_state}"
    )
elif config.model.model_class == 'GPT-2':
    file_name = (
        f"{config.model.model_params.n_embd}_"
        f"{config.model.model_params.n_layer}_"
        f"{config.model.model_params.n_head}_"
        f"{config.dataloader.batch_size}_"
        f"{config.random_state}"
    )

checkpoint_file = os.path.join(model_path, file_name + ".ckpt")

if config.model.model_class == 'GPT-2':
    seqrec_module = SeqRecHuggingface(model, **config['seqrec_module'])
    if config.model.generation:
        with open(config.semantic_ids_map_path, 'rb') as f:
            index2semid = pickle.load(f)
        inv_map = {tuple(sem_ids): item_id for item_id, sem_ids in index2semid.items()}
        seqrec_module.set_predict_mode(generate=True, mode=config.model.mode,
                                        N=config.semantic_ids_len,
                                        inv_map=inv_map,
                                        **config.model.generation_params)
else:   
    seqrec_module = SeqRec(model, **config['seqrec_module']) 

model_summary = ModelSummary(max_depth=1)
progress_bar = TQDMProgressBar(refresh_rate=20)

checkpoint = ModelCheckpoint(
    dirpath=model_path,  
    filename='_' + file_name,           
    save_top_k=1,
    monitor="val_ndcg",
    mode="max",
    save_weights_only=True
)
early_stopping = EarlyStopping(monitor="val_ndcg", mode="max",
                            patience=config.patience, verbose=False)
callbacks = [early_stopping, model_summary, checkpoint, progress_bar]

trainer = pl.Trainer(callbacks=callbacks, enable_checkpointing=True, 
                        **config['trainer_params'])

start_time = time.time()


Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Traning

In [59]:
next(iter(train_loader))

{'input_ids': tensor([[ 37, 371, 559,  ...,   0,   0,   0],
         [  1, 407, 520,  ...,   0,   0,   0],
         [239, 408, 644,  ...,   0,   0,   0],
         ...,
         [211, 470, 548,  ...,   0,   0,   0],
         [236, 280, 648,  ...,   0,   0,   0],
         [ 83, 408, 701,  ...,   0,   0,   0]]),
 'labels': tensor([[  37,  371,  559,  ..., -100, -100, -100],
         [   1,  407,  520,  ..., -100, -100, -100],
         [ 239,  408,  644,  ..., -100, -100, -100],
         ...,
         [ 211,  470,  548,  ..., -100, -100, -100],
         [ 236,  280,  648,  ..., -100, -100, -100],
         [  83,  408,  701,  ..., -100, -100, -100]]),
 'attention_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         ...,
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]])}

In [36]:
batch_val = next(iter(eval_loader))
batch_val

{'input_ids': tensor([[  0,   0,   0,  ..., 316, 677, 768],
         [  0,   0,   0,  ..., 269, 696, 768],
         [  0,   0,   0,  ..., 345, 520, 768],
         ...,
         [  0,   0,   0,  ..., 313, 669, 768],
         [  0,   0,   0,  ..., 346, 566, 768],
         [  0,   0,   0,  ..., 295, 644, 768]]),
 'user_id': tensor([  22,   24,   27,   44,   46,   47,   55,   63,   66,   69,  102,  115,
          140,  144,  152,  159,  184,  192,  194,  198,  205,  206,  241,  248,
          249,  258,  259,  275,  278,  279,  285,  287,  289,  294,  295,  298,
          303,  311,  334,  354,  359,  360,  381,  386,  394,  396,  402,  404,
          413,  414,  419,  421,  449,  450,  454,  468,  469,  476,  499,  500,
          503,  505,  506,  521,  528,  530,  539,  556,  560,  570,  606,  617,
          631,  635,  639,  644,  655,  661,  672,  682,  684,  702,  737,  743,
          748,  756,  759,  762,  770,  784,  798,  814,  826,  831,  835,  858,
          868,  884,  908,  92

In [37]:
seqrec_module.validation_step(batch_val, None)

0.894140625


MisconfigurationException: You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging

In [43]:
targets = batch_val['target'].detach().cpu().numpy()
targets = [seqrec_module.inv_map.get(tuple(seq), 0) for seq in targets]
targets
metrics = seqrec_module.compute_val_metrics(targets, preds)
metrics

{'ndcg': 0.0, 'hit_rate': 0.0, 'mrr': 0.0}

In [108]:
inv_map[(110, 393, 628, 768)]

6172

In [112]:
tuple(batch_val['target'][0].detach().cpu().numpy())

(110, 393, 628, 768)

In [61]:
seqrec_module.make_prediction_generate(batch_val)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [82]:
seqrec_module.mode = 'beamsearch'

In [40]:
B = batch_val['input_ids'].size(0)
K = seqrec_module.predict_top_k
N = seqrec_module.N

if seqrec_module.mode == 'greedy':
    max_new_tokens = N * K
    # гриди: do_sample=False, num_beams=None, num_return_sequences=1
    params = dict(seqrec_module.generate_params)
    params["do_sample"] = False
    params["max_new_tokens"] = max_new_tokens

    seq = seqrec_module.model.generate(
        batch_val['input_ids'][:, -seqrec_module.model.config.n_positions + max_new_tokens:].to(seqrec_module.model.device),
        pad_token_id=seqrec_module.padding_idx,
        **params
    )

    cont = seq[:, -max_new_tokens:].view(B, K, N)  # (B, N*K)

elif seqrec_module.mode == 'beamsearch':
    params = dict(seqrec_module.generate_params)
    params["do_sample"] = False
    params["num_beams"] = K
    params["num_return_sequences"] = K
    params["max_new_tokens"] = N
    seq = seqrec_module.model.generate(
        batch_val['input_ids'][:, -seqrec_module.model.config.n_positions + N:].to(seqrec_module.model.device),
        pad_token_id=seqrec_module.padding_idx,
        **params
    )
    print('jopa')

    cont = seq[:, -N:] # (B*K, N)

cont = cont.view(B, K, N) #  -> (B, K, N)

cont_flat = cont.reshape(-1, N).tolist()
preds_flat = [seqrec_module.inv_map.get(tuple(seq), 0) for seq in cont_flat]
preds = (torch.tensor(preds_flat, dtype=torch.long).view(B, K))  #  -> (B, K)

In [96]:
cont.shape

torch.Size([256, 10, 4])

In [41]:
preds

tensor([[   0,    0,    0,  ...,    0,    0,    0],
        [9289, 9289, 9289,  ..., 9289, 9289, 9289],
        [9289, 9289, 9289,  ..., 9289, 9289, 9289],
        ...,
        [9289, 9289, 9289,  ..., 9289, 9289, 9289],
        [9289, 9289, 9289,  ..., 9289, 9289, 9289],
        [9289, 9289, 9289,  ..., 9289, 9289, 9289]])

In [86]:
mask = preds != 0
print(mask.sum().item() / mask.numel())

scores = torch.arange(K-1, -1, -1)   # [K-1, ... 0]
scores = scores.unsqueeze(0).expand(B, -1)
scores

0.0


tensor([[9, 8, 7,  ..., 2, 1, 0],
        [9, 8, 7,  ..., 2, 1, 0],
        [9, 8, 7,  ..., 2, 1, 0],
        ...,
        [9, 8, 7,  ..., 2, 1, 0],
        [9, 8, 7,  ..., 2, 1, 0],
        [9, 8, 7,  ..., 2, 1, 0]])

In [76]:
cont

tensor([[[ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         ...,
         [ 768,  768,  768,  768],
         [ 768,  768,  768, 1462],
         [1462, 1462, 1462, 1462]],

        [[ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         ...,
         [ 347,  347,  347,  347],
         [ 347,  347,  347,  347],
         [ 347, 1324, 1324, 1324]],

        [[ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         ...,
         [ 768,  768,  768,  768],
         [ 768,  768, 1469, 1469],
         [1469, 1469, 1469, 1469]],

        ...,

        [[1658, 1658, 1658, 1658],
         [1658, 1658, 1658, 1658],
         [1658, 1658, 1658, 1658],
         ...,
         [1023, 1023, 1023, 1023],
         [1023, 1023, 1023, 1023],
         [1023, 1023, 1023, 1023]],

        [[ 768,  768,  768,  768],
         [ 768,  768,  768,  768],
         [ 7

In [69]:
N

4

In [75]:
cont.reshape(-1, N)

tensor([[ 768,  768,  768,  768],
        [ 768,  768,  768,  768],
        [ 768,  768,  768,  768],
        ...,
        [ 768,  768,  768,  768],
        [ 768, 1224, 1224, 1224],
        [1224, 1224, 1224, 1224]])

In [64]:
batch_val['target']

tensor([[110, 393, 628, 768],
        [ 23, 293, 548, 768],
        [ 43, 366, 756, 768],
        ...,
        [120, 387, 567, 768],
        [193, 390, 621, 768],
        [ 83, 302, 716, 769]])

In [33]:
batch_val

tensor([[110, 393, 628, 768],
        [ 53, 261, 585, 768],
        [135, 386, 589, 768],
        ...,
        [126, 378, 609, 768],
        [240, 360, 765, 768],
        [104, 450, 701, 768]])

In [18]:
seqrec_module.model.config.vocab_size

1707

In [35]:
trainer.fit(model=seqrec_module,
                    train_dataloaders=train_loader,
                    val_dataloaders=eval_loader)

/home/jovyan/.mlspace/envs/splits/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/jovyan/gusak/semantic_seqrec/models/global_timesplit/val_by_time/Beauty/q09/GPT-2 exists and is not empty.

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 242 K 
------------------------------------------
242 K     Trainable params
0         Non-trainable params
242 K     Total params
0.968     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

0.0
0.0
Epoch 0:   0%|          | 0/159 [00:00<?, ?it/s]                           

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Epoch 0: 100%|██████████| 159/159 [01:05<00:00,  2.42it/s, v_num=6]0.894140625
0.879296875
0.903125


RuntimeError: a Tensor with 10 elements cannot be converted to Scalar

### Predict

In [16]:
if config.model.model_class == 'GPT-2':
    if config.model.generation:
        predict_dataset = CausalLMPredictionDataset(
            test, max_length=config.dataset_params.max_length - max(config.evaluator.top_k))
        
        predict_loader = DataLoader(
                predict_dataset, shuffle=False,
                collate_fn=PaddingCollateFn(left_padding=True),
                batch_size=config.dataloader.test_batch_size,
                num_workers=config.dataloader.num_workers)
        
        seqrec_module.set_predict_mode(generate=True, mode=config.model.mode, **config.model.generation_params)

    else:
        predict_dataset = CausalLMPredictionDataset(test, max_length=config.dataset_params.max_length)

        predict_loader = DataLoader(
                predict_dataset, shuffle=False,
                collate_fn=PaddingCollateFn(),
                batch_size=config.dataloader.test_batch_size,
                num_workers=config.dataloader.num_workers)
        
        seqrec_module.set_predict_mode(generate=False)

In [17]:
preds = trainer.predict(model=seqrec_module, dataloaders=predict_loader)
recs = preds2recs(preds, successive=False)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 96/96 [00:01<00:00, 48.90it/s]


In [18]:
recs

Unnamed: 0,user_id,item_id,prediction
0,0,10148,1.000000
1,0,10163,0.500000
2,0,7849,0.333333
3,0,9509,0.250000
4,0,10609,0.200000
...,...,...,...
61095,22361,7906,0.166667
61096,22361,9062,0.142857
61097,22361,4973,0.125000
61098,22361,5689,0.111111


In [19]:
batch = next(iter(predict_loader))

In [62]:
seqrec_module.generate_params['num_return_sequences'] = 1
seqrec_module.generate_params['no_repeat_ngram_size'] = 1
seqrec_module.generate_params

{'early_stopping': False,
 'num_return_sequences': 1,
 'do_sample': False,
 'no_repeat_ngram_size': 1,
 'num_beams': 10}

### HuggingFace Generate

In [63]:
seq = seqrec_module.model.generate(
                batch['input_ids'][:, -seqrec_module.model.config.n_positions + seqrec_module.predict_top_k*4:].to(seqrec_module.model.device),
                pad_token_id=seqrec_module.padding_idx,
                max_new_tokens=seqrec_module.predict_top_k*4,
                **seqrec_module.generate_params,
            )

preds = seq[:, -seqrec_module.predict_top_k*4:]

In [68]:
seq.shape

torch.Size([64, 68])

In [65]:
batch['input_ids'].shape

torch.Size([64, 28])

In [66]:
seq.shape

torch.Size([64, 68])

In [69]:
preds.shape

torch.Size([64, 40])

In [70]:
def chunker(seq, size):
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]

In [87]:
preds.view(64,-1, 4)

tensor([[[10148, 10163, 10632,  3496],
         [ 7532,  8124,  8824,  8969],
         [10100, 10106,  1880,  5636],
         ...,
         [ 5631,  5633,  5634,  5638],
         [ 8601,  8635,  8636,  8701],
         [10206, 10208, 10244, 10247]],

        [[ 1390,  1660,  3519,  1679],
         [ 5516,  5695,  9568, 10646],
         [11540, 11517, 11630, 10442],
         ...,
         [10123, 10206, 10208, 10244],
         [10247, 10285,  5637,  6288],
         [11932, 11665, 11315, 11609]],

        [[10350,  8507,  8968,  9767],
         [10632,  6237,  7532,  8136],
         [ 8969,  8824, 10000, 10100],
         ...,
         [ 9205, 10159, 10189, 10770],
         [10313, 11215, 10619, 10166],
         [10746, 10248, 10354, 11141]],

        ...,

        [[11290,  1679,  6136, 11690],
         [11691, 11790, 11932,  5752],
         [ 5921,  6293,  6371,  6661],
         ...,
         [11844, 11770, 11766, 11775],
         [11774, 11715, 11763, 11778],
         [11765, 11773,  89

In [74]:
preds[0]

tensor([10148, 10163, 10632,  3496,  7532,  8124,  8824,  8969, 10100, 10106,
         1880,  5636,  5647,  5752,  5921,  6293,  6371,  6661,  7608,  8092,
         8518,  8596,  8705,  8858,  8897,  8744,  4858,  5266,  5631,  5633,
         5634,  5638,  8601,  8635,  8636,  8701, 10206, 10208, 10244, 10247])

In [40]:
preds

tensor([[ 7977,  9509,  7849,  ...,  9120,  9622, 10000],
        [10214,  4343,  6727,  ...,  7847,  5764,  4185],
        [10632,  6237,  8136,  ..., 10156,  9205,  9110],
        ...,
        [ 6046,  3578,  9668,  ..., 11949, 11972, 11770],
        [ 2648,  5095,  2319,  ...,  7650,  7652, 10284],
        [ 8229, 11054,  8978,  ..., 11699, 11803, 11236]])

In [88]:
seqrec_module.generate_params['num_return_sequences'] = seqrec_module.predict_top_k
seqrec_module.generate_params['no_repeat_ngram_size'] = 1
seqrec_module.generate_params['num_beams'] = seqrec_module.predict_top_k
seqrec_module.generate_params

{'early_stopping': False,
 'num_return_sequences': 10,
 'do_sample': False,
 'no_repeat_ngram_size': 1,
 'num_beams': 10}

In [89]:
seq = seqrec_module.model.generate(
                batch['input_ids'][:, -seqrec_module.model.config.n_positions + 4:].to(seqrec_module.model.device),
                pad_token_id=seqrec_module.padding_idx,
                max_new_tokens=4,
                **seqrec_module.generate_params,
            )

preds = seq[:, -4:]

In [90]:
preds.shape

torch.Size([640, 4])

In [93]:
seq.shape

torch.Size([640, 32])

In [92]:
preds

tensor([[10148, 10163,  7849,  7977],
        [10148, 10163,  7849,  9509],
        [10148, 10163,  7849, 10003],
        ...,
        [ 2618,  2620,  2617,  9250],
        [ 2618,  2619,  4439,  9250],
        [ 2618,  4439,  2617,  9250]])

In [96]:
preds.view(64, -1, 4)

tensor([[[10148, 10163,  7849,  7977],
         [10148, 10163,  7849,  9509],
         [10148, 10163,  7849, 10003],
         ...,
         [10148, 10163, 10632,  6237],
         [10148, 10163,  7532,  8824],
         [10148, 10163, 10632,  8136]],

        [[ 2648,  3145, 11940, 11941],
         [ 2648,  3145, 11938, 11942],
         [ 2648,  3145, 11938, 11940],
         ...,
         [ 2648,  2673,  2826,  7987],
         [ 5040,  7637,  6914,  8326],
         [ 9165, 10481,  9259,  9084]],

        [[11215,  3496,  7532,  8824],
         [11215,  3496,  6237,  9109],
         [11215,  3496,  6237,  7532],
         ...,
         [10100, 10163,  6237,  7532],
         [11215,  3496,  7532,  8124],
         [11215,  3496,  6237,  8968]],

        ...,

        [[ 5752,  5921,  6293,  6371],
         [ 5752,  5921,  6293,  6661],
         [10742,  8139, 10345, 11408],
         ...,
         [ 5752,  5921,  6293,  5131],
         [ 5752,  5921,  6293, 10350],
         [10742,  8139,  61

In [100]:
scores = torch.arange(10-1, -1, -1)   # [K-1, ... 0]
scores = scores.unsqueeze(0).expand(64, -1)
scores

tensor([[9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],


In [None]:
import pickle as pkl

with open('/home/jovyan/gusak/semantic_seqrec/data/item_sem_id_modified.pkl', 'rb') as f:
    item_sem_id = pkl.load(f)

item_sem_id


{9450: array([104, 431, 695, 768]),
 9840: array([211, 332, 621, 768]),
 10077: array([193, 307, 697, 768]),
 11156: array([ 84, 430, 657, 768]),
 11753: array([193, 436, 645, 768]),
 11864: array([  1, 295, 720, 768]),
 3310: array([120, 291, 761, 768]),
 4573: array([235, 506, 751, 768]),
 4137: array([177, 323, 750, 768]),
 9080: array([177, 323, 677, 768]),
 10956: array([140, 396, 585, 768]),
 4387: array([ 15, 386, 559, 768]),
 6363: array([ 15, 269, 696, 768]),
 454: array([254, 289, 674, 768]),
 4209: array([ 92, 463, 696, 768]),
 59: array([107, 300, 715, 768]),
 5666: array([211, 470, 675, 768]),
 8650: array([  2, 434, 724, 768]),
 9405: array([  6, 377, 571, 768]),
 8877: array([  1, 378, 609, 768]),
 8824: array([ 37, 314, 743, 768]),
 8651: array([ 37, 374, 573, 768]),
 10131: array([ 75, 378, 751, 768]),
 10254: array([  1, 508, 688, 768]),
 8969: array([  1, 378, 548, 768]),
 5098: array([211, 378, 751, 768]),
 2714: array([ 74, 348, 644, 768]),
 10321: array([  2, 408,

In [105]:
{tuple(sem_ids): item_id 
           for item_id, sem_ids in item_sem_id.items()}

{(104, 431, 695, 768): 9450,
 (211, 332, 621, 768): 9840,
 (193, 307, 697, 768): 10077,
 (84, 430, 657, 768): 11156,
 (193, 436, 645, 768): 11753,
 (1, 295, 720, 768): 11864,
 (120, 291, 761, 768): 3310,
 (235, 506, 751, 768): 4573,
 (177, 323, 750, 768): 4137,
 (177, 323, 677, 768): 9080,
 (140, 396, 585, 768): 10956,
 (15, 386, 559, 768): 4387,
 (15, 269, 696, 768): 6363,
 (254, 289, 674, 768): 454,
 (92, 463, 696, 768): 4209,
 (107, 300, 715, 768): 59,
 (211, 470, 675, 768): 5666,
 (2, 434, 724, 768): 8650,
 (6, 377, 571, 768): 9405,
 (1, 378, 609, 768): 8877,
 (37, 314, 743, 768): 8824,
 (37, 374, 573, 768): 8651,
 (75, 378, 751, 768): 10131,
 (1, 508, 688, 768): 10254,
 (1, 378, 548, 768): 8969,
 (211, 378, 751, 768): 5098,
 (74, 348, 644, 768): 2714,
 (2, 408, 512, 768): 10321,
 (60, 290, 751, 768): 10163,
 (2, 500, 644, 768): 8072,
 (200, 427, 645, 768): 9767,
 (200, 365, 761, 768): 3496,
 (60, 277, 697, 768): 7450,
 (74, 296, 675, 768): 8030,
 (211, 378, 720, 768): 9622,
 (126,

In [19]:
with open('/home/jovyan/gusak/semantic_seqrec/data/item_sem_id_modified.pkl', 'rb') as f:
    index2semid = pickle.load(f)
# inv_map = {tuple(sem_ids): item_id for item_id, sem_ids in index2semid.items()}

In [24]:
torch.Tensor(list(index2semid.values())).max()

tensor(1706.)

In [114]:
list(index2semid.values())[0]

array([104, 431, 695, 768])