# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import string
import sys

import pandas as pd
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

sys.path.insert(0, "..")

from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 16

    user_col: str = "user_id"
    item_col: str = "parent_asin"


args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 16,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "b", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

val_sequences = [["f", "l", "m"], ["i", "h"], ["j", "e", "a"]]

sequences_fp = "../data/sequences.jsonl"
val_sequences_fp = "../data/val_sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_sequences:
        f.write(json.dumps(sequence) + "\n")

In [5]:
# Simulate pre-configured id_to_idx mapper
id_to_idx = {
    id_: idx for id_, idx in zip(list(string.ascii_letters[:13]), list(range(13)))
}
id_to_idx["a"] = 1
id_to_idx["b"] = 0

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp, window_size=1, negative_samples=2, id_to_idx=id_to_idx
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=1,
    negative_samples=2,
    id_to_idx=id_to_idx,
)

# Example of getting an item
for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-10-10 11:18:42.863[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-10-10 11:18:42.881[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([0, 0, 0]) tensor([2, 7, 9]) tensor([1., 0., 0.])


In [6]:
dataset.id_to_idx

{'a': 1,
 'b': 0,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12}

In [7]:
dataset.sampling_probs

array([0.13537454, 0.05938764, 0.13537454, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764, 0.05938764, 0.05938764,
       0.13537454, 0.05938764, 0.05938764])

In [8]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 10, 11, 12},
             11: {6, 7, 10, 11, 12},
             12: {6, 7, 10, 11, 12},
             8: {2, 8, 9},
             9: {2, 8, 9}})

In [9]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(2))
(tensor(0), tensor(7))
(tensor(0), tensor(9))


In [10]:
val_dataset.sampling_probs

array([0.10225277, 0.07544086, 0.10225277, 0.0448574 , 0.07544086,
       0.07544086, 0.0448574 , 0.07544086, 0.07544086, 0.07544086,
       0.10225277, 0.07544086, 0.07544086])

In [11]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4, 9},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4, 9},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10, 11, 12},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 8, 10, 11, 12},
             11: {5, 6, 7, 10, 11, 12},
             12: {5, 6, 7, 10, 11, 12},
             8: {2, 7, 8, 9},
             9: {1, 2, 4, 8, 9}})

## Test no conflicting labels

In [12]:
dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [13]:
target_items = []
context_items = []
labels = []

for i, batch_input in tqdm(enumerate(dataloader)):
    _target_items = batch_input["target_items"].cpu().detach().numpy()
    _context_items = batch_input["context_items"].cpu().detach().numpy()
    _labels = batch_input["labels"].cpu().detach().numpy()

    target_items.extend(_target_items)
    context_items.extend(_context_items)
    labels.extend(_labels)

test_df = pd.DataFrame(
    {"target_items": target_items, "context_items": context_items, "labels": labels}
)
assert (
    test_df.groupby(["target_items", "context_items"])["labels"]
    .nunique()
    .loc[lambda s: s > 1]
    .shape[0]
    == 0
), "Conflicting labels!"

0it [00:00, ?it/s]

# Load data

In [14]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")
idm = IDMapper().load("../data/idm.json")

In [15]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AGD4QHNPSC45XTUPSUE6TYQOF3WQ,B0BN5DC36N,5.0,1628679010802,8301,2189,Computers,Seagate Horizon Forbidden West Limited Edition...,[Discover new worlds with the officially-licen...,"[Video Games, Legacy Systems, PlayStation Syst...",89.99,"[1855, 4190, 2358, 4136, 2954, 802, 2776, 2147..."
1,AGK34QNFABMBLRESDKG2VRC3VIIQ,B0BL65X86R,5.0,1628702768435,18134,98,Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0,"[3039, 3175, 1599, 1609, 1490, 2307, 1888, 323..."
2,AHRIM42WD5BZX6T37TUR3Z73DUXQ,B07VHHPJBV,4.0,1628711687953,1912,1979,Computers,"Redragon M602 Griffin RGB Gaming Mouse, RGB Sp...",[Redragon M602 (White) GRIFFIN High-Precision ...,"[Video Games, PC, Accessories, Gaming Mice]",19.99,"[-1, -1, 3095, 1325, 936, 380, 3628, 270, 3487..."
3,AFZDMJ2J5XHVZDKUWFKF346LA7WA,B07R6YZQB8,3.0,1628718971833,12732,3113,All Electronics,Glistco Simple Feet- Horizontal Stand/Feet Com...,[],"[Video Games, PlayStation 4]",12.6,"[-1, -1, -1, -1, 624, 2620, 2252, 3122, 1737, ..."
4,AE5OKUZ7Y5RVG6WUBTHJPECSU6ZA,B094Y1G4F4,4.0,1628744762627,14142,4404,Video Games,"PowerA Enhanced - Seafoam Fade, Gamepad, Wired...",[This PowerA Enhanced Wired Controller is offi...,"[Video Games, Xbox One, Accessories, Controlle...",37.99,"[1680, 3356, 1142, 4549, 1393, 2937, 1839, 393..."
...,...,...,...,...,...,...,...,...,...,...,...,...
957,AHJUZFMUESAEQBPC2QQMBDVUBYFQ,B0B1PB5L93,4.0,1657883331431,6294,3694,Computers,Razer Viper Ultimate Lightweight Wireless Gami...,[Forget about average and claim the unfair adv...,"[Video Games, PC, Accessories, Gaming Mice]",89.99,"[-1, -1, -1, 4336, 4609, 3873, 2245, 2869, 320..."
958,AFIXV7CY3OC6WI5DXCS3JAGP5SQA,B0C37RBK2R,5.0,1657887021161,9028,4103,Video Games,Xbox Series S,"[Introducing the Xbox Series S, the smallest, ...",[],279.0,"[639, 1294, 627, 3644, 4468, 3016, 2610, 2932,..."
959,AFDL3ZQE4ARYEEBBH2KAPMP4NSHQ,B0795GHTBC,5.0,1657910674213,8973,3116,All Electronics,ivoler [3 Pack Screen Protector Tempered Glass...,[],"[Video Games, Nintendo Switch, Accessories, Fa...",9.39,"[-1, -1, -1, 1259, 1366, 2240, 3784, 1073, 570..."
960,AEE72HLCWIZT2GKD7UZRXN36T27A,B0CB8LZT7K,5.0,1657928730786,18414,424,Video Games,Daydayup Switch Carrying Case Compatible with ...,[],"[Video Games, Legacy Systems, Nintendo Systems...",21.99,"[-1, 2370, 1236, 2881, 4266, 3449, 131, 2739, ..."


In [16]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col]
        .agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [17]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

19578

In [18]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

157

## Persist

In [19]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in item_sequence:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_item_sequence:
        f.write(json.dumps(sequence) + "\n")

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-10-10 11:18:44.002[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=19,578 len(val_item_sequence)=157[0m


## Persist a small data for overfitting

In [20]:
num_sequences = 2
batch_item_sequence = item_sequence[:num_sequences]
batch_sequences_fp = "../data/batch_item_sequence.jsonl"

with open(batch_sequences_fp, "w") as f:
    for sequence in batch_item_sequence:
        f.write(json.dumps(sequence) + "\n")

# Run with all data

In [21]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-10-10 11:18:44.025[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([1712, 1712, 1712, 1712, 1712, 1712]) tensor([4219, 4424, 1022, 4557, 3810, 3951]) tensor([1., 0., 0., 0., 0., 0.])


In [22]:
batch_size = len(item_sequence[0])  # for easier testing

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [23]:
# Test index mapping matches input id_mapper
target_items_idx_dataloader = set()
for batch_input in dataloader:
    target_items_idx_dataloader.update(batch_input["target_items"].detach().numpy())
    break

targets_items_idx_item_sequence = set()
for item_id in item_sequence[0]:
    idx = idm.item_to_index[item_id]
    targets_items_idx_item_sequence.add(idx)

assert target_items_idx_dataloader == targets_items_idx_item_sequence

In [24]:
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-10-10 11:18:44.661[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [25]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([1462, 1462, 1462, 1462, 1462, 1462, 4525, 4525, 4525, 4525, 4525, 4525,
        1010, 1010, 1010, 1010, 1010, 1010, 1137, 1137, 1137, 1137, 1137, 1137,
        1137, 1137, 1137, 1137, 1137, 1137, 4422, 4422, 4422, 4422, 4422, 4422,
        4422, 4422, 4422, 4422, 4422, 4422, 1910, 1910, 1910, 1910, 1910, 1910,
        1910, 1910, 1910, 1910, 1910, 1910, 2243, 2243, 2243, 2243, 2243, 2243,
        2243, 2243, 2243, 2243, 2243, 2243, 1762, 1762, 1762, 1762, 1762, 1762,
        3840, 3840, 3840, 3840, 3840, 3840, 1526, 1526, 1526, 1526, 1526, 1526,
        1526, 1526, 1526, 1526, 1526, 1526,  917,  917,  917,  917,  917,  917,
        3248, 3248, 3248, 3248, 3248, 3248, 2348, 2348, 2348, 2348, 2348, 2348,
        2843, 2843, 2843, 2843, 2843, 2843, 2023, 2023, 2023, 2023, 2023, 2023,
        1110, 1110, 1110, 1110, 1110, 1110]), 'context_items': tensor([4525, 1804, 3125, 1189, 2529,  103, 1462,  100, 4390, 1324, 3347, 3203,
        1137, 1090, 2446, 2723, 3484, 16