In [None]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from tqdm import tqdm

In [None]:
available_jax_models = [
    'erfanzar/FlaxMpt-7B',
    'erfanzar/FlaxMpt-1B',
    'erfanzar/FlaxFalcon',
    'erfanzar/JaxLLama',
    'erfanzar/GT-J'
]

In [None]:
model_id = '<YOUR_MODEL_ID_HERE>'
push_to = '<HUGGINGFACE_REPO_NAME_TO_PUSH_DATASET>'
tokenizer_id = "<TOKENIZER_ID>"
data_set_name = '<DATASET_NAME_TO_TOKENIZE>'
use_padding = True
block_size = 2048
use_over = True
pretrain = False
minimum_length = block_size // 2
pre_train_pick_up = 1  # between 0.0 - 1.0

In [None]:
assert model_id in available_jax_models

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
data = load_dataset(data_set_name)

In [None]:
data = data.map(
    lambda x: tokenizer(x['news'], ),
    remove_columns=data['train'].column_names,
    batched=True,
    batch_size=1000
)


def simple_chunk(input_ids_, attention_mask_, chunk=512, drop_last=True):
    input_ids = []
    attention_mask = []

    for current_chunk in range(0, len(attention_mask_), chunk):
        try:
            input_ids.append(input_ids_[current_chunk:current_chunk + chunk])
            attention_mask.append(attention_mask_[current_chunk:current_chunk + chunk])
        except KeyError:
            if not drop_last:
                input_ids.append(input_ids_[current_chunk:])
                attention_mask.append(attention_mask_[current_chunk:])
    if len(input_ids[-1]) != chunk and not drop_last:
        rem = chunk - len(input_ids[-1])
        added_remo = [0 for _ in range(rem)]
        input_ids[-1] += added_remo
        attention_mask[-1] += added_remo
    return input_ids, attention_mask


In [None]:
if use_padding and use_over:
    def g_gen():
        for i in data['train']:
            input_ids = i['input_ids']
            if len(input_ids) > minimum_length:
                origin = tokenizer(tokenizer.decode(input_ids), max_length=block_size, padding='max_length')
                yield origin


    data_set = DatasetDict({'train': Dataset.from_generator(g_gen)})
    data_set.push_to_hub(push_to)
elif use_padding and not use_over:
    data.push_to_hub(push_to)
elif pretrain:
    i, a = [], []
    for s in tqdm(data['train']):
        i += s['input_ids']
        a += s['attention_mask']
    ii, aa = simple_chunk(i, a, block_size, False)
    len_ = int(len(aa) * pre_train_pick_up)


    def gen():
        for I, A in tqdm(zip(ii[:len_], aa[:len_]), total=len_):
            yield {'input_ids': I, 'attention_mask': A}


    data_set = DatasetDict({'train': Dataset.from_generator(gen)})
    data.push_to_hub(push_to)
else:
    raise ValueError