In [1]:
import pathlib
from argparse import ArgumentParser

import numpy as np
import pandas as pd
import torch as th
import sklearn as sk

from pytorch_lightning import seed_everything, Trainer
from ranking_metrics_torch.precision_recall import precision_at, recall_at
from ranking_metrics_torch.cumulative_gain import ndcg_at
from torch_factorization_models.implicit_mf import ImplicitMatrixFactorization
from torch_factorization_models.movielens import MovielensDataset, MovielensDataModule

from practicalrecs_examples.matrix_factorization import load_model

In [2]:
seed_everything(42)  # same seed used to create splits in training

42

In [3]:
movielens_module = MovielensDataModule("/home/karl/Projects/datasets/ml-25m/")
movielens_module.setup()

In [4]:
movielens = movielens_module.dataset

### Interaction Counts

In [5]:
interaction_counts = th.bincount(movielens.data[:,0])

In [6]:
interaction_counts.max()

tensor(5525)

In [7]:
interaction_counts.median()

tensor(40)

In [8]:
interaction_counts.to(dtype=th.float32).mean()

tensor(76.7073)

In [9]:
interaction_counts.sort(descending=True)

torch.return_types.sort(
values=tensor([5525, 4733, 3215,  ...,    1,    1,    1]),
indices=tensor([ 75212,  72226, 110836,  ...,  59032,  40527, 147085]))

In [24]:
interaction_counts.sort(descending=True)[0][2]

tensor(3215)

### Bloom filter sizes

In [11]:
import math

def compute_bytes(capacity, error_rate):
    num_hashes = max(math.floor(math.log2(1 / error_rate)), 1)
    bits_per_hash = math.ceil(
                capacity * abs(math.log(error_rate)) /
                (num_hashes * (math.log(2) ** 2)))
    num_bits = max(num_hashes * bits_per_hash,128)
    return num_bits//8

In [12]:
compute_bytes(28, 0.1)

16

In [13]:
compute_bytes(100, 0.1)

60

In [14]:
compute_bytes(100, 0.01)

120

In [15]:
compute_bytes(100, 0.001)

180

In [16]:
compute_bytes(1000, 0.01)

1198

In [17]:
compute_bytes(10000, 0.1)

5991

In [18]:
compute_bytes(10000, 0.01)

11982

### Filter Size Plots

In [None]:
flattened_sizes = list(filtering_plot_sizes.flatten())
flattened_recalls = list(filtering_plot_recalls.flatten())

unsorted_points = list(zip(flattened_sizes, flattened_recalls))

sorted_points = sorted(unsorted_points, key = lambda x: x[1])

sorted_sizes, sorted_recalls = list(zip(*sorted_points))

plt.figure(num=1, dpi=150, facecolor='w', edgecolor='k')
plt.plot(sorted_recalls, sorted_sizes, label="Flattened")
plt.vlines(0.18225, filtering_plot_sizes[0][0], filtering_plot_sizes[-1][-1], colors='k', linestyles='dashed', label='Ideal')
plt.yscale("log")
plt.ylabel("Filter Size (Bytes)")
plt.xlabel("Recall@100")
plt.legend()