In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import ReciprocalRank

from sklearn.metrics import ndcg_score

import pandas as pd
from tqdm import tqdm

import random

import numpy as np

import json

import h5py

In [2]:
def sample_excluding(n, x, a):
    if x == -1:
        return [num for num in range(1,n+1) if num != a ]
    
    if x > n - 1:
        raise ValueError("Cannot sample more elements than available excluding 'a'")
    
    # Sample x numbers from 1 to n-1
    sampled = random.sample(range(1, n), x)

    # Map values >= a to skip 'a'
    return [num if num < a else num + 1 for num in sampled]
# ok, lets look at what we need right here and there

class Ex2VecDataset(Dataset):
    def __init__(self, data_path, usage_dict_path, timedeltas_list_path, history_size=3500, sample_negative=999, max_padding=256):
        self.data_path = data_path
        self.usage_dict_path = usage_dict_path
        self.history_size = history_size
        self.sample_negative = sample_negative

        self.data = pd.read_parquet(self.data_path)

        self.max_user = self.data['user_id'].max()
        self.max_item = self.data['track_id'].max()

        self.data.set_index(['user_id', 'track_id'], inplace=True, drop=False)

        self.max_padding=max_padding

        with open(usage_dict_path) as file:
            self.use_dict = {int(key) : set(value) for key, value in json.load(file).items()}

        with h5py.File(timedeltas_list_path, 'r') as f:
            self.offsets = f['offsets'][:]
            self.timestamps_flat = f['timestamps_flat'][:]

            self.pos_dict = {tuple(x) : i for i,x in enumerate(tqdm(f['user_item']))}
    
            total_size = (self.max_user + 1) * (self.max_item + 1)
            self.pos_array = np.full(total_size, -1, dtype=np.int32)
            for i, (user, item) in enumerate(tqdm(f['user_item'])):
                flat_index = user * (self.max_item + 1) + item
                self.pos_array[flat_index] = i

        

    def __len__(self):
        return len(self.data)

    # @profile
    def __getitem__(self, idx):

        user = self.data.iloc[idx]['user_id']
        pred_item = self.data.iloc[idx]['track_id']
        ts = self.data.iloc[idx]['ts']

        if pred_item not in self.use_dict[user]:
            return None

        if self.sample_negative != -1:
            true_vals = np.zeros(self.sample_negative + 1, dtype=np.float32)
            samples = np.empty(self.sample_negative + 1, dtype=np.int32)
            timedeltas = np.zeros((self.sample_negative + 1, self.max_padding), dtype=np.float32)
        else:
            true_vals = np.zeros(self.max_item)
            samples = np.empty(self.max_item, dtype=np.int32)
            timedeltas = np.zeros((self.max_item, self.max_padding), dtype=np.float32)
        true_vals[-1] = 1.0
    
        
        samples[:-1] = sample_excluding(self.max_item, self.sample_negative, pred_item)
        samples[-1] = pred_item
    
        # Vectorized flat index computation
        flat_indices = user * (self.max_item + 1) + samples
        idx_items = self.pos_array[flat_indices]
    
        
        valid_mask = idx_items != -1
        valid_indices = np.nonzero(valid_mask)[0]
        valid_pos = idx_items[valid_mask]
    
        starts = self.offsets[valid_pos, 0]
        lengths = self.offsets[valid_pos, 1]
        ends = starts + lengths
    
        for i, (start, end, length, sample_idx) in enumerate(zip(starts, ends, lengths, valid_indices)):
            timedeltas[sample_idx, :length] = ts - self.timestamps_flat[start:end]

        weights = timedeltas > 0
        
        return {
            'user_id' : torch.tensor(user),
            'predict_items' : torch.tensor(samples),
            'real_values' : torch.tensor(true_vals),
            'timedeltas' : torch.from_numpy(timedeltas),
            'weights' : torch.from_numpy(weights.astype(np.float32))
        }
        
# @profile
def collate_fn(batch):
    # Remove None entries
    batch = [x for x in batch if x is not None]
    
    if not batch:
        return None  # Signal to skip this batch
    
    # Stack each field in the batch
    collated_batch = {}
    keys = batch[0].keys()
    for key in keys:
        collated_batch[key] = torch.stack([sample[key] for sample in batch])

    return collated_batch

class Ex2VecOriginal(torch.nn.Module):
    def __init__(self, config):
        super(Ex2VecOriginal, self).__init__()
        self.config = config
        self.n_users = config['n_users']
        self.n_items = config['n_items']
        self.latend_d = config['latent_d']

        self.global_lamb = torch.nn.Parameter(torch.tensor(1.0))

        self.user_lamb = torch.nn.Embedding(self.n_users+1, 1)

        self.user_bias = torch.nn.Embedding(self.n_users+1, 1)
        self.item_bias = torch.nn.Embedding(self.n_items+1, 1)

        self.alpha = torch.nn.Parameter(torch.tensor(1.0))
        self.beta = torch.nn.Parameter(torch.tensor(-0.065))
        self.gamma = torch.nn.Parameter(torch.tensor(0.5))

        self.cutoff = torch.nn.Parameter(torch.tensor(3.0))

        self.embedding_user = torch.nn.Embedding(
            num_embeddings=self.n_users+1, embedding_dim=self.latend_d
        )

        self.embedding_item = torch.nn.Embedding(
            num_embeddings=self.n_items+1, embedding_dim=self.latend_d
        )

        self.logistic = torch.nn.Sigmoid()

    
    def forward(self, user_id, item_id, timedeltas, weights):
        user_emb = self.embedding_user(user_id).unsqueeze(1)
        item_emb = self.embedding_item(item_id)

        u_bias = self.user_bias(user_id)
        i_bias = self.item_bias(item_id).squeeze(-1)

        # base_dist = torch.sqrt((user_emb - item_emb)**2).sum(dim=2)
        base_dist = torch.norm(user_emb - item_emb, dim=-1)

        lamb = self.global_lamb + self.user_lamb(user_id).unsqueeze(-1)


        # print('ORIG TIMEDELTAS', timedeltas)

        # print('PRE POW DELTAS', (timedeltas + self.cutoff) * weights)

        timedeltas = torch.pow(torch.clamp(timedeltas + self.cutoff, min=1e-6), -0.5)
        # print('TIMEDELTAS 1', timedeltas)
        # if torch.isnan(timedeltas).any():
        #     print(f'crashing {15324/0}')

        timedeltas = timedeltas * weights
        # print('TIMEDELTAS 2', timedeltas)

        # if torch.isnan(timedeltas).any():
        #     print(f'crashing {15324/0}')

        timedeltas = timedeltas * weights

        base_level = lamb * timedeltas

        # print('LAMB', lamb)
        # print('TIMEDELTAS', timedeltas)

        # if torch.isnan(base_level).any():
        #     print('NONE FOUND BASE LEVEL 1 ', base_level)
        #     print(f'crashing {15324/0}')

        # print('PRE SUM', base_level)

        base_level = torch.sum(base_level, axis=2)

        # if torch.isnan(base_level).any:
        #     print('NONE FOUND BASE LEVEL 2 ', base_level)
        #     print(f'crashing {15324/0}')

        output = torch.maximum(torch.zeros_like(base_dist), base_dist - base_level)

        # if torch.isnan(output).any():
        #     print('NONE FOUND OUTPUT', output)
        #     print(f'crashing {15324/0}')

        I = self.alpha * output  + self.beta * torch.pow(output, 2) + self.gamma + u_bias + i_bias
      
        return I

In [3]:
# need to generate all the results i need 
dataset_test = Ex2VecDataset('sorted_data.parquet', 'test_dict.json', 'interactions.h5', sample_negative=-1)

100%|███████████████████████████████████████████████████| 4892757/4892757 [00:30<00:00, 159774.12it/s]
100%|███████████████████████████████████████████████████| 4892757/4892757 [00:30<00:00, 162081.64it/s]


In [4]:
loader_test = DataLoader(dataset_test, batch_size=1024, num_workers = 8, shuffle=False, collate_fn=collate_fn, pin_memory=True)

In [5]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    
config = {
    'n_users' : dataset_test.max_user,
    'n_items' : dataset_test.max_item,
    'latent_d' : 64
}

model = Ex2VecOriginal(config).to(device)

checkpoint = torch.load('orig_model_epoch_29.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
# df_all = pd.DataFrame()

model.eval()
all_predictions = []

pbar_test = tqdm(enumerate(loader_test), total=len(loader_test))

with torch.no_grad():
    for i, batch in pbar_test:
        
        
        if batch is None:
            pbar_test.update(1)
            continue
            
        real = batch['real_values'].to(device)
        user_id = batch['user_id'].to(device)
        predict_items = batch['predict_items'].to(device)
        timedeltas = batch['timedeltas'].to(device)
        weights = batch['weights'].to(device)

        output = model(user_id, predict_items, timedeltas, weights)

        # Convert to numpy
        scores = output.cpu().numpy()
        item_indices = predict_items.cpu().numpy()

        batch_size, item_count = scores.shape
        max_item_index = item_indices.max()

        result = np.full((batch_size, max_item_index + 1), np.nan)  # NaN for missing values

        rows = np.arange(batch_size).reshape(-1, 1)
        result[rows, item_indices] = scores  # Assign each score to the right item index column

        df_batch = pd.DataFrame(result)
        df_batch[0] = predict_items[:, -1].cpu().detach().numpy()

        # # Optionally add user ID to keep track of rows
        # df_batch["user_id"] = user_id.cpu().numpy()
        # df_batch.set_index("user_id", inplace=True)

        all_predictions.append(df_batch)

        # df_all = pd.concat([df_all, df_batch], axis=0, ignore_index=True, copy=False)

        if i % 1000 == 0:
            df_all = pd.concat(all_predictions, axis=0, ignore_index=True, copy=False)
            df_all.to_parquet(f'./original_results/original_results_ordered_{i//1000}.parquet')
            all_predictions = []

        pbar_test.update(1)

df_all = pd.concat(all_predictions, axis=0, ignore_index=True, copy=False)
df_all.to_parquet(f'./original_results/original_results_ordered_17last.parquet')

 42%|█████████████████████████▍                                  | 6910/16269 [34:22<37:27,  4.16it/s]