In [1]:
import os
import sys

project_root = os.path.dirname(os.getcwd())
sys.path.append(f"{project_root}/src")
sys.path.append(f"{project_root}/third_party")

os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_COLOR_PREFIX"] = "0"

from config import gpt2_cfg as cfg


In [2]:
import ray
if ray.is_initialized():
    ray.shutdown()

ray.init(
        runtime_env={
            "env_vars": {          
                "PYTHONPATH": "$PYTHONPATH:" + cfg.project_root + "/src",
                "RAY_DATA_VERBOSE_PROGRESS": "1",
            },
            "working_dir": cfg.project_root,
            "excludes": [
                "/bazel-*",
                ".git",
                "*.pyc",
                "/__pycache__",
                "/outputs",
                "/model",
            ],
        },
        ignore_reinit_error=True,
    )

# convience for debugging
ray.data.DataContext.get_current().execution_options.verbose_progress = True
ray.data.DataContext.log_internal_stack_trace_to_stdout = True

2024-08-02 13:28:35,136	INFO worker.py:1772 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2024-08-02 13:28:35,142	INFO packaging.py:530 -- Creating a file package for local directory '/workspaces/CaiZi'.
2024-08-02 13:28:35,147	INFO packaging.py:358 -- Pushing file package 'gcs://_ray_pkg_891678a0e78684b4.zip' (0.32MiB) to Ray cluster...
2024-08-02 13:28:35,149	INFO packaging.py:371 -- Successfully pushed file package 'gcs://_ray_pkg_891678a0e78684b4.zip'.


In [3]:
from pathlib import Path
data_sources = [ Path(item["path"]) for item in cfg["dataset"]]
text_document_paths = ray.data.from_items(data_sources)

In [4]:
from document_processor import TextDocumentProcessor
train_text_document_processor = TextDocumentProcessor(section="train")
train_texts=text_document_paths.map(train_text_document_processor)

validate_text_document_processor = TextDocumentProcessor(section="validate")
validate_texts=text_document_paths.map(validate_text_document_processor)


In [5]:
from token_processor import TikTokenizer
tokenizer = TikTokenizer()
train_tokens = train_texts.map(tokenizer)
validate_tokens = validate_texts.map(tokenizer)

In [6]:
from chunk_processor import ChunkProcessor

chunk_processor = ChunkProcessor(max_length=cfg["124M"]["block_size"],stride=cfg["124M"]["stride"])
train_chunked_tokens = train_tokens.flat_map(chunk_processor)
validate_chunked_tokens = validate_tokens.flat_map(chunk_processor)

In [7]:
import torch
from torchmetrics.text import Perplexity

import ray
import ray.train
from ray.train import Checkpoint

from model.GPT import GPT
from utility import save_checkpoint, resume_checkpoint


def train_loop_per_worker(config):
    vocab_size = config["vocab_size"]
    dimension_embedding = config["dimension_embedding"]
    block_size = config["block_size"]
    num_layers = config["num_layers"]
    num_headers = config["num_headers"]
    drop_rate = config["drop_rate"]
    qkv_bias = config["qkv_bias"]
    check_frequency = config["check_frequency"]
    batch_size_per_worker = config["batch_size_per_worker"]
    num_epoch_per_worker = config["num_epoch_per_worker"]
    resume_training = config["resume_training"]
    best_checkpoint_dir = config["best_checkpoint_dir"]

    # GPT model
    model = GPT(
        vocab_size,
        dimension_embedding,
        block_size,
        num_layers,
        num_headers,
        drop_rate,
        qkv_bias,
    )
    model = ray.train.torch.prepare_model(model)

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

    # ====== Resume training state from the checkpoint. ======
    epoch_start = 0
    best_perplexity = float("inf")
    best_epoch = 0

    if resume_training:
        if os.path.exists(best_checkpoint_dir):
            checkpoint = ray.train.Checkpoint.from_directory(best_checkpoint_dir)
        else:
            checkpoint = None
        if checkpoint:
            best_epoch, best_perplexity = resume_checkpoint(model, optimizer, checkpoint)
            print(
                f"Resumed training from best_epoch {best_epoch},best_perplexity {best_perplexity}"
            )
        else:
            print(f"Checkpoint not found, starting from epoch 0")

    # loss function
    loss_function = torch.nn.CrossEntropyLoss()

    rank = ray.train.get_context().get_world_rank()
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

    # metrics
    metric = Perplexity().to(device)

    # data
    train_data_shard = ray.train.get_dataset_shard("train")
    validate_data_shard = ray.train.get_dataset_shard("validate")

    report_metrics = {
        "epoch": 0,
        "train_loss": 0.0,
        "validate_loss": 0.0,
        "perplexity": 0.0,
    }

    for epoch in range(epoch_start + 1, num_epoch_per_worker + 1):
        model.train()

        report_metrics["epoch"] = epoch

        train_loss = 0
        batch_count = 0
        for batch in train_data_shard.iter_torch_batches(
            batch_size=batch_size_per_worker,
            drop_last=True,
            local_shuffle_buffer_size=1000,
        ):
            batch_count += 1
            input_ids = batch["input_ids"]
            logits = model(input_ids)
            target_ids = batch["target_ids"]
            loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
            train_loss += loss.item()  # only for reporting
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss = train_loss / batch_count

        report_metrics["train_loss"] = train_loss

        validate_loss = 0
        perplexity = 0

        if epoch % check_frequency == 0:

            model.eval()

            with torch.no_grad():
                batch_count = 0
                for batch in validate_data_shard.iter_torch_batches(
                    batch_size=1,
                    drop_last=False,
                ):
                    batch_count += 1
                    input_ids = batch["input_ids"]
                    logits = model(input_ids)
                    target_ids = batch["target_ids"]
                    loss = loss_function(logits.flatten(0, 1), target_ids.flatten())
                    validate_loss += loss.item()  # only for reporting
                    metric.update(logits, target_ids)

            validate_loss = validate_loss / batch_count
            perplexity = metric.compute().item()
            metric.reset()

            report_metrics["validate_loss"] = validate_loss
            report_metrics["perplexity"] = perplexity
            report_metrics["best_epoch"] = best_epoch
            report_metrics["best_perplexity"] = best_perplexity

            # In standard DDP training, where the model is the same across all ranks,
            # only the global rank 0 worker needs to save and report the checkpoint
            if ray.train.get_context().get_world_rank() == 0:
                if perplexity < best_perplexity:
                    best_perplexity = perplexity
                    best_epoch = epoch

                    # create the best_checkpoint_dir if it does not exist
                    if not os.path.exists(best_checkpoint_dir):
                        os.makedirs(best_checkpoint_dir)

                    save_checkpoint(
                        model,
                        optimizer,
                        epoch,
                        perplexity,
                        best_checkpoint_dir,
                    )

                ray.train.report(metrics=report_metrics)

In [8]:
from ray.train.torch import TorchTrainer
from ray.train import Result
from ray import train

train_loop_config = {
    "vocab_size": cfg["124M"]["vocab_size"],
    "dimension_embedding": cfg["124M"]["dimension_embedding"],
    "block_size": cfg["124M"]["block_size"],
    "num_layers": cfg["124M"]["num_layers"],
    "num_headers": cfg["124M"]["num_headers"],
    "drop_rate": cfg["124M"]["drop_rate"],
    "qkv_bias": cfg["124M"]["qkv_bias"],
    "check_frequency": cfg["ray_train"]["check_frequency"],
    "batch_size_per_worker": cfg["ray_train"]["batch_size_per_worker"],
    "num_epoch_per_worker": cfg["ray_train"]["num_epoch_per_worker"],
    "resume_training":cfg["ray_train"]["resume_training"],
    "best_checkpoint_dir":cfg["ray_train"]["best_checkpoint_dir"],
}


trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    datasets={
        "train": train_chunked_tokens,
        "validate": validate_chunked_tokens,
    },
    scaling_config=ray.train.ScalingConfig(
        num_workers=cfg["ray_train"]["num_workers"],
        use_gpu=cfg["ray_train"]["use_gpu"],
        resources_per_worker={
            "CPU": cfg["ray_train"]["num_cpus_per_worker"],
            "GPU": cfg["ray_train"]["num_gpus_per_worker"],
        },
    ),
    run_config=train.RunConfig(
        storage_path=cfg["ray_train"]["storage_path"],
        name=cfg["ray_train"]["name"],
    ),
    
)
result: Result = trainer.fit()
print(result.metrics)

2024-08-02 13:28:38,600	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2024-08-02 13:28:38 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 0/32 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_13-28-34_407310_432788/artifacts/2024-08-02_13-28-38/124M/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2024-08-02 13:28:43 (running for 00:00:05.20)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_13-28-34_407310_432788/artifacts/2024-08-02_13-28-38/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(TorchTrainer pid=435232) Started distributed worker processes: 
(TorchTrainer pid=435232) - (node_id=05b22615e8a4cafa239a671ee5d7134909a17e17dbf9c30248c66a7a, ip=172.17.0.2, pid=435588) world_rank=0, local_rank=0, node_rank=0
(RayTrainWorker pid=435588) Setting up process group for: env:// [rank=0, world_size=1]
(RayTrainWorker pid=435588) Moving model to device: cuda:0
(RayTrainWorker pid=435588)   model_state_dict = torch.load(
(RayTrainWorker pid=435588)   torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))


(RayTrainWorker pid=435588) Resumed training from best_epoch 9,best_perplexity 503.8351135253906


(RayTrainWorker pid=435588)   extra_state = torch.load(os.path.join(checkpoint_dir, "extra_state.pt"))
(SplitCoordinator pid=435664) Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-02_13-28-34_407310_432788/logs/ray-data
(SplitCoordinator pid=435664) Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor)] -> OutputSplitter[split(1, equal=True)]
(SplitCoordinator pid=435664) Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-02_13-28-34_407310_432788/logs/ray-data
(SplitCoordinator pid=435664) Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor)] -> OutputSplitter[split(1, equal=True)]
(SplitCoordinator pid=435664) Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-02_13-28-34_407310_432788/logs/ray-data
(SplitCoordinator pid=435664) Executio

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435665) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435665) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435665) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435665) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435665) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435665) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435665) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435665) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435665) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435665) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435665) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435665) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435664) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435664) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435664) Running 0: 0 bundle [00:00, ? bundle/s]

(pid=435665) - Map(TextDocumentProcessor)->Map(TikTokenizer)->FlatMap(ChunkProcessor) 1: 0 bundle [00:00, ? bu…

(pid=435665) - split(1, equal=True) 2: 0 bundle [00:00, ? bundle/s]

(pid=435665) Running 0: 0 bundle [00:00, ? bundle/s]

== Status ==
Current time: 2024-08-02 13:28:48 (running for 00:00:10.23)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_13-28-34_407310_432788/artifacts/2024-08-02_13-28-38/124M/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2024-08-02 13:28:49,910	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/workspaces/CaiZi/outputs/gpt2/124M' in 0.0020s.
2024-08-02 13:28:49,911	INFO tune.py:1041 -- Total run time: 11.31 seconds (11.30 seconds for the tuning loop).


== Status ==
Current time: 2024-08-02 13:28:49 (running for 00:00:11.30)
Using FIFO scheduling algorithm.
Logical resource usage: 2.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2024-08-02_13-28-34_407310_432788/artifacts/2024-08-02_13-28-38/124M/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


{'epoch': 15, 'train_loss': 0.17282094061374664, 'validate_loss': 7.031112194061279, 'perplexity': 1131.2891845703125, 'best_epoch': 9, 'best_perplexity': 503.8351135253906, 'timestamp': 1722605328, 'checkpoint_dir_name': None, 'done': True, 'training_iteration': 5, 'trial_id': '208a2_00000', 'date': '2024-08-02_13-28-48', 'time_this_iter_s': 0.6792900562286377, 'time_total_s': 7.634586334228516, 'pid': 435232, 'hostname': '569d72e6c54c', 'node_ip': '172.17.0.2', 'config': {'train_loop_config': {'vocab_size': 50257, 'dimension_embedding': 768, 'block_size': 256, 'num_layers': 12, 'num_headers': 12, 'drop_rate': 0.1, 'qkv_bias': False, 'check_frequency'