# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import string
import sys

import pandas as pd
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

sys.path.insert(0, "..")

from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 16

    user_col: str = "user_id"
    item_col: str = "parent_asin"


args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 16,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "b", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

val_sequences = [["f", "l", "m"], ["i", "h"], ["j", "e", "a"]]

sequences_fp = "../data/sequences.jsonl"
val_sequences_fp = "../data/val_sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_sequences:
        f.write(json.dumps(sequence) + "\n")

In [5]:
# Simulate pre-configured id_to_idx mapper
id_to_idx = {
    id_: idx for id_, idx in zip(list(string.ascii_letters[:13]), list(range(13)))
}
id_to_idx["a"] = 1
id_to_idx["b"] = 0

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp, window_size=1, negative_samples=2, id_to_idx=id_to_idx
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=1,
    negative_samples=2,
    id_to_idx=id_to_idx,
)

# Example of getting an item
for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-10-20 00:45:12.407[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-10-20 00:45:12.421[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([0, 0, 0]) tensor([ 2,  7, 11]) tensor([1., 0., 0.])


In [6]:
dataset.id_to_idx

{'a': 1,
 'b': 0,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12}

In [7]:
dataset.sampling_probs

array([0.13537454, 0.05938764, 0.13537454, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764, 0.05938764, 0.05938764,
       0.13537454, 0.05938764, 0.05938764])

In [8]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 10, 11, 12},
             11: {6, 7, 10, 11, 12},
             12: {6, 7, 10, 11, 12},
             8: {2, 8, 9},
             9: {2, 8, 9}})

In [9]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(2))
(tensor(0), tensor(7))
(tensor(0), tensor(11))


In [10]:
val_dataset.sampling_probs

array([0.10225277, 0.07544086, 0.10225277, 0.0448574 , 0.07544086,
       0.07544086, 0.0448574 , 0.07544086, 0.07544086, 0.07544086,
       0.10225277, 0.07544086, 0.07544086])

In [11]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4, 9},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4, 9},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10, 11, 12},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 8, 10, 11, 12},
             11: {5, 6, 7, 10, 11, 12},
             12: {5, 6, 7, 10, 11, 12},
             8: {2, 7, 8, 9},
             9: {1, 2, 4, 8, 9}})

## Test no conflicting labels

In [12]:
dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [13]:
target_items = []
context_items = []
labels = []

for i, batch_input in tqdm(enumerate(dataloader)):
    _target_items = batch_input["target_items"].cpu().detach().numpy()
    _context_items = batch_input["context_items"].cpu().detach().numpy()
    _labels = batch_input["labels"].cpu().detach().numpy()

    target_items.extend(_target_items)
    context_items.extend(_context_items)
    labels.extend(_labels)

test_df = pd.DataFrame(
    {"target_items": target_items, "context_items": context_items, "labels": labels}
)
assert (
    test_df.groupby(["target_items", "context_items"])["labels"]
    .nunique()
    .loc[lambda s: s > 1]
    .shape[0]
    == 0
), "Conflicting labels!"

0it [00:00, ?it/s]

# Load data

In [14]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")
idm = IDMapper().load("../data/idm.json")

In [15]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AE2AB5V6OZNXHFVEOHBKYPUVQBAA,B0BKSLJMS7,5.0,1635137091196,13808,318,Computers,Logitech G29 Driving Force Racing Wheel and Fl...,[],"[Video Games, PC, Accessories, Controllers, Ra...",352.1,"[-1, -1, -1, 3010, 2852, 892, 206, 1298, 3070,..."
1,AE2B3KJWAXWZKZRZX6LIXYAMF5MA,B09CG15F86,2.0,1652370727452,8584,1316,Computers,Razer Doubleshot PBT Keycap Upgrade Set for Me...,[Enjoy durability without cramping your style....,"[Video Games, PC, Accessories, Gaming Keyboards]",,"[-1, -1, -1, -1, -1, 3673, 839, 46, 1292, 269]"
2,AE2FHQDFEPXKMKIZG2T3RDQIOOUA,B08DF248LD,4.0,1650551522850,3462,63,Video Games,Xbox Core Wireless Controller – Carbon Black,[Experience the modernized design of the Xbox ...,[],45.5,"[-1, -1, -1, -1, 2309, 2088, 1121, 3826, 1796,..."
3,AE2KSKDHIBIBGNZNOUPVPZI4DEOQ,B0C1K1R6HK,2.0,1652300179778,13459,3183,Video Games,Xbox Series X,"[Xbox Series X, the fastest, most powerful Xbo...","[Video Games, Legacy Systems, Xbox Systems, Xb...",499.99,"[-1, 2981, 45, 2935, 2654, 3337, 4201, 575, 28..."
4,AE2MZXV6MN7FGQAWXW3QRCVLTOEQ,B0BL65X86R,5.0,1655492544369,1990,4142,Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0,"[-1, -1, -1, 1332, 1709, 1068, 2293, 4250, 290..."
...,...,...,...,...,...,...,...,...,...,...,...,...
957,AHZKTZHKO3Z6UYWEYMH4YL52K3LA,B0716CXJ1R,5.0,1636865685568,9989,853,Video Games,Darksiders III - Collector's Edition - Xbox One,[],"[Video Games, Xbox One, Games]",149.99,"[741, 4427, 664, 2712, 2086, 2299, 949, 611, 4..."
958,AHZKTZHKO3Z6UYWEYMH4YL52K3LA,B07SM7G9CN,5.0,1636865734529,9989,4431,Video Games,Donkey Kong Country: Tropical Freeze - Nintend...,[Barrel-blast into a critically acclaimed Donk...,"[Video Games, Nintendo Switch, Games]",52.49,"[4427, 664, 2712, 2086, 2299, 949, 611, 4352, ..."
959,AHZKTZHKO3Z6UYWEYMH4YL52K3LA,B081W1VBKN,5.0,1644053506803,9989,3222,Video Games,Darksiders 2: Deathinitive Edition - Xbox One ...,"[What starts with War, ends in Death. Awakened...","[Video Games, Xbox One, Games]",14.99,"[664, 2712, 2086, 2299, 949, 611, 4352, 1900, ..."
960,AHZLVBGFP4FNOJGC33CZQSHUQXWA,B07H53PZY8,4.0,1634514696799,4179,1645,Video Games,Mudrunner - American Wilds Edition - PlayStati...,[Mud Runner - American Wilds is the ultimate v...,"[Video Games, PlayStation 4, Games]",23.98,"[4317, 2926, 4329, 4188, 1203, 4197, 4447, 154..."


In [16]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col]
        .agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [17]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

19578

In [18]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

157

## Persist

In [19]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in item_sequence:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_item_sequence:
        f.write(json.dumps(sequence) + "\n")

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-10-20 00:45:13.244[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=19,578 len(val_item_sequence)=157[0m


## Persist a small data for overfitting

In [20]:
num_sequences = 2
batch_item_sequence = item_sequence[:num_sequences]
batch_sequences_fp = "../data/batch_item_sequence.jsonl"

with open(batch_sequences_fp, "w") as f:
    for sequence in batch_item_sequence:
        f.write(json.dumps(sequence) + "\n")

# Run with all data

In [21]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-10-20 00:45:13.265[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([394, 394, 394, 394, 394, 394]) tensor([3760,  547,  524, 2360, 3228, 2516]) tensor([1., 0., 0., 0., 0., 0.])


In [22]:
batch_size = len(item_sequence[0])  # for easier testing

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [23]:
# Test index mapping matches input id_mapper
target_items_idx_dataloader = set()
for batch_input in dataloader:
    target_items_idx_dataloader.update(batch_input["target_items"].detach().numpy())
    break

targets_items_idx_item_sequence = set()
for item_id in item_sequence[0]:
    idx = idm.item_to_index[item_id]
    targets_items_idx_item_sequence.add(idx)

assert target_items_idx_dataloader == targets_items_idx_item_sequence

In [24]:
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-10-20 00:45:13.779[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [25]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([3070, 3070, 3070, 3070, 3070, 3070,  892,  892,  892,  892,  892,  892,
        3266, 3266, 3266, 3266, 3266, 3266,  483,  483,  483,  483,  483,  483,
         483,  483,  483,  483,  483,  483,  611,  611,  611,  611,  611,  611,
         611,  611,  611,  611,  611,  611,  556,  556,  556,  556,  556,  556,
         556,  556,  556,  556,  556,  556, 1589, 1589, 1589, 1589, 1589, 1589,
        1589, 1589, 1589, 1589, 1589, 1589, 1515, 1515, 1515, 1515, 1515, 1515,
        3032, 3032, 3032, 3032, 3032, 3032,  142,  142,  142,  142,  142,  142,
         142,  142,  142,  142,  142,  142, 2067, 2067, 2067, 2067, 2067, 2067,
        2136, 2136, 2136, 2136, 2136, 2136, 2075, 2075, 2075, 2075, 2075, 2075,
        3955, 3955, 3955, 3955, 3955, 3955,  756,  756,  756,  756,  756,  756,
        2689, 2689, 2689, 2689, 2689, 2689]), 'context_items': tensor([ 892,  749, 4522, 3117, 1689, 1817, 3070, 4198, 4432, 2420, 4284, 2509,
         483,  341, 2867, 1117, 3903,  6