In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import datasets
from tqdm.notebook import tqdm

import utils

In [2]:
train_ds, valid_ds, test_ds = utils.load_dataset()

Found cached dataset civil_comments (/home/johnny/.cache/huggingface/datasets/civil_comments/default/0.9.0/e7a3aacd2ab7d135fa958e7209d10b1fa03807d44c486e3c34897aa08ea8ffab)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
score_npy = 'scores/roberta_balanced.npy'
score_column = 'scores'

bins = 8
cut = 'quantile'
allocation = 'pilot'
pilot_size = 50

In [4]:
labels = test_ds['toxicity']
scores = np.load(score_npy)
df = pd.DataFrame(data={'toxicity' : labels, 'toxic' : test_ds['label'], 'scores' : scores})

In [5]:
df.head(1)

Unnamed: 0,toxicity,toxic,scores
0,0.0,0,0.001761


In [6]:
df['toxic'].sum(), df['toxic'].mean()

(106438, 0.05897253769515213)

In [7]:
size = 12192

In [8]:
p = df['toxic'].mean()
random_sampling_var = p * (1 - p)
np.sqrt(random_sampling_var / size)

0.002133480020786072

### Cutting

In [9]:
def get_error(df):
    sizes_sigmas = []
    for i, group in df.groupby('bin'):
        # to round up to 1
        pilot = group['toxic']
        sizes_sigmas.append((len(group), np.std(pilot)))

    allocations = []
    denominator = np.sum([ n_h * sigma_h for n_h, sigma_h in sizes_sigmas ])
    for n_h, sigma_h in sizes_sigmas:
        n_from_bin = size * n_h * sigma_h / denominator
        allocations.append(n_from_bin)
    
    stratified_var = 0
    for (i, group), n_from_bin in zip(df.groupby('bin'), allocations):

        p = group['toxic'].mean()

        # approximation when the groups are very large
        stratified_var += np.square(len(group) / len(df)) * (group['toxic'].var() / n_from_bin)
    return np.sqrt(stratified_var)

In [10]:
def oracle_bins(df, depth=4):
    minimum_bins = [0, 1]
    minimum_err = 1
    steps = 10

    for level in range(1, depth+1):
        if level == 1:
            indexes = [1]
        else:
            indexes = list(range(1, level*2, 2))

        for idx in indexes:
            bins = minimum_bins.copy()
            bins.insert(idx, 0)
            lb, ub = bins[idx-1], bins[idx+1]

            # print(lb + (ub-lb)/steps, ub)

            for i in np.linspace( lb + (ub-lb)/steps, ub, steps, endpoint=False):
                bins[idx] = i
                df['bin'] = pd.qcut(df[score_column], bins)
                new_err = get_error(df)
                if new_err < minimum_err:
                    minimum_err = new_err
                    minimum_bins = bins.copy()

    return minimum_bins

In [11]:
if cut == 'eqwidth':
    minimum, maximum = df[score_column].min(), df[score_column].max()
    df['bin'] = pd.cut(df[score_column], np.linspace(minimum, maximum, num=bins+1), include_lowest=True)
elif cut == 'quantile':
    df['bin'] = pd.qcut(df[score_column], np.linspace(0, 1, num=bins+1))
elif cut == 'oracle':
    depth = int(np.log(bins) / np.log(2))
    b = oracle_bins(df, depth = depth)
    df['bin'] = pd.qcut(df[score_column], b)

### Allocation

In [12]:
sizes_sigmas = []
for i, group in df.groupby('bin'):
    # to round up to 1
    if allocation == 'pilot':
        pilot = np.array(group['toxic'].sample(pilot_size + 2))
        pilot[-1] = 1
        pilot[-2] = 0
    elif allocation == 'optimal':
        pilot = group['toxic']
    sizes_sigmas.append((len(group), np.std(pilot)))

allocations = []
denominator = np.sum([ n_h * sigma_h for n_h, sigma_h in sizes_sigmas ])
for n_h, sigma_h in sizes_sigmas:
    n_from_bin = size * n_h * sigma_h / denominator
    print(n_h, sigma_h, n_from_bin)
    allocations.append(n_from_bin)

225612 0.09900990099009901 903.0895550780031
225607 0.09900990099009901 903.0695408598968
225609 0.09900990099009901 903.0775465471393
225609 0.09900990099009901 903.0775465471393
225609 0.09900990099009901 903.0775465471393
225609 0.09900990099009901 903.0775465471393
225609 0.27006300390071003 2463.2671332556174
225610 0.4725583985656127 4310.2635846179255


In [13]:
minimum = np.min(allocations)
multipliers = [ i / minimum for i in allocations ]
multipliers

[1.000022162432903,
 1.0,
 1.000008864973161,
 1.000008864973161,
 1.000008864973161,
 1.000008864973161,
 2.7276605198201134,
 4.772903292158122]

In [14]:
stratified_var = 0
for (i, group), n_from_bin in zip(df.groupby('bin'), allocations):
    print(i, len(group), n_from_bin, group['toxic'].mean(), group['toxic'].var())
    
    p = group['toxic'].mean()
    
    # approximation when the groups are very large
    stratified_var += np.square(len(group) / len(df)) * (group['toxic'].var() / n_from_bin)
'stderr: ', np.sqrt(stratified_var)

(0.0005200000000000001, 0.0018] 225612 903.0895550780031 0.0001772955339255004 0.0001772648859266667
(0.0018, 0.00198] 225607 903.0695408598968 0.0005186009299356846 0.000518334280519157
(0.00198, 0.0023] 225609 903.0775465471393 0.0008909219047112482 0.0008901321083340531
(0.0023, 0.00306] 225609 903.0775465471393 0.0015336267613437406 0.0015312815376297766
(0.00306, 0.00635] 225609 903.0775465471393 0.0032135242831624624 0.003203211742919281
(0.00635, 0.0573] 225609 903.0775465471393 0.010473872939466068 0.010364216863925926
(0.0573, 0.916] 225609 2463.2671332556174 0.04925335425448453 0.04682766891045865
(0.916, 0.997] 225610 4310.2635846179255 0.4057178316563982 0.241111941443611


('stderr: ', 0.0012082062756475468)

In [15]:
numerator = 0
for (i, group), multip_h in zip(df.groupby('bin'), multipliers):
    print(i, len(group), multip_h, group['toxic'].mean(), group['toxic'].var())
    
    # approximation when the groups are very large
    numerator += np.square(len(group) / len(df)) * (group['toxic'].var() / multip_h)

(0.0005200000000000001, 0.0018] 225612 1.000022162432903 0.0001772955339255004 0.0001772648859266667
(0.0018, 0.00198] 225607 1.0 0.0005186009299356846 0.000518334280519157
(0.00198, 0.0023] 225609 1.000008864973161 0.0008909219047112482 0.0008901321083340531
(0.0023, 0.00306] 225609 1.000008864973161 0.0015336267613437406 0.0015312815376297766
(0.00306, 0.00635] 225609 1.000008864973161 0.0032135242831624624 0.003203211742919281
(0.00635, 0.0573] 225609 1.000008864973161 0.010473872939466068 0.010364216863925926
(0.0573, 0.916] 225609 2.7276605198201134 0.04925335425448453 0.04682766891045865
(0.916, 0.997] 225610 4.772903292158122 0.4057178316563982 0.241111941443611


In [16]:
p = 0.05897253769515213
print(p, end=' ')

for within in [0.2, 0.1, 0.05]:
    desired_ci = p * within

    alpha = 0.05
    z_statistic = stats.norm.ppf(1 - (alpha / 2))
    desired_var = np.square(desired_ci / z_statistic)

    minimum = numerator / desired_var
    n = np.sum([minimum * multip for multip in multipliers])
    print(int(n+1), end=' ')

0.05897253769515213 492 1966 7864 

In [17]:
cached_groupby = list(df.groupby('bin'))

In [18]:
p = 0.05897253769515213
print(p, end=' ')

for within in [0.2, 0.1, 0.05]:
    ns = []
    for i in tqdm(range(0, 1000)):
        sizes_sigmas = []
        for i, group in cached_groupby:
            # to round up to 1
            pilot = np.array(group['toxic'].sample(pilot_size+1))
            pilot[-1] = 1

            sizes_sigmas.append((len(group), np.std(pilot)))

        allocations = []
        denominator = np.sum([ n_h * sigma_h for n_h, sigma_h in sizes_sigmas ])
        for n_h, sigma_h in sizes_sigmas:
            proportion = n_h * sigma_h / denominator
            allocations.append(proportion)
            
        
        numerator = 0
        for (i, group), multip_h in zip(cached_groupby, allocations):
            # approximation when the groups are very large
            numerator += np.square(len(group) / len(df)) * (group['toxic'].var() / multip_h)

        desired_ci = p * within

        alpha = 0.05
        z_statistic = stats.norm.ppf(1 - (alpha / 2))
        desired_var = np.square(desired_ci / z_statistic)

        minimum = numerator / desired_var
        l = [minimum * multip for multip in allocations]
        l = [ np.max([pilot_size, n]) for n in l ]
        n = np.sum(l)
        ns.append(n)
    print(int(np.mean(ns)+1), np.std(ns), end=' ')

0.05897253769515213 

  0%|          | 0/1000 [00:00<?, ?it/s]

571 15.260134694649501 

  0%|          | 0/1000 [00:00<?, ?it/s]

2230 78.78822477266212 

  0%|          | 0/1000 [00:00<?, ?it/s]

8921 317.4244427512398 