# Online Batch Selection for Faster Training of Neural Network

In [1]:
%reload_ext blackcellmagic

In [40]:
from attr import dataclass
from bpr import utils
from bpr import utils as bpr_utils
from bpr.bpr import bpr_update
from pathlib import Path
from scipy import sparse
from scipy.sparse import coo_matrix
from typing import Dict, List
import bisect
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random

In [41]:
seed = 42

os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)

In [42]:
%%time

# Load data

ratings_df = pd.read_csv('../data/ml-100k/u.data', sep='\t', names=['user_id', 'item_id', 'rating', 'timestamp'])
ratings_df = ratings_df[['user_id', 'item_id', 'rating']]
for column in ['user_id', 'item_id']:
    ratings_df[column] -= 1

CPU times: user 22.2 ms, sys: 10 ms, total: 32.2 ms
Wall time: 63 ms


In [43]:
%%time
# Split data

train_df, test_df = bpr_utils.train_test_split(df=ratings_df, loo=True)
sampler = bpr_utils.NegativeSampler(item_ids=ratings_df["item_id"].unique()).fit(
    train=train_df, test=test_df
)
train_df = sampler.transform(train_df, train=True, size=5)
test_df = sampler.transform(test_df, train=False, size=50)

CPU times: user 6.44 s, sys: 100 ms, total: 6.54 s
Wall time: 6.56 s


In [36]:
@dataclass
class BatchSelector:
    pressure_begin: float = 100
    pressure_final: float = 100
    n_samples: int = 0
    n_epochs: int = 0
    batch_size: int = 0
    c: int = 0
    c_e: int = 0
    c_s: int = 0
    loss_idxs: List[Optional[str]] = []
    loss_vals: np.ndarray = np.zeros(n_samples)
    probs: np.ndarray = np.zeros(n_samples)  # selection probabilities
    probs_cumsum: np.ndarray = np.zeros(n_samples)  # cumulative selection probabilities
    t_s: int = 0  # period of sorting

    def __attrs_post_init__(self):
        print("In post_init")
        self.loss_idxs = [None] * self.n_samples
        self.loss_vals = np.zeros(self.n_samples)

    def __call__(self):
        if self.c == 0:
            # Line 3 of Algorithm 3
            self.c_e = -1 * self.n_samples
            self.c_s = 0

        if self.c - self.c_e > self.n_samples:
            # recompute selection probabilities, once per epoch
            self.c_e = self.c
            epoch = self.c_e / self.n_samples

            # compute s_e (pressure_e) using Eq 5
            pressure_ratio = math.log(self.pressure_final / self.pressure_begin)
            pressure_coeff = math.exp(pressure_ratio / self.n_epochs)
            pressure = self.pressure_begin * math.pow(pressure_coeff, epoch)

            # compute p_i (probabilities) using Eq 4
            prob_coeff = math.exp(math.log(pressure) / self.n_samples)
            self.probs[0] = 1.0
            for i in range(1, self.n_samples):
                self.probs[i] = self.probs[i - 1] / prob_coeff
                # This commented out code should be correct implementation as per
                # the description on the paper. However, in the paper's codebase,
                # the authors left out raising the power.
                # self.probs[i] = math.pow(self.probs[i - 1]/prob_coeff, i)

            self.probs = self.probs / self.probs.sum()
            self.probs_cumsum = np.cumsum(self.probs)

        if self.c - self.c_s > self.t_s:
            # sort data points w.r.t the latest known loss
            self.c_s = self.c
            sorted_by_vals = np.argsort(self.loss_vals)[::-1]
            self.loss_idxs = [self.loss_idx[i] for i in sorted_by_vals]
            self.loss_vals = self.loss_vals[sorted_by_vals]

        sampled_idxs = []
        while len(sampled_idxs) < self.batch_size:
            r = min(random.random(), self.probs_cumsum[-1])
            index = bisect.bisect_right(self.probs_cumsum, r)
            sampled_idxs.append(index)
        return sampled_idxs


In [37]:
selector = BatchSelector(n_samples=4)

In post_init


In [38]:
selector.loss_idxs

[None, None, None, None]

In [39]:
selector()

[]

In [None]:
# probs = np.random.random(10)
# print(probs)
# batch_size = int(0.4 * probs.shape[0])
# probs = probs/probs.sum()
# print(batch_size, 'batch_size')
# probs

# Z = np.cumsum(probs)
# print(Z)

# indexes = []
# print(Z[-1])
# for i in range(probs.shape[0]):
#     r = min(random.random(), Z[-1])

#     index = bisect.bisect_right(Z, r)
#     indexes.append(index)
#     print("i={:>2} | r={:.4f} | index={:>2}".format(i, r, index))
#     if len(set(indexes)) > batch_size:
#         print('breaking')
#         break

# print(set(indexes))