In [None]:
import os
import sys

project_root = os.path.dirname(os.getcwd())
sys.path.append(f"{project_root}/src")
sys.path.append(f"{project_root}/third_party")

os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_COLOR_PREFIX"] = "0"

from config import gpt2_cfg as cfg


In [None]:
import ray
if ray.is_initialized():
    ray.shutdown()

ray.init(
        runtime_env={
            "env_vars": {          
                "PYTHONPATH": "$PYTHONPATH:" + cfg.project_root + "/src",
                "RAY_DATA_VERBOSE_PROGRESS": "1",
            },
            "working_dir": cfg.project_root,
            "excludes": [
                "/bazel-*",
                ".git",
                "*.pyc",
                "/__pycache__",
                "/outputs",
                "/model",
            ],
        },
        ignore_reinit_error=True,
    )

# convience for debugging
ray.data.DataContext.get_current().execution_options.verbose_progress = False
ray.data.DataContext.log_internal_stack_trace_to_stdout = True

In [None]:
from pathlib import Path
data_sources = [ Path(item["path"]) for item in cfg["dataset"]]
text_document_paths = ray.data.from_items(data_sources)

In [None]:
from document_processor import TextDocumentProcessor
train_text_document_processor = TextDocumentProcessor(section="train")
train_texts=text_document_paths.map(train_text_document_processor)

validate_text_document_processor = TextDocumentProcessor(section="validate")
validate_texts=text_document_paths.map(validate_text_document_processor)


In [None]:
from token_processor import TikTokenizer
tokenizer = TikTokenizer()
train_tokens = train_texts.map(tokenizer)
validate_tokens = validate_texts.map(tokenizer)

In [None]:
from chunk_processor import ChunkProcessor

chunk_processor = ChunkProcessor(max_length=cfg["124M"]["block_size"],stride=cfg["124M"]["stride"])
train_chunked_tokens = train_tokens.flat_map(chunk_processor)
validate_chunked_tokens = validate_tokens.flat_map(chunk_processor)

In [41]:
iter= train_chunked_tokens.iterator()

iteable= iter.iter_torch_batches(batch_size =2)

print(next(iteable.iterator_gen()))

print(next(iteable.iterator_gen()))









2024-08-07 15:26:18,552	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-07_15-09-39_319166_12085/logs/ray-data
2024-08-07 15:26:18,554	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor)]


Running 0: 0 bundle [00:00, ? bundle/s]

2024-08-07 15:26:19,364	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-07_15-09-39_319166_12085/logs/ray-data
2024-08-07 15:26:19,367	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor)]


{'input_ids': tensor([[ 5962, 22307,    25,  ...,  1842,   484,  6842],
        [  514,    13,   198,  ...,   275,  1000,    13]], device='cuda:0'), 'target_ids': tensor([[22307,    25,   198,  ...,   484,  6842,   514],
        [   13,   198,   198,  ...,  1000,    13,   198]], device='cuda:0')}


Running 0: 0 bundle [00:00, ? bundle/s]

{'input_ids': tensor([[ 5962, 22307,    25,  ...,  1842,   484,  6842],
        [  514,    13,   198,  ...,   275,  1000,    13]], device='cuda:0'), 'target_ids': tensor([[22307,    25,   198,  ...,   484,  6842,   514],
        [   13,   198,   198,  ...,  1000,    13,   198]], device='cuda:0')}


In [None]:
import torch
from torchmetrics.text import Perplexity

import ray
import ray.train

from model.GPT import GPT
from utility import save_checkpoint, resume_checkpoint
from text_generator import TextGenerator


def train_loop_per_worker(config):
    vocab_size = config["vocab_size"]
    dimension_embedding = config["dimension_embedding"]
    block_size = config["block_size"]
    num_layers = config["num_layers"]
    num_headers = config["num_headers"]
    drop_rate = config["drop_rate"]
    qkv_bias = config["qkv_bias"]
    check_frequency = config["check_frequency"]
    batch_size_per_worker = config["batch_size_per_worker"]
    num_epoch_per_worker = config["num_epoch_per_worker"]
    resume_training = config["resume_training"]
    best_checkpoint_dir = config["best_checkpoint_dir"]
    start_context = config["start_context"]

    # GPT model
    model = GPT(
        vocab_size,
        dimension_embedding,
        block_size,
        num_layers,
        num_headers,
        drop_rate,
        qkv_bias,
    )
    model = ray.train.torch.prepare_model(model)

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

    # ====== Resume training state from the checkpoint. ======
    epoch_start = 0
    best_perplexity = 1000000.0
    best_epoch = 0

    if resume_training:
        if os.path.exists(best_checkpoint_dir):
            checkpoint = ray.train.Checkpoint.from_directory(best_checkpoint_dir)
        else:
            checkpoint = None
        if checkpoint:
            best_epoch, best_perplexity = resume_checkpoint(model, optimizer, checkpoint)
            epoch_start = best_epoch
            print(
                f"Resumed training from best_epoch {best_epoch},best_perplexity {best_perplexity}"
            )
        else:
            print(f"Checkpoint not found, starting from epoch 0")

    # loss function
    loss_function = torch.nn.CrossEntropyLoss()

    rank = ray.train.get_context().get_world_rank()
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

    # metrics
    metric = Perplexity().to(device)

    # data
    train_data_shard = ray.train.get_dataset_shard("train")
    validate_data_shard = ray.train.get_dataset_shard("validate")

    report_metrics = {
        "epoch": 0,
        "train_loss": 0.0,
        "validate_loss": 0.0,
        "perplexity": 0.0,
        "best_epoch": best_epoch,
        "best_perplexity": best_perplexity,
    }
 
 
    text_generator = TextGenerator(model,device=device)
    tokenizer = TikTokenizer()
    for epoch in range(epoch_start + 1, num_epoch_per_worker + 1):
        
        print(f"epoch:{epoch}")
        
        model.train()

        report_metrics["epoch"] = epoch

        train_loss = 0
        batch_count = 0
        for batch in train_data_shard.iter_torch_batches(
            batch_size=batch_size_per_worker,
            drop_last=True,
            local_shuffle_buffer_size=1000,
        ):
            
            batch_count += 1
            input_ids = batch["input_ids"]
            logits = model(input_ids)
            
            target_ids = batch["target_ids"]
            loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
            train_loss += loss.item()  # only for reporting
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss = train_loss / batch_count

        report_metrics["train_loss"] = train_loss

        validate_loss = 0
        perplexity = 0

        if epoch % check_frequency == 0:

            model.eval()
 

            with torch.no_grad():
                batch_count = 0
                for batch in validate_data_shard.iter_torch_batches(
                    batch_size=1,
                    drop_last=False,
                ):
                    batch_count += 1
                    input_ids = batch["input_ids"]
                    logits = model(input_ids)
                    target_ids = batch["target_ids"]
                    loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
                    validate_loss += loss.item()  # only for reporting
                    metric.update(logits, target_ids)

            validate_loss = validate_loss / batch_count
            perplexity = metric.compute().item()
            metric.reset()

            report_metrics["validate_loss"] = validate_loss
            report_metrics["perplexity"] = perplexity


            # In standard DDP training, where the model is the same across all ranks,
            # only the global rank 0 worker needs to save and report the checkpoint
            if ray.train.get_context().get_world_rank() == 0:
                if perplexity < best_perplexity:
                    best_perplexity = perplexity
                    best_epoch = epoch
                    
                    report_metrics["best_epoch"] = best_epoch
                    report_metrics["best_perplexity"] = best_perplexity


                    # create the best_checkpoint_dir if it does not exist
                    # if not os.path.exists(best_checkpoint_dir):
                    #     os.makedirs(best_checkpoint_dir)

                    # save_checkpoint(
                    #     model,
                    #     optimizer,
                    #     epoch,
                    #     perplexity,
                    #     best_checkpoint_dir,
                    # )

            ray.train.report(metrics=report_metrics)
                
        decoded = text_generator(tokenizer.encode(start_context), max_new_tokens=50, context_size=block_size)
        print(f"\n epoch{epoch}:{decoded}")

In [None]:
from ray.train.torch import TorchTrainer
from ray.train import Result
from ray import train

train_loop_config = {
    "vocab_size": cfg["124M"]["vocab_size"],
    "dimension_embedding": cfg["124M"]["dimension_embedding"],
    "block_size": cfg["124M"]["block_size"],
    "num_layers": cfg["124M"]["num_layers"],
    "num_headers": cfg["124M"]["num_headers"],
    "drop_rate": cfg["124M"]["drop_rate"],
    "qkv_bias": cfg["124M"]["qkv_bias"],
    "check_frequency": cfg["ray_train"]["check_frequency"],
    "batch_size_per_worker": cfg["ray_train"]["batch_size_per_worker"],
    "num_epoch_per_worker": cfg["ray_train"]["num_epoch_per_worker"],
    "resume_training":cfg["ray_train"]["resume_training"],
    "best_checkpoint_dir":cfg["ray_train"]["best_checkpoint_dir"],
    "start_context":cfg["ray_train"]["start_context"],
}


trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    datasets={
        "train": train_chunked_tokens,
        "validate": validate_chunked_tokens,
    },
    scaling_config=ray.train.ScalingConfig(
        num_workers=cfg["ray_train"]["num_workers"],
        use_gpu=cfg["ray_train"]["use_gpu"],
        resources_per_worker={
            "CPU": cfg["ray_train"]["num_cpus_per_worker"],
            "GPU": cfg["ray_train"]["num_gpus_per_worker"],
        },
    ),
    run_config=train.RunConfig(
        storage_path=cfg["ray_train"]["storage_path"],
        name=cfg["ray_train"]["name"],
    ),
    
)
result: Result = trainer.fit()
print(result.metrics)