In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
from os.path import isfile
from typing import List
from collections import Counter
import ipywidgets as widgets
from ipywidgets import interact

from src.sampling import reservoir_sampling, reservoir_sampling_original

In [None]:
path_original_data = "../src/tests/data/test_data_original/test.jsonl"
assert isfile(path_original_data)

# Functions

In [None]:
def lines2int(_lines: List[str]) -> List[int]:
    """convert jsonl lines to integers
       e.g. ["28\n", "11\n"] -> [28, 11]
    """
    return [int(line.strip("\n")) for line in _lines]

def plot_count(_counter: Counter, _exclude: List[int], _sample_size: int, _sample_number: int):
    counter_list = [_counter[i] for i in range(100)]
    item_number = sum(counter_list)
    
    fig, ax = plt.subplots(1, 1, figsize=(8,6))
    ax.plot([0, 100], [item_number/(100-len(_exclude)), item_number/(100-len(_exclude))], color="k", linestyle=":", label="expectation value")
    
    ax.plot(counter_list, color="k", label="empirical value")
    ax.set_title(f"sampled data")
    ax.set_ylim([0, None])
    ax.legend()
    
    counter_list_nonzero = [elem for elem in counter_list if elem > 0]
    print(f"> sample size = {_sample_size}, #samples = {_sample_number} => #items={item_number}")
    print("")
    print("Output:")
    print(f"> min = {min(counter_list_nonzero)}, max = {max(counter_list_nonzero)}, mean = {np.mean(counter_list_nonzero):.1f}")
    print(f"> non-zero counts = {len(counter_list_nonzero)}, zero counts = {100-len(counter_list_nonzero)}")

# Reservoir Sampling

In [None]:
@interact
def function_reservoir_sampling(sample_size=[20, 40], sample_number=[10000, 50000], evaluation=[False, True]):
    data_train = []
    exclude = sorted(random.sample(range(0, 100), k=sample_size)) if evaluation else []
    print("> Input:")
    print(f"> {len(exclude)} excluded indices: {exclude}")
    for r in range(sample_number):
        with open(path_original_data) as infile:
            sample, sample_indices = reservoir_sampling(infile, sample_size, exclude=exclude)
            data_train.extend(lines2int(sample))

    counter = Counter(data_train)

    for index in exclude:
        assert counter[index] == 0

    plot_count(counter, exclude, sample_size, sample_number)

# Reservoir Sampling Original

In [None]:
@interact
def function_reservoir_sampling_original(sample_size=[20, 40], sample_number=[10000, 50000]):
    data_train = []
    print("> Input:")
    for r in range(sample_number):
        with open(path_original_data) as infile:
            sample = reservoir_sampling_original(infile, sample_size)
            data_train.extend(lines2int(sample))

    counter = Counter(data_train)

    plot_count(counter, [], sample_size, sample_number)