# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import sys
import string

import pandas as pd
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

sys.path.insert(0, "..")
from src.skipgram.dataset import SkipGramDataset
from src.id_mapper import IDMapper

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 16

    user_col: str = "user_id"
    item_col: str = "parent_asin"


args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 16,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
# Example sequences of item IDs
# sequences = [
#     [1, 2, 3, 4, 0],
#     [5, 1, 1, 1, 10],
#     [6, 12, 10, 11, 7],
#     [1, 2, 10],
#     [9, 8, 2],
# ]

# val_sequences = [
#     [5, 11, 12],
#     [8, 7],
#     [9, 4, 0]
# ]

sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "b", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

val_sequences = [["f", "l", "m"], ["i", "h"], ["j", "e", "a"]]

# Simulate pre-configured id_to_idx mapper 
id_to_idx = {id_: idx for id_, idx in zip(list(string.ascii_letters[:13]), list(range(13)))}
id_to_idx['a'] = 1
id_to_idx['b'] = 0

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(sequences, window_size=1, negative_samples=2, id_to_idx=id_to_idx)
val_dataset = SkipGramDataset(
    val_sequences,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=1,
    negative_samples=2,
    id_to_idx=id_to_idx
)

# Example of getting an item
inp = dataset[0]
target_items = inp["target_items"]
context_items = inp["context_items"]
labels = inp["labels"]
print(target_items, context_items, labels)

[32m2024-10-18 18:09:49.164[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m59[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/5 [00:00<?, ?it/s]

[32m2024-10-18 18:09:49.180[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m59[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/3 [00:00<?, ?it/s]

tensor([0, 0, 0]) tensor([2, 6, 7]) tensor([1., 0., 0.])


In [5]:
dataset.id_to_idx

{'a': 1,
 'b': 0,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12}

In [6]:
dataset.sampling_probs

array([0.13537454, 0.05938764, 0.13537454, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764, 0.05938764, 0.05938764,
       0.13537454, 0.05938764, 0.05938764])

In [7]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 10, 11, 12},
             11: {6, 7, 10, 11, 12},
             12: {6, 7, 10, 11, 12},
             8: {2, 8, 9},
             9: {2, 8, 9}})

In [8]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(2))
(tensor(0), tensor(6))
(tensor(0), tensor(7))


In [9]:
val_dataset.sampling_probs

array([0.10225277, 0.07544086, 0.10225277, 0.0448574 , 0.07544086,
       0.07544086, 0.0448574 , 0.07544086, 0.07544086, 0.07544086,
       0.10225277, 0.07544086, 0.07544086])

In [10]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4, 9},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4, 9},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10, 11, 12},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 8, 10, 11, 12},
             11: {5, 6, 7, 10, 11, 12},
             12: {5, 6, 7, 10, 11, 12},
             8: {2, 7, 8, 9},
             9: {1, 2, 4, 8, 9}})

## Test no conflicting labels

In [11]:
dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [12]:
target_items = []
context_items = []
labels = []

for i, batch_input in tqdm(enumerate(dataloader), total=len(dataloader)):
    _target_items = batch_input["target_items"].cpu().detach().numpy()
    _context_items = batch_input["context_items"].cpu().detach().numpy()
    _labels = batch_input["labels"].cpu().detach().numpy()

    target_items.extend(_target_items)
    context_items.extend(_context_items)
    labels.extend(_labels)

test_df = pd.DataFrame(
    {"target_items": target_items, "context_items": context_items, "labels": labels}
)
assert (
    test_df.groupby(["target_items", "context_items"])["labels"]
    .nunique()
    .loc[lambda s: s > 1]
    .shape[0]
    == 0
), "Conflicting labels!"

  0%|          | 0/1 [00:00<?, ?it/s]

# Load data

In [13]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")
idm = IDMapper().load("../data/idm.json")

In [14]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AEN7JFLQCURF54WR5OHY7HOWWMSQ,B08FC5TTBF,5.0,1628644724721,17723,493,Video Games,Demon's Souls - PlayStation 5,[From Bluepoint Games comes a remake of the Pl...,"[Video Games, PlayStation 5, Games]",29.99,"[-1, -1, -1, -1, 1033, 2858, 2715, 2154, 4157,..."
1,AELH2ZF5QSSIFBF6WXAZLCF7JIWA,B0C6DH316S,2.0,1628653733506,9366,2158,Computers,Logitech G PRO X Wireless Lightspeed Gaming He...,[],"[Video Games, PC, Accessories, Headsets]",253.82,"[-1, -1, -1, -1, 962, 485, 3662, 1772, 1450, 877]"
2,AGD4QHNPSC45XTUPSUE6TYQOF3WQ,B0BN5DC36N,5.0,1628679010802,10050,1348,Computers,Seagate Horizon Forbidden West Limited Edition...,[Discover new worlds with the officially-licen...,"[Video Games, Legacy Systems, PlayStation Syst...",89.99,"[1798, 3008, 3372, 2187, 612, 2198, 3363, 4036..."
3,AFMOSTKHH2HFLI35E3YMI7GLYDCQ,B07KRWJCQW,5.0,1628687441776,19905,2867,Video Games,$40 Xbox Gift Card [Digital Code],[Buy an Xbox Gift Card for yourself or a frien...,"[Video Games, Online Game Services, Xbox Live,...",40.0,"[-1, -1, 3765, 2762, 3671, 3502, 1888, 3530, 2..."
4,AGK34QNFABMBLRESDKG2VRC3VIIQ,B0BL65X86R,5.0,1628702768435,8627,2698,Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0,"[2488, 3199, 2018, 2368, 2930, 3410, 4532, 232..."
...,...,...,...,...,...,...,...,...,...,...,...,...
944,AEKYV77UMZZGHT4PZIETDQ6ELJBQ,B08F4C6HCD,5.0,1657816667680,5408,2929,Video Games,Legend of Zelda Link's Awakening - Nintendo Sw...,"[“Castaway, you should know the truth!” As Lin...","[Video Games, Nintendo Switch, Games]",59.88,"[2108, 4347, 1392, 2244, 2194, 1986, 44, 3931,..."
945,AGUFCRCH7HOUQ5FQYSJETEEFAYOA,B00DBDPOZ4,5.0,1657855227062,5315,2713,Video Games,Xbox One Play and Charge Kit,[Keep the action going with the Xbox One Play ...,"[Video Games, Xbox One, Accessories]",34.99,"[-1, -1, -1, -1, -1, 1917, 2652, 3712, 4346, 3..."
946,AHJUZFMUESAEQBPC2QQMBDVUBYFQ,B0B1PB5L93,4.0,1657883331431,15104,2067,Computers,Razer Viper Ultimate Lightweight Wireless Gami...,[Forget about average and claim the unfair adv...,"[Video Games, PC, Accessories, Gaming Mice]",89.99,"[-1, -1, -1, 4682, 9, 155, 1559, 3146, 3065, 770]"
947,AE5UUBPDQX4MRFFDW7D3IKHQYIEQ,B00ZJBSBD8,5.0,1657945454164,16935,2936,Video Games,Trackmania Turbo-Nla,[Step into the wild car fantasy world of Track...,"[Video Games, PlayStation 4, Games]",13.68,"[18, 2498, 2002, 1603, 3391, 1963, 1763, 525, ..."


In [15]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col]
        .agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [16]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

20366

In [17]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

147

## Persist

In [18]:
with open("../data/item_sequence.json", "w") as f:
    json.dump(item_sequence, f)
with open("../data/val_item_sequence.json", "w") as f:
    json.dump(val_item_sequence, f)

with open("../data/item_sequence.json", "r") as f:
    item_sequence = json.load(f)
with open("../data/val_item_sequence.json", "r") as f:
    val_item_sequence = json.load(f)

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-10-18 18:09:50.118[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=20,366 len(val_item_sequence)=147[0m


# Run with all data

In [19]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    item_sequence,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index
)

inp = dataset[0]
target_items = inp["target_items"]
context_items = inp["context_items"]
labels = inp["labels"]
print(target_items, context_items, labels)

[32m2024-10-18 18:09:50.128[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m59[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/20366 [00:00<?, ?it/s]

tensor([4529, 4529, 4529, 4529, 4529, 4529]) tensor([4054, 2873,  673, 4411, 3857, 1164]) tensor([1., 0., 0., 0., 0., 0.])


In [20]:
batch_size = len(item_sequence[0])  # for easier testing

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [21]:
# Test index mapping matches input id_mapper
target_items_idx_dataloader = set()
for batch_input in dataloader:
    target_items_idx_dataloader.update(batch_input['target_items'].detach().numpy())
    break

targets_items_idx_item_sequence = set()
for item_id in item_sequence[0]:
    idx = idm.item_to_index[item_id]
    targets_items_idx_item_sequence.add(idx)

assert target_items_idx_dataloader == targets_items_idx_item_sequence

In [22]:
val_dataset = SkipGramDataset(
    val_item_sequence,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-10-18 18:09:50.707[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m59[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/147 [00:00<?, ?it/s]

In [23]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([ 564,  564,  564,  564,  564,  564, 1531, 1531, 1531, 1531, 1531, 1531,
        3345, 3345, 3345, 3345, 3345, 3345, 1222, 1222, 1222, 1222, 1222, 1222,
        1222, 1222, 1222, 1222, 1222, 1222, 4483, 4483, 4483, 4483, 4483, 4483,
        4483, 4483, 4483, 4483, 4483, 4483, 4062, 4062, 4062, 4062, 4062, 4062,
        4062, 4062, 4062, 4062, 4062, 4062, 1062, 1062, 1062, 1062, 1062, 1062,
        1062, 1062, 1062, 1062, 1062, 1062, 2867, 2867, 2867, 2867, 2867, 2867,
        2867, 2867, 2867, 2867, 2867, 2867,   21,   21,   21,   21,   21,   21,
        3124, 3124, 3124, 3124, 3124, 3124,  358,  358,  358,  358,  358,  358,
        3843, 3843, 3843, 3843, 3843, 3843, 2184, 2184, 2184, 2184, 2184, 2184,
        2184, 2184, 2184, 2184, 2184, 2184, 4646, 4646, 4646, 4646, 4646, 4646,
        4646, 4646, 4646, 4646, 4646, 4646, 4619, 4619, 4619, 4619, 4619, 4619,
        4619, 4619, 4619, 4619, 4619, 4619,  157,  157,  157,  157,  157,  157,
         157,  157,  15

In [24]:
val_dataset[0]

{'target_items': tensor([564, 564, 564, 564, 564, 564]),
 'context_items': tensor([1531, 1760,  768, 2050, 3408, 1500]),
 'labels': tensor([1., 0., 0., 0., 0., 0.])}