In [1]:
from src import datasets_loader
import torch
import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def get_loader(max_raw_len, max_token_len, batch_size):

    train_data = datasets_loader.get_dataset(
            dataset_name="bigcode/the-stack-march-sample",
            path_to_cache="/mnt/colab_public/datasets/joao/bigcode/the-stack-march-sample",
            split="train",
            maximum_raw_length=max_raw_len,
        )
    
    collate_fn = datasets_loader.Collator(
        tokenizer_path="bigcode/tokenizer-the-stack-march-sample",
        maximum_length=max_token_len,
        mlm_masking_probability=0.15,
        contrastive_masking_probability=0.3,
        ignore_contrastive_loss_data=True,
    )

    data_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=False,
    )

    return data_loader

In [3]:
def get_token_fraction_series(max_raw_len, max_token_len, batch_size):

    loader = get_loader(max_raw_len, max_token_len, batch_size)

    padding_ratio_series = []

    for batch in loader:
        input_ids = batch[0]
        padding_count = (input_ids==loader.collate_fn.pad_token_id).float().sum()
        padding_ratio_series.append(padding_count/input_ids.numel())
    
    return padding_ratio_series

In [4]:
max_raw_len_list = [1000, 5000, 10000, 20000, 50000]
max_token_len_list = [512, 784, 1024, 2048, 4096]
batch_size = 8000

In [5]:
results_dict = {}

for max_token_len in tqdm.notebook.tqdm(max_token_len_list, desc="Token len"):
    for max_raw_len in tqdm.notebook.tqdm(max_raw_len_list, desc="Raw len"):
        padding_fraction_series = get_token_fraction_series(max_raw_len, max_token_len, batch_size)

        try:
            results_dict[max_token_len][max_raw_len] = padding_fraction_series
        except KeyError:
            results_dict[max_token_len] = {max_raw_len: padding_fraction_series}


Token len:   0%|          | 0/5 [00:00<?, ?it/s]

Raw len:   0%|          | 0/5 [00:00<?, ?it/s]

Found cached dataset parquet (/mnt/colab_public/datasets/joao/bigcode/the-stack-march-sample/bigcode___parquet/bigcode--the-stack-march-sample-ba0d1a1a229e8720/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /mnt/colab_public/datasets/joao/bigcode/the-stack-march-sample/bigcode___parquet/bigcode--the-stack-march-sample-ba0d1a1a229e8720/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8516320185f20da3.arrow
Found cached dataset parquet (/mnt/colab_public/datasets/joao/bigcode/the-stack-march-sample/bigcode___parquet/bigcode--the-stack-march-sample-ba0d1a1a229e8720/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /mnt/colab_public/datasets/joao/bigcode/the-stack-march-sample/bigcode___parquet/bigcode--the-stack-march-sample-ba0d1a1a229e8720/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-85e33cff9f85f34b.arrow


In [None]:
def get_confidence_interval(data_series):
    def ci95(seq):
        return 1.96 * np.std(seq) / np.sqrt(len(seq))
    return np.mean(data_series), ci95(data_series)

In [None]:
plt.figure(facecolor="white")
for max_token_len in max_token_len_list:
    ci_center_series = []
    ci_extrema_series = []
    for max_raw_len in max_raw_len_list:
        center, error = get_confidence_interval(results_dict[max_token_len][max_raw_len])

        ci_center_series.append(center)
        ci_extrema_series.append(error)
        
    plt.plot(max_raw_len_list, ci_center_series, label=f"Max len: {max_token_len}")
    plt.errorbar(max_raw_len_list, ci_center_series, yerr=ci_extrema_series, fmt ='o')

plt.xlabel = "Max input length"
plt.ylabel = "Padding fraction"
plt.legend()
plt.show()
        