In [90]:
from typing import Literal
from datasets import Dataset, DatasetDict, load_dataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from collections import defaultdict
import numpy as np
from networkx.algorithms import bipartite
import random
import tqdm
import scipy.sparse as sp
from collections import defaultdict
import heapq

In [3]:
class YambdaDataset:
  # Basically enumerating all possible interactions in the dataset
  INTERACTIONS = frozenset(["likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"])

  def __init__(
      self,
      dataset_type: Literal["flat", "sequential"] = "flat",
      dataset_size: Literal["50m", "500m", "5b"] = "50m"
  ):
    # Initialize the instance of the dataset with two parameters: its type
    # and size
    assert dataset_type in {"flat", "sequential"}
    assert dataset_size in {"50m", "500m", "5b"}
    self.dataset_type = dataset_type
    self.dataset_size = dataset_size

  def interaction(self, event_type: Literal["likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"]) -> Dataset:
    # This function will retrieve us needed table for our interaction type.
    assert event_type in self.INTERACTIONS
    return self._download(f"{self.dataset_type}/{self.dataset_size}", event_type)

  def audio_embeddings(self) -> Dataset:
    # Retrieve audio embeddings
    return self._download("", "embeddings")

  def album_item_mapping(self) -> Dataset:
    # Retrieve album item mapping
    return self._download("", "album_item_mapping")

  def artist_item_mapping(self) -> Dataset:
    # Retrieve artist item mapping
    return self._download("", "artist_item_mapping")

  @staticmethod
  def _download(data_dir: str, file: str) -> Dataset:
    # Function for downloading the needed part of the dataset.
    data = load_dataset("yandex/yambda", data_dir=data_dir, data_files=f"{file}.parquet")
    assert isinstance(data, DatasetDict)
    return data["train"]



In [4]:
dataset = YambdaDataset("flat", "50m")
listens = dataset.interaction("listens")
likes = dataset.interaction("likes")
dislikes = dataset.interaction("dislikes")

README.md: 0.00B [00:00, ?B/s]

flat/50m/listens.parquet:   0%|          | 0.00/369M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

flat/50m/likes.parquet:   0%|          | 0.00/7.18M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

flat/50m/dislikes.parquet:   0%|          | 0.00/990k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
likes_df = likes.to_pandas()


In [6]:
def calculate_items_similarity(item_i, item_j, interaction_df): # s(i, j)
    users_item_i = set(interaction_df[interaction_df['item_id'] == item_i]['uid'].tolist())
    users_item_j = set(interaction_df[interaction_df['item_id'] == item_j]['uid'].tolist())
    intersection = users_item_i.intersection(users_item_j)
    union = users_item_i.union(users_item_j)
    return len(intersection) / len(union)

In [7]:
def build_user_item_matrix(df):
    du = df[['item_id', 'uid']].drop_duplicates()

    item_codes, item_index = pd.factorize(du['item_id'], sort=False)
    user_codes, user_index = pd.factorize(du['uid'], sort=False)

    X = sp.csr_matrix(
        (np.ones(len(du), dtype=np.uint8), (item_codes, user_codes)),
        shape=(len(item_index), len(user_index))
    )

    item_id_to_idx = dict(zip(item_index, np.arange(len(item_index))))
    idx_to_item_id = np.array(item_index)
    item_pop = np.asarray(X.getnnz(axis=1), dtype=np.int32) # | U_i | 

    return X, item_pop, item_id_to_idx, idx_to_item_id

In [8]:
def nearest_items_topk_jaccard_sparse(item_original_id, X, item_pop, item_to_idx, idx_to_item_id, top_k: int = 100):
    i = item_to_idx.get(item_original_id, None)
    if i is None or item_pop[i] == 0:
        return []

    intersection = X[i].dot(X.T).toarray().ravel()
    union = item_pop[i] + item_pop - intersection
    similarity = np.divide(intersection, union, out=np.zeros_like(intersection, dtype=np.float32), where=(union > 0))
    similarity[i] = 0

    k = min(top_k, similarity.size - 1)
    idx = np.argpartition(similarity, -k)[-k:]
    idx = idx[np.argsort(-similarity[idx])]
    similarities_chosen = similarity[idx]

    return [(int(idx_to_item_id[j]), s) for j, s in zip(idx, similarities_chosen)]
    

In [80]:
def build_user_items(user_id, interaction_df):
    user_items = set(
        interaction_df.loc[interaction_df['uid'] == user_id, 'item_id'].unique()
    )
    return user_items

In [81]:
def get_item_rank(item_original_id, user_original_id, item_nearest_elements, interaction_df, user_items):
    num = 0
    denom = 0
    for it_id, sim in item_nearest_elements:
        if it_id == item_original_id:
            continue
        if it_id in user_items:
            num += sim
        denom += abs(sim)

    if denom == 0:
        return 0.0

    return num / denom
    

In [98]:
def recommend_items_to_user(user_id, interactions_df, neighbors_per_liked=50, top_n=100, prebuilt=None, user_items=None):
    if prebuilt is None:
        X, item_pop, item_id_to_idx, idx_to_item_id = build_user_item_matrix(interactions_df)
    else:
        X = prebuilt['X']
        item_pop = prebuilt['item_pop']
        item_id_to_idx = prebuilt['item_id_to_idx']
        idx_to_item_id = prebuilt['idx_to_item_id']

    if not user_items:
        user_items = set(interactions_df.loc[interactions_df['uid'] == user_id, 'item_id'].unique())
        if not user_items:
            return []
    

    seeds = np.array(list(user_items), dtype=np.int64)
    scores = defaultdict(float)

    for seed in seeds:
        neigh = nearest_items_topk_jaccard_sparse(seed, X, item_pop, item_id_to_idx, idx_to_item_id, neighbors_per_liked)
        for nid, s in neigh:
            if nid in user_items or nid == seed:
                continue
            scores[nid] += s
            
    if not scores:
        return []

    top = heapq.nlargest(top_n, scores.items(), key=lambda kv: kv[1])
    return top
        

In [99]:
recommend_items_to_user(100, likes_df)

[(1731125, 0.5493582114577293),
 (2605971, 0.5000000149011612),
 (959902, 0.5000000149011612),
 (3873995, 0.5000000149011612),
 (9095148, 0.5000000149011612),
 (676744, 0.5000000149011612),
 (1582448, 0.5000000149011612),
 (1153922, 0.5000000149011612),
 (3070991, 0.5000000149011612),
 (2557126, 0.5000000149011612),
 (7720625, 0.5000000149011612),
 (1885682, 0.5000000149011612),
 (6807557, 0.5000000149011612),
 (7434085, 0.5000000149011612),
 (6018231, 0.5000000149011612),
 (3949326, 0.5000000149011612),
 (4924097, 0.5000000149011612),
 (3380472, 0.5000000149011612),
 (6400232, 0.5000000149011612),
 (8702829, 0.5000000149011612),
 (8197758, 0.47287582606077194),
 (7092180, 0.4712700694799423),
 (5257351, 0.4527124911546707),
 (4463945, 0.4304981082677841),
 (3713399, 0.4285714328289032),
 (5861543, 0.42364533245563507),
 (9236339, 0.4119230657815933),
 (7534851, 0.4096774160861969),
 (8424907, 0.40789473056793213),
 (3590342, 0.4068613648414612),
 (8276189, 0.4064931571483612),
 (23025