# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import string
import sys

import pandas as pd
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

sys.path.insert(0, "..")
from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 16

    user_col: str = "user_id"
    item_col: str = "parent_asin"


args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 16,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "b", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

val_sequences = [["f", "l", "m"], ["i", "h"], ["j", "e", "a"]]

sequences_fp = "sequences.jsonl"
val_sequences_fp = "val_sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_sequences:
        f.write(json.dumps(sequence) + "\n")

In [5]:
# Simulate pre-configured id_to_idx mapper
id_to_idx = {
    id_: idx for id_, idx in zip(list(string.ascii_letters[:13]), list(range(13)))
}
id_to_idx["a"] = 1
id_to_idx["b"] = 0

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp, window_size=1, negative_samples=2, id_to_idx=id_to_idx
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=1,
    negative_samples=2,
    id_to_idx=id_to_idx,
)

# Example of getting an item
for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-09-29 16:59:49.930[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m57[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-09-29 16:59:49.951[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m57[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([0, 0, 0]) tensor([ 2,  7, 11]) tensor([1., 0., 0.])


In [6]:
dataset.id_to_idx

{'a': 1,
 'b': 0,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12}

In [7]:
dataset.sampling_probs

array([0.13537454, 0.05938764, 0.13537454, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764, 0.05938764, 0.05938764,
       0.13537454, 0.05938764, 0.05938764])

In [8]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 10, 11, 12},
             11: {6, 7, 10, 11, 12},
             12: {6, 7, 10, 11, 12},
             8: {2, 8, 9},
             9: {2, 8, 9}})

In [9]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(2))
(tensor(0), tensor(7))
(tensor(0), tensor(11))


In [10]:
val_dataset.sampling_probs

array([0.10225277, 0.07544086, 0.10225277, 0.0448574 , 0.07544086,
       0.07544086, 0.0448574 , 0.07544086, 0.07544086, 0.07544086,
       0.10225277, 0.07544086, 0.07544086])

In [11]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 10},
             1: {0, 1, 2, 3, 4, 9},
             2: {0, 1, 2, 3, 4, 8, 9, 10},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4, 9},
             10: {0, 2, 5, 6, 7, 10, 11, 12},
             5: {0, 5, 10, 11, 12},
             6: {6, 7, 10, 11, 12},
             7: {6, 7, 8, 10, 11, 12},
             11: {5, 6, 7, 10, 11, 12},
             12: {5, 6, 7, 10, 11, 12},
             8: {2, 7, 8, 9},
             9: {1, 2, 4, 8, 9}})

## Test no conflicting labels

In [12]:
dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [13]:
target_items = []
context_items = []
labels = []

for i, batch_input in tqdm(enumerate(dataloader), total=len(dataloader)):
    _target_items = batch_input["target_items"].cpu().detach().numpy()
    _context_items = batch_input["context_items"].cpu().detach().numpy()
    _labels = batch_input["labels"].cpu().detach().numpy()

    target_items.extend(_target_items)
    context_items.extend(_context_items)
    labels.extend(_labels)

test_df = pd.DataFrame(
    {"target_items": target_items, "context_items": context_items, "labels": labels}
)
assert (
    test_df.groupby(["target_items", "context_items"])["labels"]
    .nunique()
    .loc[lambda s: s > 1]
    .shape[0]
    == 0
), "Conflicting labels!"

  0%|          | 0/1 [00:00<?, ?it/s]

# Load data

In [14]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")
idm = IDMapper().load("../data/idm.json")

In [15]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AEN7JFLQCURF54WR5OHY7HOWWMSQ,B08FC5TTBF,5.0,1628644724721,23927,8911,Video Games,Demon's Souls - PlayStation 5,[From Bluepoint Games comes a remake of the Pl...,"[Video Games, PlayStation 5, Games]",29.99,"[-1, -1, 8943, 10693, 7298, 9143, 2204, 3729, ..."
1,AFRYPSUWNLA3JHPZ2JFJ2UPHICAA,B094KW45JB,3.0,1628645945696,61868,5984,Video Games,Resident Evil Village Standard - Xbox [Digital...,[Experience survival horror like never before ...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",,"[564, 7914, 5508, 7864, 5269, 9195, 9820, 1000..."
2,AH6IDYC5I3UWRBLDFRFHYQUCDRWA,B08F52Y2PN,4.0,1628652575492,22181,724,All Electronics,"ZIUMIER Camo Gaming Headset for PS4, PS5, Xbox...",[],"[Video Games, PlayStation 4, Accessories, Head...",33.99,"[-1, -1, -1, -1, 6097, 9166, 3645, 8674, 2755,..."
3,AELH2ZF5QSSIFBF6WXAZLCF7JIWA,B0C6DH316S,2.0,1628653733506,32966,6990,Computers,Logitech G PRO X Wireless Lightspeed Gaming He...,[],"[Video Games, PC, Accessories, Headsets]",253.82,"[-1, -1, -1, 6699, 3162, 5730, 6422, 338, 8304..."
4,AHU5XXAXE4IFARJKTC4E266FDVAQ,B07PMFPQBC,5.0,1628656981268,5485,1821,Computers,"Suncala 256MB Memory Card for Playstation 2, H...",[],"[Video Games, Legacy Systems, PlayStation Syst...",9.98,"[-1, -1, -1, -1, 4844, 2327, 4718, 1414, 1643,..."
...,...,...,...,...,...,...,...,...,...,...,...,...
5243,AE5UUBPDQX4MRFFDW7D3IKHQYIEQ,B0B5SWS9ZW,5.0,1657945351808,67730,6287,Video Games,Gran Turismo Sport Hits - PlayStation 4,[The new standard in racing - introducing the ...,"[Video Games, PlayStation 4, Games]",16.99,"[6657, 7159, 1220, 2330, 3893, 10784, 3661, 10..."
5244,AE5UUBPDQX4MRFFDW7D3IKHQYIEQ,B00ZJBSBD8,5.0,1657945454164,67730,5393,Video Games,Trackmania Turbo-Nla,[Step into the wild car fantasy world of Track...,"[Video Games, PlayStation 4, Games]",13.68,"[7159, 1220, 2330, 3893, 10784, 3661, 10945, 4..."
5245,AF54JR3WONKVAUZUYDLOOPZN7NFQ,B0C37RBK2R,5.0,1657950291114,44112,2240,Video Games,Xbox Series S,"[Introducing the Xbox Series S, the smallest, ...",[],279.0,"[-1, -1, 6960, 8595, 8166, 11167, 1617, 4347, ..."
5246,AG3N3EMFIFGW66WOIC4MW55IUS6Q,B0BXQH38S6,4.0,1657965636367,53031,9169,Computers,Logitech G G703 6-Button Wireless Gaming Mouse...,[Logitech G703 Lightspeed Wireless Gaming Mous...,"[Video Games, PC, Accessories, Gaming Keyboards]",144.79,"[-1, -1, 2809, 6812, 2083, 5686, 363, 683, 873..."


In [16]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col]
        .agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [17]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

68572

In [18]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

871

## Persist

In [19]:
sequences_fp = "item_sequence.jsonl"
val_sequences_fp = "val_item_sequence.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in item_sequence:
        f.write(json.dumps(sequence) + "\n")
with open(val_sequences_fp, "w") as f:
    for sequence in val_item_sequence:
        f.write(json.dumps(sequence) + "\n")

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-09-29 16:59:55.029[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=68,572 len(val_item_sequence)=871[0m


## Persist a small data for overfitting

In [20]:
num_sequences = 2
batch_item_sequence = item_sequence[:num_sequences]
batch_sequences_fp = "batch_item_sequence.jsonl"

with open(batch_sequences_fp, "w") as f:
    for sequence in batch_item_sequence:
        f.write(json.dumps(sequence) + "\n")

# Run with all data

In [21]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

for inp in dataset:
    target_items = inp["target_items"]
    context_items = inp["context_items"]
    labels = inp["labels"]
    print(target_items, context_items, labels)
    break

[32m2024-09-29 16:59:55.082[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m57[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

tensor([432, 432, 432, 432, 432, 432]) tensor([11177,  8558,  1042,  6995,  3396,  5669]) tensor([1., 0., 0., 0., 0., 0.])


In [22]:
batch_size = len(item_sequence[0])  # for easier testing

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)

In [23]:
# Test index mapping matches input id_mapper
target_items_idx_dataloader = set()
for batch_input in dataloader:
    target_items_idx_dataloader.update(batch_input["target_items"].detach().numpy())
    break

targets_items_idx_item_sequence = set()
for item_id in item_sequence[0]:
    idx = idm.item_to_index[item_id]
    targets_items_idx_item_sequence.add(idx)

assert target_items_idx_dataloader == targets_items_idx_item_sequence

In [24]:
val_dataset = SkipGramDataset(
    val_sequences_fp,
    interacted=dataset.interacted,
    item_freq=dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-09-29 17:00:00.442[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m57[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [25]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([  49,   49,   49,   49,   49,   49, 6670, 6670, 6670, 6670, 6670, 6670,
        6670, 6670, 6670, 6670, 6670, 6670, 5718, 5718, 5718, 5718, 5718, 5718,
        5718, 5718, 5718, 5718, 5718, 5718,  970,  970,  970,  970,  970,  970,
         970,  970,  970,  970,  970,  970, 9766, 9766, 9766, 9766, 9766, 9766,
        9766, 9766, 9766, 9766, 9766, 9766, 1371, 1371, 1371, 1371, 1371, 1371,
        1853, 1853, 1853, 1853, 1853, 1853,  844,  844,  844,  844,  844,  844,
        5346, 5346, 5346, 5346, 5346, 5346, 8471, 8471, 8471, 8471, 8471, 8471,
        1108, 1108, 1108, 1108, 1108, 1108, 7425, 7425, 7425, 7425, 7425, 7425,
        1738, 1738, 1738, 1738, 1738, 1738, 6363, 6363, 6363, 6363, 6363, 6363,
        6059, 6059, 6059, 6059, 6059, 6059, 3741, 3741, 3741, 3741, 3741, 3741]), 'context_items': tensor([ 6670, 10589,  8622,  3326,  8507,  7438,    49,  5718,  7513,  1055,
         9189,  1659,  4112,  6316,  1035,  2911, 10587,  9093,  6670,   970,
        