# Build sequence data loaders for Skip Gram

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import sys
from loguru import logger
from pydantic import BaseModel
import pandas as pd
from torch.utils.data import DataLoader

sys.path.insert(0, "..")
from src.skipgram.dataset import SkipGramDataset

# Controller

In [3]:
class Args(BaseModel):
    num_negative_samples: int = 5
    window_size: int = 1
    batch_size: int = 2

    user_col: str = 'user_id'
    item_col: str = 'parent_asin'

args = Args()
print(args.model_dump_json(indent=2))

{
  "num_negative_samples": 5,
  "window_size": 1,
  "batch_size": 2,
  "user_col": "user_id",
  "item_col": "parent_asin"
}


# Test implementation

In [4]:
# Example sequences of item IDs
# sequences = [
#     [1, 2, 3, 4, 0],
#     [5, 1, 1, 1, 10],
#     [6, 12, 10, 11, 7],
#     [1, 2, 10],
#     [9, 8, 2],
# ]

# val_sequences = [
#     [5, 11, 12],
#     [8, 7],
#     [9, 4, 0]
# ]

sequences = [
    ['b', 'c', 'd', 'e', 'a'],
    ['f', 'b', 'b', 'b', 'k'],
    ['g', 'm', 'k', 'l', 'h'],
    ['b', 'c', 'k'],
    ['j', 'i', 'c']
]

val_sequences = [
    ['f', 'l', 'm'],
    ['i', 'h'],
    ['j', 'e', 'a']
]

# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(sequences, window_size=1, negative_samples=2)
val_dataset = SkipGramDataset(val_sequences, interacted=dataset.interacted, item_freq=dataset.item_freq, window_size=1, negative_samples=2)

# Example of getting an item
inp = dataset[0]
target_items = inp['target_items']
context_items = inp['context_items']
labels = inp['labels']
print(target_items, context_items, labels)

[32m2024-09-24 08:34:00.597[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m53[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/5 [00:00<?, ?it/s]

[32m2024-09-24 08:34:00.608[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m53[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/3 [00:00<?, ?it/s]

tensor([0, 0, 0]) tensor([1, 8, 7]) tensor([1., 0., 0.])


In [5]:
dataset.sampling_probs

array([0.13537454, 0.13537454, 0.05938764, 0.05938764, 0.05938764,
       0.05938764, 0.13537454, 0.05938764, 0.05938764, 0.05938764,
       0.05938764, 0.05938764, 0.05938764])

In [6]:
dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 6},
             1: {0, 1, 2, 3, 4, 6, 11, 12},
             2: {0, 1, 2, 3, 4},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             5: {0, 5, 6},
             6: {0, 1, 5, 6, 7, 8, 9, 10},
             7: {6, 7, 8, 9, 10},
             8: {6, 7, 8, 9, 10},
             9: {6, 7, 8, 9, 10},
             10: {6, 7, 8, 9, 10},
             11: {1, 11, 12},
             12: {1, 11, 12}})

In [7]:
for i in zip(target_items, context_items):
    print(i)

(tensor(0), tensor(1))
(tensor(0), tensor(8))
(tensor(0), tensor(7))


In [8]:
val_dataset.sampling_probs

array([0.12918587, 0.12918587, 0.07681438, 0.07681438, 0.07681438,
       0.07681438, 0.12918587, 0.07681438, 0.0456741 , 0.0456741 ,
       0.0456741 , 0.0456741 , 0.0456741 ])

In [9]:
val_dataset.interacted

defaultdict(set,
            {0: {0, 1, 2, 3, 4, 5, 6},
             1: {0, 1, 2, 3, 4, 6, 11, 12},
             2: {0, 1, 2, 3, 4},
             3: {0, 1, 2, 3, 4},
             4: {0, 1, 2, 3, 4},
             5: {0, 5, 6, 7},
             6: {0, 1, 5, 6, 7, 8, 9, 10},
             7: {5, 6, 7, 8, 9, 10},
             8: {6, 7, 8, 9, 10},
             9: {6, 7, 8, 9, 10},
             10: {6, 7, 8, 9, 10},
             11: {1, 11, 12},
             12: {1, 11, 12}})

# Load data

In [10]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")

In [11]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence
0,AEN7JFLQCURF54WR5OHY7HOWWMSQ,B08FC5TTBF,5.0,1628644724721,2174,1363,Video Games,Demon's Souls - PlayStation 5,[From Bluepoint Games comes a remake of the Pl...,"[Video Games, PlayStation 5, Games]",29.99,"[-1, -1, -1, -1, 855, 817, 1661, 3931, 4525, 1..."
1,AELH2ZF5QSSIFBF6WXAZLCF7JIWA,B0C6DH316S,2.0,1628653733506,8913,1709,Computers,Logitech G PRO X Wireless Lightspeed Gaming He...,[],"[Video Games, PC, Accessories, Headsets]",253.82,"[-1, -1, -1, -1, 741, 1446, 2299, 3658, 2239, ..."
2,AGD4QHNPSC45XTUPSUE6TYQOF3WQ,B0BN5DC36N,5.0,1628679010802,11028,344,Computers,Seagate Horizon Forbidden West Limited Edition...,[Discover new worlds with the officially-licen...,"[Video Games, Legacy Systems, PlayStation Syst...",89.99,"[3771, 4691, 1536, 3627, 3531, 575, 3542, 4404..."
3,AFMOSTKHH2HFLI35E3YMI7GLYDCQ,B07KRWJCQW,5.0,1628687441776,8097,514,Video Games,$40 Xbox Gift Card [Digital Code],[Buy an Xbox Gift Card for yourself or a frien...,"[Video Games, Online Game Services, Xbox Live,...",40.0,"[-1, -1, 4625, 4540, 1625, 2877, 2586, 1797, 6..."
4,AGK34QNFABMBLRESDKG2VRC3VIIQ,B0BL65X86R,5.0,1628702768435,12761,4246,Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0,"[3882, 3574, 1410, 3630, 3089, 4330, 4077, 367..."
...,...,...,...,...,...,...,...,...,...,...,...,...
944,AEKYV77UMZZGHT4PZIETDQ6ELJBQ,B08F4C6HCD,5.0,1657816667680,3609,3765,Video Games,Legend of Zelda Link's Awakening - Nintendo Sw...,"[“Castaway, you should know the truth!” As Lin...","[Video Games, Nintendo Switch, Games]",59.88,"[3085, 650, 4695, 4243, 4067, 992, 3646, 1609,..."
945,AGUFCRCH7HOUQ5FQYSJETEEFAYOA,B00DBDPOZ4,5.0,1657855227062,7385,2174,Video Games,Xbox One Play and Charge Kit,[Keep the action going with the Xbox One Play ...,"[Video Games, Xbox One, Accessories]",34.99,"[-1, -1, -1, -1, -1, 3613, 604, 1230, 3026, 2596]"
946,AHJUZFMUESAEQBPC2QQMBDVUBYFQ,B0B1PB5L93,4.0,1657883331431,14871,1187,Computers,Razer Viper Ultimate Lightweight Wireless Gami...,[Forget about average and claim the unfair adv...,"[Video Games, PC, Accessories, Gaming Mice]",89.99,"[-1, -1, -1, 3002, 2444, 601, 4161, 3940, 3084..."
947,AE5UUBPDQX4MRFFDW7D3IKHQYIEQ,B00ZJBSBD8,5.0,1657945454164,18008,3094,Video Games,Trackmania Turbo-Nla,[Step into the wild car fantasy world of Track...,"[Video Games, PlayStation 4, Games]",13.68,"[4420, 4027, 1762, 3130, 2766, 4588, 2672, 141..."


In [12]:
def get_sequence(df, user_col=args.user_col, item_col=args.item_col):
    return (
        df.groupby(user_col)[item_col].agg(list)
        .loc[lambda s: s.apply(len) > 1]  # Remove sequence with only one item
    ).values.tolist()

In [13]:
item_sequence = train_df.pipe(get_sequence)
len(item_sequence)

20366

In [14]:
val_item_sequence = val_df.pipe(get_sequence)
len(val_item_sequence)

147

## Persist

In [15]:
with open("../data/item_sequence.json", "w") as f:
    json.dump(item_sequence, f)
with open("../data/val_item_sequence.json", "w") as f:
    json.dump(val_item_sequence, f)

with open("../data/item_sequence.json", "r") as f:
    item_sequence = json.load(f)
with open("../data/val_item_sequence.json", "r") as f:
    val_item_sequence = json.load(f)

logger.info(f"{len(item_sequence)=:,.0f} {len(val_item_sequence)=:,.0f}")

[32m2024-09-24 08:34:03.321[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mlen(item_sequence)=20,366 len(val_item_sequence)=147[0m


# Run with all data

In [16]:
# Create dataset with frequency-based negative sampling
dataset = SkipGramDataset(item_sequence, window_size=args.window_size, negative_samples=args.num_negative_samples)

# Example of getting the first item
target_items, context_items, labels = dataset[0]
print(target_items, context_items, labels)

[32m2024-09-24 08:34:03.341[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m53[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/20366 [00:00<?, ?it/s]

target_items context_items labels


In [17]:
dataloader = DataLoader(
    dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, collate_fn=dataset.collate_fn
)

In [18]:
for batch_input in dataloader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([   1, 1841, 3506,   16,   99,  582,    0,    2,  864, 1148, 2394,  911,
        3947, 2080,  469, 1328, 4366,  635]), 'labels': tensor([1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}


In [19]:
item_sequence[0]

['B00029QOQS',
 'B0006B7DXA',
 'B001LETH2Q',
 'B0009XEC02',
 'B000NNDN1M',
 'B00136MBHA',
 'B007VTVRFA',
 'B0053BCML6']

In [20]:
dataset.interacted[2314]

{28,
 34,
 80,
 82,
 99,
 139,
 160,
 200,
 232,
 249,
 269,
 294,
 327,
 342,
 370,
 378,
 488,
 496,
 534,
 543,
 585,
 590,
 609,
 611,
 629,
 731,
 747,
 756,
 858,
 955,
 987,
 1014,
 1111,
 1116,
 1148,
 1183,
 1194,
 1267,
 1366,
 1373,
 1456,
 1490,
 1509,
 1513,
 1539,
 1615,
 1884,
 1935,
 1973,
 2190,
 2228,
 2251,
 2314,
 2323,
 2332,
 2548,
 2595,
 2663,
 2679,
 2682,
 2717,
 2721,
 2764,
 2787,
 2813,
 2833,
 3000,
 3002,
 3079,
 3152,
 3169,
 3214,
 3225,
 3307,
 3531,
 3562,
 3633,
 3899,
 4005,
 4034,
 4046,
 4087,
 4288,
 4333,
 4355,
 4660}

In [21]:
val_dataset = SkipGramDataset(val_item_sequence, interacted=dataset.interacted, item_freq=dataset.item_freq, window_size=args.window_size, negative_samples=args.num_negative_samples)
val_dataloader = DataLoader(
    val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, collate_fn=val_dataset.collate_fn
)

[32m2024-09-24 08:34:04.684[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m53[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions:   0%|          | 0/147 [00:00<?, ?it/s]

In [22]:
for batch_input in val_dataloader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([   1, 3473,  152,  838,  213, 4682,    0, 2833,  202, 1100,  127, 1138]), 'labels': tensor([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])}


In [23]:
val_dataset[0]

{'target_items': tensor([0, 0, 0, 0, 0, 0]),
 'context_items': tensor([   1, 3728, 2429,  471, 2520,  189]),
 'labels': tensor([1., 0., 0., 0., 0., 0.])}