In [1]:
%load_ext dotenv
%dotenv

import os

%cd {os.getenv("PROJECT_PATH") or "."}

%load_ext autoreload
%autoreload 1

from IPython.display import display


/home/aris/projects/evagpt


In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import sys
from pathlib import Path
from absl import logging
from tqdm.notebook import tqdm
from timeit import default_timer as timer
import pickle

logging.set_verbosity(logging.INFO)


In [3]:
from pandarallel import pandarallel

pandarallel.initialize(
    nb_workers=os.cpu_count(),
    progress_bar=True,
    verbose=0
)


In [4]:
def show_df(df: pd.DataFrame):
    display(df.head())
    print(df.shape)


In [12]:
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase
import multiprocessing as mp
from itertools import chain


def get_datasets(
    tokenizer: PreTrainedTokenizerBase,
    dataset_path: str,
    dataset_name: str,
    block_size: int = 1024,
):
    def preprocess_function(examples):
        tokenized = tokenizer([s + "\n\n" for s in examples["text"]])

        return tokenized

    def group_texts(examples):
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    dataset_raw = load_dataset(dataset_path, dataset_name)
    dataset = dataset_raw.map(
        preprocess_function,
        batched=True,
        num_proc=mp.cpu_count(),
        load_from_cache_file=True,
        remove_columns="text",
        desc=f"Tokenizing {dataset_name} dataset",
    ).map(
        group_texts,
        batched=True,
        num_proc=mp.cpu_count(),
        load_from_cache_file=True,
        desc=f"Grouping texts in chunks of {block_size}",
    )

    return dataset

In [13]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

datasets = get_datasets(
    tokenizer, "wikitext", "wikitext-2-raw-v1", block_size=1024
)

In [None]:

import math
import numpy as np
from datasets import Dataset
import jax.random as jrandom

def data_loader(rng: jrandom.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jrandom.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch



In [None]:
from flax.training import common_utils, jax_utils

lm_ds = datasets.with_format("jax")
lm_train_dataset = lm_ds["train"]

for epoch in range(10):
    ds_epoch = lm_train_dataset.shuffle(seed=epoch)

    def data_stream():
        for batch in ds_epoch:
            yield common_utils.shard(batch)
            
    train_iter = jax_utils.prefetch_to_device(

{'input_ids': array([ 628,  796,  569, ..., 1998, 2173,  389]),
 'attention_mask': array([1, 1, 1, ..., 1, 1, 1]),
 'labels': array([ 628,  796,  569, ..., 1998, 2173,  389])}