In [None]:
import os
import sys

project_root = os.path.dirname(os.getcwd())
sys.path.append(f"{project_root}/src")
sys.path.append(f"{project_root}/third_party")

from config import gpt2_cfg as cfg


In [None]:
import ray

if not ray.is_initialized():
    ray.init(
        runtime_env={
            "env_vars": {          
                "PYTHONPATH": "$PYTHONPATH:" + cfg.project_root + "/src",
            },
            "working_dir": cfg.project_root,
            "excludes": [
                "/bazel-*",
                ".git",
                "*.pyc",
                "/__pycache__",
                "/outputs",
                "/model",
            ],
        },
        _metrics_export_port=8080,
    )
# convience for debugging
ray.data.DataContext.get_current().execution_options.verbose_progress = True
ray.data.DataContext.log_internal_stack_trace_to_stdout = True

In [None]:
from pathlib import Path
data_sources = [ Path(item["path"]) for item in cfg["dataset"]]
text_document_paths = ray.data.from_items(data_sources)

In [None]:
from document_processor import TextDocumentProcessor
train_text_document_processor = TextDocumentProcessor(section="train")
train_texts=text_document_paths.map(train_text_document_processor)

validate_text_document_processor = TextDocumentProcessor(section="validate")
validate_texts=text_document_paths.map(validate_text_document_processor)


In [None]:
from token_processor import TikTokenizer
tokenizer = TikTokenizer()
train_tokens = train_texts.map(tokenizer)
validate_tokens = validate_texts.map(tokenizer)

In [None]:
from chunk_processor import ChunkProcessor
chunk_processor = ChunkProcessor()
train_chunked_tokens = train_tokens.flat_map(chunk_processor)
validate_chunked_tokens = validate_tokens.flat_map(chunk_processor)

In [None]:
import tempfile

import torch
from torcheval.metrics.text import Perplexity

import ray
from model.GPT import GPT
from utility import save_checkpoint, resume_checkpoint



train_loop_config = {
    "vocab_size": cfg["124M"]["vocab_size"],
    "dimension_embedding": cfg["124M"]["dimension_embedding"],
    "block_size": cfg["124M"]["block_size"],
    "n_layers": cfg["124M"]["n_layers"],
    "num_header": cfg["124M"]["num_header"],
    "drop_rate": cfg["124M"]["drop_rate"],
    "qkv_bias": cfg["124M"]["qkv_bias"],
    "check_frequency": cfg["ray_train"]["check_frequency"],
    "batch_size_per_worker": cfg["ray_train"]["batch_size_per_worker"],
    "epoch_start": cfg["ray_train"]["epoch_start"],
    "num_epoch_per_worker": cfg["ray_train"]["num_epoch_per_worker"]
}

def train_loop_per_worker(config):
    vocab_size=config["vocab_size"] 
    dimension_embedding=config["dimension_embedding"]
    block_size=config["block_size"]
    n_layers=config["n_layers"]
    num_header=config["num_header"] 
    drop_rate=config["drop_rate"]
    qkv_bias=config["qkv_bias"]
    check_frequency = config["check_frequency"]
    batch_size_per_worker=config["batch_size_per_worker"]
    epoch_start=config["epoch_start"]
    num_epoch_per_worker=config["num_epoch_per_worker"]
    
    
    # GPT model 
    model = GPT(vocab_size, dimension_embedding, block_size,n_layers, num_header, drop_rate, qkv_bias)
    model = ray.train.torch.prepare_model(model)
    
    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
    
    # ====== Resume training state from the checkpoint. ======
    epoch_start = 0
    
    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
       epoch_start = resume_checkpoint(model, optimizer, checkpoint)
    
    
    # loss function
    loss_function =  torch.nn.CrossEntropyLoss()

    # metrics    
    metric=Perplexity()
    
    # data 
    train_data_shard = ray.train.get_dataset_shard("train")
    train_dataloader = train_data_shard.iter_torch_batches(batch_size=batch_size_per_worker)
    validate_data_shard   = ray.train.get_dataset_shard("validate")
    validate_dataloader = validate_data_shard.iter_torch_batches(batch_size=1)

    for epoch in range(epoch_start, num_epoch_per_worker):
        model.train()

        train_loss = 0
        for batch in train_dataloader:
            input_ids = batch["input_ids"]
            outputs = model(input_ids)
            target_ids = batch["target_ids"]
            loss = loss_function(outputs.flatten(), target_ids.flatten())
            train_loss = loss.item() # only for reporting
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
                   
        validate_loss = 0
        perplexity = 0
        checkpoint = None
  
        if epoch % check_frequency == 0:
            model.eval()

            with torch.no_grad():                
                for batch in validate_dataloader:
                    input_ids = batch["input_ids"]
                    outputs = model(input_ids)
                    target_ids = batch["target_ids"]                    
                    loss = loss_function(outputs.flatten(), target_ids.flatten())  
                    validate_loss = loss.item()  # only for reporting
                    metric.update(outputs, target_ids)       
                perplexity= metric.compute().item()
            
            metric.reset()  
           
            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:                
                # In standard DDP training, where the model is the same across all ranks,
                # only the global rank 0 worker needs to save and report the checkpoint
                if ray.train.get_context().get_world_rank() == 0:
                        # === Make sure to save all state needed for resuming training ===
                        save_checkpoint(model, optimizer, epoch, temp_checkpoint_dir)
                

            ray.train.report(
                metrics= {
                        "epoch": epoch,
                        "train_loss": train_loss,
                        "validate_loss": validate_loss,
                        "perplexity": perplexity,
                    },
                checkpoint=checkpoint
                ) 

trainer = ray.train.torch.TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config= train_loop_config,
    datasets={
        "train": train_chunked_tokens,
        "validate": validate_chunked_tokens,
    },
    scaling_config=ray.train.ScalingConfig(
        num_workers=cfg["ray_train"]["num_workers"],
        use_gpu=    cfg["ray_train"]["use_gpu"],
        resources_per_worker={"CPU": cfg["ray_train"]["num_cpus_per_worker"], "GPU": cfg["ray_train"]["num_gpus_per_worker"]},
    ),

)
result = trainer.fit()
print(result.metrics)
