# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import string
import sys

import pandas as pd
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

sys.path.insert(0, "..")
from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 16

    user_col: str = "user_id"
    item_col: str = "parent_asin"


args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 16,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "b", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

val_sequences = [["f", "l", "m"], ["i", "h"], ["j", "e", "a"]]

sequences_fp = "../data/sequences.jsonl"
val_sequences_fp = "../data/val_sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_sequences:
        f.write(json.dumps(sequence) + "\n")

In [5]:
# Simulate pre-configured id_to_idx mapper
id_to_idx = {
    id_: idx for id_, idx in zip(list(string.ascii_letters[:13]), list(range(13)))
}
id_to_idx["a"] = 1
id_to_idx["b"] = 0

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp, window_size=1, negative_samples=2, id_to_idx=id_to_idx
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=1,
    negative_samples=2,
    id_to_idx=id_to_idx,
)

# Example of getting an item
for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-09-30 09:43:34.043[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-09-30 09:43:34.055[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([0, 0, 0]) tensor([ 2, 11,  7]) tensor([1., 0., 0.])


In [6]:
dataset.id_to_idx

{'a': 1,
 'b': 0,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12}

In [7]:
dataset.sampling_probs

array([0.13537454, 0.05938764, 0.13537454, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764, 0.05938764, 0.05938764,
       0.13537454, 0.05938764, 0.05938764])

In [8]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 10, 11, 12},
             11: {6, 7, 10, 11, 12},
             12: {6, 7, 10, 11, 12},
             8: {2, 8, 9},
             9: {2, 8, 9}})

In [9]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(2))
(tensor(0), tensor(11))
(tensor(0), tensor(7))


In [10]:
val_dataset.sampling_probs

array([0.10225277, 0.07544086, 0.10225277, 0.0448574 , 0.07544086,
       0.07544086, 0.0448574 , 0.07544086, 0.07544086, 0.07544086,
       0.10225277, 0.07544086, 0.07544086])

In [11]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4, 9},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4, 9},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10, 11, 12},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 8, 10, 11, 12},
             11: {5, 6, 7, 10, 11, 12},
             12: {5, 6, 7, 10, 11, 12},
             8: {2, 7, 8, 9},
             9: {1, 2, 4, 8, 9}})

## Test no conflicting labels

In [12]:
dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [13]:
target_items = []
context_items = []
labels = []

for i, batch_input in tqdm(enumerate(dataloader)):
    _target_items = batch_input["target_items"].cpu().detach().numpy()
    _context_items = batch_input["context_items"].cpu().detach().numpy()
    _labels = batch_input["labels"].cpu().detach().numpy()

    target_items.extend(_target_items)
    context_items.extend(_context_items)
    labels.extend(_labels)

test_df = pd.DataFrame(
    {"target_items": target_items, "context_items": context_items, "labels": labels}
)
assert (
    test_df.groupby(["target_items", "context_items"])["labels"]
    .nunique()
    .loc[lambda s: s > 1]
    .shape[0]
    == 0
), "Conflicting labels!"

0it [00:00, ?it/s]

# Load data

In [14]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")
idm = IDMapper().load("../data/idm.json")

In [15]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AEN7JFLQCURF54WR5OHY7HOWWMSQ,B08FC5TTBF,5.0,1628644724721,9638,3354,Video Games,Demon's Souls - PlayStation 5,[From Bluepoint Games comes a remake of the Pl...,"[Video Games, PlayStation 5, Games]",29.99,"[-1, -1, -1, -1, 562, 4378, 4194, 1334, 3169, ..."
1,AELH2ZF5QSSIFBF6WXAZLCF7JIWA,B0C6DH316S,2.0,1628653733506,20213,2064,Computers,Logitech G PRO X Wireless Lightspeed Gaming He...,[],"[Video Games, PC, Accessories, Headsets]",253.82,"[-1, -1, -1, -1, 1172, 2202, 2609, 3980, 3547,..."
2,AGD4QHNPSC45XTUPSUE6TYQOF3WQ,B0BN5DC36N,5.0,1628679010802,9107,2562,Computers,Seagate Horizon Forbidden West Limited Edition...,[Discover new worlds with the officially-licen...,"[Video Games, Legacy Systems, PlayStation Syst...",89.99,"[2812, 2387, 4376, 1890, 403, 737, 2959, 790, ..."
3,AFMOSTKHH2HFLI35E3YMI7GLYDCQ,B07KRWJCQW,5.0,1628687441776,6689,1845,Video Games,$40 Xbox Gift Card [Digital Code],[Buy an Xbox Gift Card for yourself or a frien...,"[Video Games, Online Game Services, Xbox Live,...",40.0,"[-1, -1, 2002, 1749, 3878, 4132, 4528, 86, 418..."
4,AGK34QNFABMBLRESDKG2VRC3VIIQ,B0BL65X86R,5.0,1628702768435,9987,781,Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0,"[1454, 3730, 3509, 3333, 3118, 3298, 2008, 689..."
...,...,...,...,...,...,...,...,...,...,...,...,...
944,AEKYV77UMZZGHT4PZIETDQ6ELJBQ,B08F4C6HCD,5.0,1657816667680,5258,337,Video Games,Legend of Zelda Link's Awakening - Nintendo Sw...,"[“Castaway, you should know the truth!” As Lin...","[Video Games, Nintendo Switch, Games]",59.88,"[1164, 246, 2145, 4214, 3056, 2672, 1158, 3609..."
945,AGUFCRCH7HOUQ5FQYSJETEEFAYOA,B00DBDPOZ4,5.0,1657855227062,17774,3394,Video Games,Xbox One Play and Charge Kit,[Keep the action going with the Xbox One Play ...,"[Video Games, Xbox One, Accessories]",34.99,"[-1, -1, -1, -1, -1, 2273, 4184, 3370, 3585, 3..."
946,AHJUZFMUESAEQBPC2QQMBDVUBYFQ,B0B1PB5L93,4.0,1657883331431,4200,3136,Computers,Razer Viper Ultimate Lightweight Wireless Gami...,[Forget about average and claim the unfair adv...,"[Video Games, PC, Accessories, Gaming Mice]",89.99,"[-1, -1, -1, 3163, 3226, 1292, 833, 4476, 432,..."
947,AE5UUBPDQX4MRFFDW7D3IKHQYIEQ,B00ZJBSBD8,5.0,1657945454164,2512,2946,Video Games,Trackmania Turbo-Nla,[Step into the wild car fantasy world of Track...,"[Video Games, PlayStation 4, Games]",13.68,"[136, 991, 1692, 3380, 762, 3777, 3897, 3240, ..."


In [16]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col]
        .agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [17]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

20366

In [18]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

147

## Persist

In [19]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in item_sequence:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_item_sequence:
        f.write(json.dumps(sequence) + "\n")

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-09-30 09:43:34.958[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=20,366 len(val_item_sequence)=147[0m


## Persist a small data for overfitting

In [20]:
num_sequences = 2
batch_item_sequence = item_sequence[:num_sequences]
batch_sequences_fp = "../data/batch_item_sequence.jsonl"

with open(batch_sequences_fp, "w") as f:
    for sequence in batch_item_sequence:
        f.write(json.dumps(sequence) + "\n")

# Run with all data

In [21]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-09-30 09:43:34.978[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([4622, 4622, 4622, 4622, 4622, 4622]) tensor([3033, 4279, 4270, 2567, 2112,  346]) tensor([1., 0., 0., 0., 0., 0.])


In [22]:
batch_size = len(item_sequence[0])  # for easier testing

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [23]:
# Test index mapping matches input id_mapper
target_items_idx_dataloader = set()
for batch_input in dataloader:
    target_items_idx_dataloader.update(batch_input["target_items"].detach().numpy())
    break

targets_items_idx_item_sequence = set()
for item_id in item_sequence[0]:
    idx = idm.item_to_index[item_id]
    targets_items_idx_item_sequence.add(idx)

assert target_items_idx_dataloader == targets_items_idx_item_sequence

In [24]:
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-09-30 09:43:35.551[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [25]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([ 918,  918,  918,  918,  918,  918, 1188, 1188, 1188, 1188, 1188, 1188,
         875,  875,  875,  875,  875,  875,  287,  287,  287,  287,  287,  287,
         287,  287,  287,  287,  287,  287, 1266, 1266, 1266, 1266, 1266, 1266,
        1266, 1266, 1266, 1266, 1266, 1266,  916,  916,  916,  916,  916,  916,
         916,  916,  916,  916,  916,  916, 3985, 3985, 3985, 3985, 3985, 3985,
        3985, 3985, 3985, 3985, 3985, 3985, 1845, 1845, 1845, 1845, 1845, 1845,
        1845, 1845, 1845, 1845, 1845, 1845,  156,  156,  156,  156,  156,  156,
        4048, 4048, 4048, 4048, 4048, 4048, 4607, 4607, 4607, 4607, 4607, 4607,
        2302, 2302, 2302, 2302, 2302, 2302, 2895, 2895, 2895, 2895, 2895, 2895,
        2895, 2895, 2895, 2895, 2895, 2895, 4035, 4035, 4035, 4035, 4035, 4035,
        4035, 4035, 4035, 4035, 4035, 4035, 2805, 2805, 2805, 2805, 2805, 2805,
        2805, 2805, 2805, 2805, 2805, 2805, 4374, 4374, 4374, 4374, 4374, 4374,
        4374, 4374, 437