In [None]:
import os
import sys

project_root = os.path.dirname(os.getcwd())
sys.path.append(f"{project_root}/src")
sys.path.append(f"{project_root}/third_party")

from config import gpt2_cfg as cfg


In [None]:
import ray

if not ray.is_initialized():
    ray.init(
        runtime_env={
            "env_vars": {          
                "PYTHONPATH": "$PYTHONPATH:" + cfg.project_root + "/src",
            },
            "working_dir": cfg.project_root,
            "excludes": [
                "/bazel-*",
                ".git",
                "*.pyc",
                "/__pycache__",
                "/output",
                "/model",
            ],
           
        },
        _metrics_export_port=8080,
    )
# convience for debugging
ray.data.DataContext.get_current().execution_options.verbose_progress = True
ray.data.DataContext.log_internal_stack_trace_to_stdout = True

In [None]:
from pathlib import Path
data_sources = [ Path(item["path"]) for item in cfg["dataset"]]
text_document_paths = ray.data.from_items(data_sources)

In [None]:
from document_processor import TextDocumentProcessor
text_document_processor = TextDocumentProcessor()
texts=text_document_paths.map(text_document_processor)

In [None]:
from token_processor import TikTokenizer
tokenizer = TikTokenizer()
tokens = texts.map(tokenizer)

In [None]:
from chunk_processor import ChunkProcessor
chunk_processor = ChunkProcessor()
chunked_tokens = tokens.flat_map(chunk_processor)

In [None]:
for batch in chunked_tokens.iter_torch_batches(batch_size=cfg.batch_size):
    print(batch)