In [1]:
import os
import sys

project_root = os.path.dirname(os.getcwd())
sys.path.append(f"{project_root}/src")
sys.path.append(f"{project_root}/third_party")

from config import gpt2_cfg as cfg


In [2]:
import ray

if ray.is_initialized():
    ray.shutdown()


ray.init(
        runtime_env={
            "env_vars": {          
                "PYTHONPATH": "$PYTHONPATH:" + cfg.project_root + "/src",
            },
            "working_dir": cfg.project_root,
            "excludes": [
                "/bazel-*",
                ".git",
                "*.pyc",
                "/__pycache__",
                "/outputs",
                "/model",
            ],
        },
      
    )
# convience for debugging
ray.data.DataContext.get_current().execution_options.verbose_progress = True
ray.data.DataContext.log_internal_stack_trace_to_stdout = True

2024-08-02 09:11:51,651	INFO worker.py:1772 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2024-08-02 09:11:51,656	INFO packaging.py:530 -- Creating a file package for local directory '/workspaces/CaiZi'.
2024-08-02 09:11:51,662	INFO packaging.py:358 -- Pushing file package 'gcs://_ray_pkg_6c71d926d037020d.zip' (0.32MiB) to Ray cluster...
2024-08-02 09:11:51,664	INFO packaging.py:371 -- Successfully pushed file package 'gcs://_ray_pkg_6c71d926d037020d.zip'.


In [3]:
from pathlib import Path
data_sources = [ Path(item["path"]) for item in cfg["dataset"]]
text_document_paths = ray.data.from_items(data_sources)

In [4]:
from document_processor import TextDocumentProcessor
train_text_document_processor = TextDocumentProcessor(section="train")
train_texts=text_document_paths.map(train_text_document_processor)

validate_text_document_processor = TextDocumentProcessor(section="validate")
validate_texts=text_document_paths.map(validate_text_document_processor)


In [5]:
from token_processor import TikTokenizer
tokenizer = TikTokenizer()
train_tokens = train_texts.map(tokenizer)
validate_tokens = validate_texts.map(tokenizer)

In [6]:
from chunk_processor import ChunkProcessor

chunk_processor = ChunkProcessor(max_length=cfg["124M"]["block_size"],stride=cfg["124M"]["stride"])
train_chunked_tokens = train_tokens.flat_map(chunk_processor)
validate_chunked_tokens = validate_tokens.flat_map(chunk_processor)


In [7]:
import tempfile


import ray.train
import torch
from torchmetrics.text import Perplexity

import ray
from ray.train import Checkpoint

from model.GPT import GPT
from utility import save_checkpoint, resume_checkpoint

train_loop_config = {
    "vocab_size": cfg["124M"]["vocab_size"],
    "dimension_embedding": cfg["124M"]["dimension_embedding"],
    "block_size": cfg["124M"]["block_size"],
    "num_layers": cfg["124M"]["num_layers"],
    "num_headers": cfg["124M"]["num_headers"],
    "drop_rate": cfg["124M"]["drop_rate"],
    "qkv_bias": cfg["124M"]["qkv_bias"],
    "check_frequency": cfg["ray_train"]["check_frequency"],
    "batch_size_per_worker": cfg["ray_train"]["batch_size_per_worker"],
    "num_epoch_per_worker": cfg["ray_train"]["num_epoch_per_worker"],
}


def train_loop_per_worker(config):
    vocab_size = config["vocab_size"]
    dimension_embedding = config["dimension_embedding"]
    block_size = config["block_size"]
    num_layers = config["num_layers"]
    num_headers = config["num_headers"]
    drop_rate = config["drop_rate"]
    qkv_bias = config["qkv_bias"]
    check_frequency = config["check_frequency"]
    batch_size_per_worker = config["batch_size_per_worker"]
    num_epoch_per_worker = config["num_epoch_per_worker"]

    # GPT model
    model = GPT(
        vocab_size,
        dimension_embedding,
        block_size,
        num_layers,
        num_headers,
        drop_rate,
        qkv_bias,
    )
    model = ray.train.torch.prepare_model(model)

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

    # ====== Resume training state from the checkpoint. ======
    epoch_start = 0

    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
        epoch_start = resume_checkpoint(model, optimizer, checkpoint)

    # loss function
    loss_function = torch.nn.CrossEntropyLoss()

    rank = ray.train.get_context().get_world_rank()
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

    # metrics
    metric = Perplexity().to(device)

    # data
    train_data_shard = ray.train.get_dataset_shard("train")
    validate_data_shard = ray.train.get_dataset_shard("validate")

    report_metrics = {"epoch": 0, "train_loss": 0.0, "validate_loss": 0.0, "perplexity": 0.0}
    for epoch in range(epoch_start, num_epoch_per_worker):
        model.train()
        
        report_metrics["epoch"] = epoch
        
        train_loss = 0
        batch_count = 0
        for batch in train_data_shard.iter_torch_batches(
            batch_size=batch_size_per_worker,
            drop_last=True,
            local_shuffle_buffer_size=1000,
        ):
            batch_count += 1
            input_ids = batch["input_ids"]
            logits = model(input_ids)
            target_ids = batch["target_ids"]
            loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
            train_loss += loss.item()  # only for reporting
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss = train_loss / batch_count
        
        report_metrics["train_loss"] = train_loss
        
        validate_loss = 0
        perplexity = 0
        checkpoint = None

        if epoch % check_frequency == 0:

            model.eval()

            with torch.no_grad():
                batch_count = 0
                for batch in validate_data_shard.iter_torch_batches(
                    batch_size=1,
                    drop_last=False,
                ):
                    batch_count += 1
                    input_ids = batch["input_ids"]
                    logits = model(input_ids)
                    target_ids = batch["target_ids"]
                    loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
                    validate_loss += loss.item()  # only for reporting
                    metric.update(logits, target_ids)

            validate_loss = validate_loss / batch_count
            perplexity = metric.compute().item()
            metric.reset()
            
            report_metrics["validate_loss"] = validate_loss
            report_metrics["perplexity"] = perplexity

            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                # In standard DDP training, where the model is the same across all ranks,
                # only the global rank 0 worker needs to save and report the checkpoint
                if ray.train.get_context().get_world_rank() == 0:
                    # === Make sure to save all state needed for resuming training ===
                    save_checkpoint(model, optimizer, epoch, temp_checkpoint_dir)
                    checkpoint =  Checkpoint.from_directory(temp_checkpoint_dir)

                ray.train.report(
                    metrics=report_metrics,
                    checkpoint=checkpoint,
                    )

In [8]:

from ray.train.torch import TorchTrainer
from ray.train import Result
from ray import train

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    datasets={
        "train": train_chunked_tokens,
        "validate": validate_chunked_tokens,
    },
    scaling_config=ray.train.ScalingConfig(
        num_workers=cfg["ray_train"]["num_workers"],
        use_gpu=cfg["ray_train"]["use_gpu"],
        resources_per_worker={
            "CPU": cfg["ray_train"]["num_cpus_per_worker"],
            "GPU": cfg["ray_train"]["num_gpus_per_worker"],
        },
    ),
    run_config=train.RunConfig(
            storage_path=cfg["ray_train"]["storage_path"],
            name = cfg["ray_train"]["name"],
            checkpoint_config=ray.train.CheckpointConfig(
            num_to_keep=2,
            # *Best* checkpoints are determined by these params:
            checkpoint_score_attribute="perplexity",
            checkpoint_score_order="min",
        ),
    ),
)
result: Result = trainer.fit()
print(result.metrics)

2024-08-02 09:11:55,181	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2024-08-02 09:11:55 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 0/32 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2024-08-02 09:12:00 (running for 00:00:05.13)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




[36m(TorchTrainer pid=227183)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=227183)[0m - (node_id=2a13ba843ae388e40be93d48336a6ae800c6b9ccb4e90b961169ae36, ip=172.17.0.2, pid=227540) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=227540)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=227540)[0m Moving model to device: cuda:0
[36m(SplitCoordinator pid=227613)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-02_09-11-50_921992_224777/logs/ray-data
[36m(SplitCoordinator pid=227613)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor)] -> OutputSplitter[split(1, equal=True)]
[36m(RayTrainWorker pid=227540)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/workspaces/CaiZi/outputs/gpt2/124M/TorchTrainer_43631_00000_0_2024-08-02_09-11-55/checkpoint_000000)
[36m

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227614) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227614) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227614) Running 0: 0 bundle [00:00, ? bundle/s]

== Status ==
Current time: 2024-08-02 09:12:05 (running for 00:00:10.20)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227614) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227614) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227614) Running 0: 0 bundle [00:00, ? bundle/s]

== Status ==
Current time: 2024-08-02 09:12:10 (running for 00:00:15.22)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227614) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227614) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227614) Running 0: 0 bundle [00:00, ? bundle/s]

== Status ==
Current time: 2024-08-02 09:12:15 (running for 00:00:20.24)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=227613) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=227613) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=227613) Running 0: 0 bundle [00:00, ? bundle/s]

2024-08-02 09:12:19,794	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/workspaces/CaiZi/outputs/gpt2/124M' in 0.0017s.
2024-08-02 09:12:19,794	INFO tune.py:1041 -- Total run time: 24.61 seconds (24.60 seconds for the tuning loop).


== Status ==
Current time: 2024-08-02 09:12:19 (running for 00:00:24.60)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_09-11-50_921992_224777/artifacts/2024-08-02_09-11-55/124M/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


{'epoch': 10, 'train_loss': 1.016655594110489, 'validate_loss': 6.5713417530059814, 'perplexity': 714.3267822265625, 'timestamp': 1722589934, 'checkpoint_dir_name': 'checkpoint_000002', 'should_checkpoint': True, 'done': True, 'training_iteration': 3, 'trial_id': '43631_00000', 'date': '2024-08-02_09-12-17', 'time_this_iter_s': 6.307123899459839, 'time_total_s': 19.69325876235962, 'pid': 227183, 'hostname': '569d72e6c54c', 'node_ip': '172.17.0.2', 'config': {'train_loop_config': {'vocab_size': 50257, 'dimension_embedding': 768, 'block_size': 256, 'num_layers': 12, 'num_headers': 12, 'drop_rate': 0.1, 'qkv_bias': False, 'check_frequency': 5, 'batch_size