In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
from contextlib import nullcontext
import datasets
from dotenv import load_dotenv
import json
from lib.llama3.reference_impl.model import ModelArgs, Transformer
from lib.utils import black_print
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
import neptune
import os
import subprocess
import torch
import torch.distributed
import torch.nn.functional as F
from torch.optim.adamw import AdamW
from tqdm import tqdm
from typing import Iterable
import wandb

load_dotenv()

if torch.cuda.is_available() and int(os.environ.get("RANK", -1)) != -1:
    torch.distributed.init_process_group(backend="nccl")
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = torch.distributed.get_world_size()
else:
    rank = 0
    local_rank = 0
    world_size = 1

if world_size > 1:
    print(f"Rank: {rank} - Local Rank: {local_rank} - World Size: {world_size}")

parser = argparse.ArgumentParser(description="Llama3 training script")
parser.add_argument("--name", type=str, default=None, help="Run name")
args, _ = parser.parse_known_args()
run_name = args.name
run_path = f"./runs/{run_name}"
gs_path = f"gs://atreides/experiments/runs/{run_name}"
if run_name and not os.path.exists(run_path):
    os.makedirs(run_path, exist_ok=True)
    subprocess.Popen(
        f"gsutil -m rsync -r {gs_path} {run_path}",
        shell=True,
    ).wait()
params_path = f"{run_path}/params.json"
model_path = f"{run_path}/model.pth"
optimizer_path = f"{run_path}/optimizer.pth"
scheduler_path = f"{run_path}/scheduler.pth"
dataset_state_path = f"{run_path}/dataset-state.json"
training_state_path = f"{run_path}/training-state.json"

if rank == 0:
    neptune_run = neptune.init_run(custom_run_id=run_name)
    wandb_run = wandb.init(name=run_name, resume="allow", id=run_name)

llama3_2_1B = resolve_model("Llama3.2-1B")
assert llama3_2_1B is not None
params = llama3_2_1B.arch_args
params["dim"] //= 4
params["n_heads"] //= 4
params["n_kv_heads"] //= 4
params["n_layers"] //= 4
params["max_seq_len"] = 512
params["max_batch_size"] = 4
params["use_flash_attention"] = True

if os.path.exists(params_path):
    params = json.load(open(params_path, "r"))
elif run_name:
    os.makedirs(run_path, exist_ok=True)
    json.dump(params, open(params_path, "w"))

model_args = ModelArgs(**params)
if rank == 0:
    black_print(model_args)
micro_step_tokens = model_args.max_seq_len * model_args.max_batch_size

model = Transformer(model_args)
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, weights_only=True))
if torch.cuda.is_available():
    model = model.to(f"cuda:{local_rank}")
    # model.compile()
    torch.set_float32_matmul_precision("high")
    # context = (
    #     torch.autocast(
    #         device_type="cuda",
    #         dtype=torch.bfloat16,
    #     )
    #     if torch.cuda.is_bf16_supported()
    #     else nullcontext()
    # )
    context = nullcontext()
elif torch.backends.mps.is_available():
    model = model.to("mps")
    context = nullcontext()
else:
    model = model.to("cpu")
    context = nullcontext()
model_device = next(model.parameters()).device
if rank == 0:
    print(f"Model Device: {model_device}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if rank == 0:
    print(f"Trainable Parameters: {total_params:,}")
pretrain_tokens = int(total_params * 20)  # Chinchilla-optimal
tokenizer = Tokenizer.get_instance()
if rank != 0:
    datasets.disable_progress_bars()
dataset: datasets.IterableDataset = datasets.load_dataset(
    "HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True
)  # type: ignore
val_dataset: datasets.IterableDataset = datasets.load_dataset(
    "reflex-ai/fineweb-ultra-mini", split="train", streaming=True
)  # type: ignore


def batches(
    dataset: datasets.IterableDataset,
) -> Iterable[tuple[torch.Tensor, torch.Tensor]]:
    max_tokens = (model_args.max_seq_len + 1) * model_args.max_batch_size
    tokens = []
    for index, document in enumerate(dataset):
        if index % world_size != rank:
            continue
        tokens += tokenizer.encode(document["text"], bos=True, eos=True)
        if len(tokens) >= max_tokens:
            batch = torch.tensor(
                tokens[:max_tokens], dtype=torch.long, device=model_device
            ).reshape(model_args.max_batch_size, -1)
            yield batch[:, :-1], batch[:, 1:]
            tokens = tokens[max_tokens:]
    if tokens:
        pad_length = max_tokens - len(tokens)
        tokens += [tokenizer.pad_id] * pad_length
        batch = torch.tensor(
            tokens[:max_tokens], dtype=torch.long, device=model_device
        ).reshape(model_args.max_batch_size, -1)
        yield batch[:, :-1], batch[:, 1:]


peak_lr = 6e-4 / ((total_params * 1e-9) ** (1 / 3))
optimizer = AdamW(model.parameters(), lr=peak_lr)
if os.path.exists(optimizer_path):
    optimizer.load_state_dict(torch.load(optimizer_path, weights_only=True))
step_tokens = 2**19
cosine_annealing_steps = pretrain_tokens // step_tokens
warmup_steps = cosine_annealing_steps // 150
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    [
        torch.optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1 / warmup_steps,
            end_factor=1,
            total_iters=warmup_steps,
        ),
        torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_annealing_steps, eta_min=peak_lr * 0.01),  # type: ignore
    ],
    milestones=[warmup_steps],
)
if os.path.exists(scheduler_path):
    scheduler.load_state_dict(torch.load(scheduler_path, weights_only=True))
grad_accum_steps = step_tokens // micro_step_tokens // world_size
grad_accum_threshold = step_tokens

if world_size > 1:
    distributed_model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[local_rank]
    )
else:
    distributed_model = model


training_state = (
    json.loads(open(training_state_path, "r").read())
    if os.path.exists(training_state_path)
    else {}
)

train_loss = torch.tensor(training_state.get("train_loss", 12.0)).to(model_device)
next_train_loss = torch.tensor(0.0).to(model_device)

val_size = 1000
val_ids = {document["id"] for document in val_dataset.take(val_size)}
val_token_threshold = training_state.get("val_token_threshold", 0)
val_loss = torch.tensor(training_state.get("val_loss", 12.0)).to(model_device)


def update_val_loss() -> None:
    global val_token_threshold, val_loss
    val_token_threshold += 10_000_000
    distributed_model.eval()
    with torch.no_grad():
        val_loss.zero_()
        num_batches = 0
        for val_x, val_y in batches(val_dataset.take(val_size)):
            val_logits = distributed_model(val_x, 0)
            val_loss += F.cross_entropy(
                val_logits.view(-1, model.vocab_size), val_y.flatten()
            )
            num_batches += 1
        val_loss /= num_batches
        if world_size > 1:
            torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.AVG)
    distributed_model.train()


dataset = dataset.filter(lambda x: x["id"] not in val_ids)
if os.path.exists(dataset_state_path):
    dataset.load_state_dict(json.load(open(dataset_state_path, "r")))
save_frequecy = 10_000_000
save_token_threshold = training_state.get("save_token_threshold", save_frequecy)


def save_state(pbar: tqdm) -> None:
    if not run_name:
        return
    global save_token_threshold
    save_token_threshold += save_frequecy
    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    torch.save(scheduler.state_dict(), scheduler_path)
    json.dump(dataset.state_dict(), open(dataset_state_path, "w"))
    json.dump(
        {
            "val_loss": val_loss.item(),
            "val_token_threshold": val_token_threshold,
            "train_loss": train_loss.item(),
            "save_token_threshold": save_token_threshold,
            "progress": pbar.n,
        },
        open(training_state_path, "w"),
    )
    subprocess.Popen(
        f"gsutil -m rsync -r {run_path} {gs_path}",
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )


with tqdm(
    desc="Training model",
    total=pretrain_tokens,
    disable=rank != 0,
    unit="token",
    initial=training_state.get("progress", 0),
    bar_format="{l_bar}{bar}| {n:,}/{total:,} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
) as pbar:
    for x, y in batches(dataset):
        with context:
            logits = distributed_model(x, 0)
            loss = F.cross_entropy(logits.view(-1, model.vocab_size), y.flatten())
        loss /= grad_accum_steps
        next_train_loss += loss

        if rank == 0:
            pbar.update(x.numel() * world_size)
        else:
            pbar.n += x.numel() * world_size

        if pbar.n >= grad_accum_threshold:
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            grad_accum_threshold += step_tokens
            train_loss = next_train_loss
            if world_size > 1:
                torch.distributed.all_reduce(
                    train_loss, op=torch.distributed.ReduceOp.AVG
                )
            next_train_loss = torch.tensor(0.0).to(model_device)
            if rank == 0:
                neptune_run["train/loss"].append(train_loss.item(), step=pbar.n)
                neptune_run["lr"].append(scheduler.get_last_lr()[0], step=pbar.n)
                wandb_run.log(
                    {
                        "loss/train": train_loss.item(),
                        "lr": scheduler.get_last_lr()[0],
                    },
                    step=pbar.n,
                )
            if pbar.n >= val_token_threshold:
                update_val_loss()
                if rank == 0:
                    neptune_run["val/loss"].append(val_loss.item(), step=pbar.n)
                    wandb_run.log({"loss/val": val_loss.item()}, step=pbar.n)
            if pbar.n >= save_token_threshold:
                save_state(pbar)
            if pbar.n >= pretrain_tokens:
                break
        else:
            with (distributed_model.no_sync if world_size > 1 else nullcontext)():
                loss.backward()
        
        pbar.set_postfix(
            {
                "train_loss": train_loss.item(),
                "val_loss": val_loss.item(),
                "lr": scheduler.get_last_lr()[0],
            }
        )

if rank == 0:
    neptune_run.stop()
    wandb_run.finish()



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/ender-research/atreides/e/AT-24


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


ModelArgs(
    dim=512,
    n_layers=4,
    n_heads=8,
    n_kv_heads=2,
    vocab_size=128256,
    multiple_of=256,
    ffn_dim_multiplier=1.5,
    norm_eps=1e-05,
    rope_theta=500000.0,
    use_scaled_rope=True,
    max_batch_size=4,
    max_seq_len=512,
    vision_chunk_size=-1,
    vision_max_num_chunks=4,
    vision_num_cross_attention_layers=-1,
)
Model Device: cuda:0
Trainable Parameters: 146,543,104


Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/44 [00:00<?, ?it/s]

Training model:   1%|▏         | 37,279,744/2,930,862,080 [32:43<42:19:58, 18986.90token/s, train_loss=6.76, val_loss=7.1, lr=0.00114] 


KeyboardInterrupt: 

In [None]:
dataset.state_dict()

{'ex_iterable': {'shard_idx': 0, 'shard_example_idx': 3335},
 'previous_state': None,
 'num_examples_since_previous_state': 0,
 'previous_state_example_idx': 3335}

In [None]:
dataset.state_dict()

{'ex_iterable': {'shard_idx': 0, 'shard_example_idx': 177},
 'previous_state': None,
 'num_examples_since_previous_state': 0,
 'previous_state_example_idx': 177}

In [None]:
from lib.llama3.reference_impl.model import ModelArgs
from llama_models.sku_list import resolve_model


llama3_2_1B = resolve_model("Llama3.2-1B")
assert llama3_2_1B is not None
params = llama3_2_1B.arch_args
params["dim"] //= 4
params["n_heads"] //= 4
params["n_kv_heads"] //= 4
params["n_layers"] //= 4

class ExtendedModelArgs(ModelArgs):
    flash_attention: bool = False

ExtendedModelArgs(
    max_seq_len=512,
    max_batch_size=8,
    flash_attention=False,
    **params,
)

False

In [None]:
from lib.llama3.reference_impl.generation import Llama
import os

llama3_2_1B_ckpt_dir = os.path.expanduser("~/.llama/checkpoints/Llama3.2-1B/original/")
tokenizer_path = llama3_2_1B_ckpt_dir + "tokenizer.model"

llama = Llama.build(
    ckpt_dir=llama3_2_1B_ckpt_dir,
    tokenizer_path=tokenizer_path,
    max_seq_len=512,
    max_batch_size=1,
    device="cpu",
)

next(llama.model.parameters()).device.type

Loaded in 7.03 seconds


'cpu'

In [None]:
llama.text_completion("What is the meaning of life?", max_gen_len=10).generation

' I have no idea, but I think I can'

In [None]:
shakespeare_text = open("./data/tinyshakespeare.txt", "r").read()

# Display the first few lines
print("First few lines of Shakespeare's text:")
print(shakespeare_text[:500])

# Get some statistics
total_chars = len(shakespeare_text)
total_lines = shakespeare_text.count("\n")

print(f"\nTotal characters: {total_chars}")
print(f"Total lines: {total_lines}")

First few lines of Shakespeare's text:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor

Total characters: 1115394
Total lines: 40000


In [None]:
# Check model parameters
print("Model Parameters:")
print(f"Vocabulary Size: {model.vocab_size}")
print(f"Number of Layers: {model.n_layers}")
print(f"Embedding Dimension: {model.params.dim}")
print(f"Number of Attention Heads: {model.params.n_heads}")
print(f"Max Sequence Length: {model.params.max_seq_len}")
print(f"Feedforward Dimension: {model.layers[0].feed_forward.w1.out_features}")

# Check if parameters are initialized
print("\nParameter Initialization:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {'Initialized' if param.sum().item() != 0 else 'Not initialized'}")

# Verify shapes of key components
print("\nKey Component Shapes:")
print(f"Token Embeddings: {model.tok_embeddings.weight.shape}")
print(f"Output Layer: {model.output.weight.shape}")
print(f"First Layer Query Weight: {model.layers[0].attention.wq.weight.shape}")
print(f"First Layer Key Weight: {model.layers[0].attention.wk.weight.shape}")
print(f"First Layer Value Weight: {model.layers[0].attention.wv.weight.shape}")

# Check for NaNs or infinities
print("\nNaN/Inf Check:")
for name, param in model.named_parameters():
    if torch.isnan(param).any() or torch.isinf(param).any():
        print(f"Warning: {name} contains NaN or Inf values")
    else:
        print(f"{name}: OK")


Model Parameters:
Vocabulary Size: 128256
Number of Layers: 4
Embedding Dimension: 512
Number of Attention Heads: 8
Max Sequence Length: 512
Feedforward Dimension: 2048

Parameter Initialization:
tok_embeddings.weight: Initialized
layers.0.attention.wq.weight: Initialized
layers.0.attention.wk.weight: Initialized
layers.0.attention.wv.weight: Initialized
layers.0.attention.wo.weight: Initialized
layers.0.feed_forward.w1.weight: Initialized
layers.0.feed_forward.w2.weight: Initialized
layers.0.feed_forward.w3.weight: Initialized
layers.0.attention_norm.weight: Initialized
layers.0.ffn_norm.weight: Initialized
layers.1.attention.wq.weight: Initialized
layers.1.attention.wk.weight: Initialized
layers.1.attention.wv.weight: Initialized
layers.1.attention.wo.weight: Initialized
layers.1.feed_forward.w1.weight: Initialized
layers.1.feed_forward.w2.weight: Initialized
layers.1.feed_forward.w3.weight: Initialized
layers.1.attention_norm.weight: Initialized
layers.1.ffn_norm.weight: Initialized

In [None]:
# Calculate number of trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal Trainable Parameters: {total_params:,}")

# Calculate size in gigabytes (assuming float32 parameters)
size_in_gb = total_params * 4 / (1024**3)  # 4 bytes per float32 parameter
print(f"Approximate Model Size: {size_in_gb:.2f} GB")


Total Trainable Parameters: 146,543,104
Approximate Model Size: 0.55 GB


In [None]:
# Checkpoint the model
import os

# Create a directory for checkpoints if it doesn't exist
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Save the model state
model_path = os.path.join(checkpoint_dir, "llama3_model_checkpoint.pth")
torch.save(model.state_dict(), model_path)

print(f"Model checkpoint saved to {model_path}")

# Save the model arguments
import json

model_args_path = os.path.join(checkpoint_dir, "llama3_model_args.json")
with open(model_args_path, 'w') as f:
    json.dump(vars(model.params), f, indent=2)

print(f"Model arguments saved to {model_args_path}")

Model checkpoint saved to checkpoints/llama3_model_checkpoint.pth
Model arguments saved to checkpoints/llama3_model_args.json


In [None]:
model.eval()
print(Llama(model, llama.tokenizer, model_args).text_completion("To be or not to be,", temperature=1.0, max_gen_len=500, echo=True).generation.split("<|begin_of_text|>")[1])

To be or not to be, no will to die through love itself.
Why, how to tread how do to honour newly your brother,
But to have five thousand thanks too much to Clarence:
I'll give my soul,
To should our speech of gold and too:
You are dear train, and father, poor brother,
Ere further conference with a passing small.
O Dorsetable.
Your sense may beggarly the tomb,
And bid me mistress sit dispatch: past the boy,
And well lost with one thing just proportion,
And over the board, under his liking!
And all the watchful eye of dear faith,
More fierce and an inditeous wrath!
How well, lords, I befall, and lay,
Is not forgot the tyrant, to fill the crown,
And manage of your glorious sun: regent join'd!
Yet would youravenousoddess, that went;
And well we have heard of all run a needful';
Anduile me with the root
And buryWhat! myself become a tyrant
Stands without the brat's king in Bosworth
To leap upon a black tidings was;
And in all my tumble down: great leaving me,
'Twere a bloody axe to that mak