In [1]:
import h5py
import torch
import numpy as np

from tqdm import tqdm

import pandas as pd

import json

import random

from torch.utils.data import Dataset

In [2]:
def save_user_track_interactions_to_hdf5(df: pd.DataFrame, h5_path: str):
    """
    Given a dataframe with columns: user_id, track_id, timestamp,
    store grouped interactions into HDF5 with flat + offset layout.
    """
    assert set(df.columns) >= {"user_id", "track_id", "ts"}, "Missing required columns"

    # Group by (user_id, track_id)
    df_sorted = df.sort_values(["user_id", "track_id", "ts"])
    grouped = df_sorted.groupby(["user_id", "track_id"])

    # Prepare storage arrays
    timestamps_flat = []
    offsets = []
    user_item = []

    current_offset = 0
    for (user, item), group in tqdm(grouped):
        ts = group["ts"].to_numpy(dtype=np.int64)
        timestamps_flat.append(ts)
        length = len(ts)
        offsets.append((current_offset, length))
        user_item.append((user, item))
        current_offset += length

    # Convert to flat arrays
    timestamps_flat = np.concatenate(timestamps_flat)
    offsets = np.array(offsets, dtype=np.int64)
    user_item = np.array(user_item, dtype=np.int64)

    # Save to HDF5
    with h5py.File(h5_path, "w") as f:
        f.create_dataset("timestamps_flat", data=timestamps_flat, compression="gzip", chunks=True)
        f.create_dataset("offsets", data=offsets, compression="gzip", chunks=True)
        f.create_dataset("user_item", data=user_item, compression="gzip", chunks=True)

    print(f"Saved {len(user_item)} user-item pairs, {len(timestamps_flat)} timestamps to {h5_path}")

In [3]:
df = pd.read_parquet('sorted_data.parquet')

save_user_track_interactions_to_hdf5(df, "interactions.h5")

100%|██████████████████████████████████████████████████████████████████████| 4892757/4892757 [02:37<00:00, 31026.90it/s]


Saved 4892757 user-item pairs, 16658781 timestamps to interactions.h5


In [2]:
file = h5py.File('interactions.h5', 'r')

In [3]:
file.keys()

<KeysViewHDF5 ['offsets', 'timestamps_flat', 'user_item']>

In [4]:
file['timestamps_flat']

<HDF5 dataset "timestamps_flat": shape (16658781,), type "<i8">

In [5]:
file['offsets'][0]

array([0, 2])

In [6]:
file['user_item']

<HDF5 dataset "user_item": shape (4892757, 2), type "<i8">

In [25]:
user = 1
items = [1,2,3,4,5,6,7,8,9,10,23,45,56,19,63]



In [26]:
pos_dict = {tuple(x) : i for i,x in enumerate(tqdm(file['user_item']))}

100%|█████████████████████████████████████████████████████████████████████| 4892757/4892757 [00:31<00:00, 154718.48it/s]


In [27]:
pos_dict

{(np.int64(1), np.int64(1)): 0,
 (np.int64(1), np.int64(2)): 1,
 (np.int64(1), np.int64(3)): 2,
 (np.int64(1), np.int64(4)): 3,
 (np.int64(1), np.int64(5)): 4,
 (np.int64(1), np.int64(6)): 5,
 (np.int64(1), np.int64(7)): 6,
 (np.int64(1), np.int64(8)): 7,
 (np.int64(1), np.int64(9)): 8,
 (np.int64(1), np.int64(10)): 9,
 (np.int64(1), np.int64(11)): 10,
 (np.int64(1), np.int64(12)): 11,
 (np.int64(1), np.int64(13)): 12,
 (np.int64(1), np.int64(14)): 13,
 (np.int64(1), np.int64(15)): 14,
 (np.int64(1), np.int64(16)): 15,
 (np.int64(1), np.int64(17)): 16,
 (np.int64(1), np.int64(18)): 17,
 (np.int64(1), np.int64(19)): 18,
 (np.int64(1), np.int64(20)): 19,
 (np.int64(1), np.int64(21)): 20,
 (np.int64(1), np.int64(22)): 21,
 (np.int64(1), np.int64(23)): 22,
 (np.int64(1), np.int64(24)): 23,
 (np.int64(1), np.int64(25)): 24,
 (np.int64(1), np.int64(26)): 25,
 (np.int64(1), np.int64(27)): 26,
 (np.int64(1), np.int64(28)): 27,
 (np.int64(1), np.int64(29)): 28,
 (np.int64(1), np.int64(30)): 29,

In [29]:
poss = [pos_dict[(user, i)] for i in items]
poss.sort()

In [30]:
poss

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 18, 22, 44, 55, 62]

In [32]:
boundaries = file['offsets'][poss]

In [None]:
def clear_pad(l, curr):
    

In [36]:
[file['timestamps_flat'][start:end] for start, end in boundaries]

[array([1654072390, 1655197911]),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64)]

In [None]:
output = []

for i in items:
    
    tds = []
    weights = []
    