In [None]:
import subprocess
import sys
from pathlib import Path

import tiktoken
import torch

from llm_e2e import GPT2Config, GPT2Model, StreamingDatasetGenerator

def setup_colab_env(project_root: str):
    if 'google.colab' not in sys.modules:
        return
        
    print("In Colab notebook. Installing dependencies ...")
    proj_path = Path(project_root)
    if not (root_path / "pyproject.toml").is_file():
        print(f"Error: 'pyproject.toml' not found in {root_path}")
        return
        
    subprocess.run(
        ["uv", "sync"],
        check=True,
        capture_output=True,
        text=True
    )
    
    subprocess.run(
        ["uv", "pip", "install", "-e", "."],
        cwd=root_path,
        check=True,
        capture_output=True,
        text=True
    )

def setup_cuda(cfg: GPT2Config):
    if not torch.cuda.is_available():
        return
        
    assert cfg.device == 'cuda'
    print(f"CUDA version: {torch.version.cuda}")
    capability = torch.cTrueuda.get_device_capability()
    if capability[0] >= 7:  # Volta (7.0+), Turing (7.5+), Ampere (8.0+), Hopper (9.0+)
        torch.set_float32_matmul_precision("high")
        print("Uses tensor cores")
    else:
        print("Tensor cores not supported on this GPU.")
    
project_root = '/home/jimsingh/src/llm_e2e/'
setup_colab_env(project_root)
config_yaml = f"gpt2_bert_corpus_gpu.yaml"
cfg = GPT2Config.from_yaml(f"{project_root}/config/{config_yaml}")
encoding = tiktoken.get_encoding(cfg.encoding_name)
setup_cuda(cfg)

train_dataset = StreamingDatasetGenerator(cfg, encoding=encoding)
val_dataset = StreamingDatasetGenerator(cfg, encoding=encoding)


In [None]:
import torch
import itertools
from datetime import datetime

def generate_text(model, tokenizer, prompt: str, max_tokens=20) -> str:
    model.eval()
    device = next(model.parameters()).device
    
    encoded = tokenizer.encode(prompt)
    encoded_ids = torch.tensor([encoded], dtype=torch.long).to(device)

    # Model inference
    with torch.no_grad():
        output_token_ids = model.generate(encoded_ids, max_tokens)
    
    decoded_ids_list = output_token_ids[0].cpu().tolist()
    decoded_text = tokenizer.decode(decoded_ids_list)
    return decoded_text


def estimate_loss(model, loader, device, eval_iters):
    model.eval()
    losses = torch.zeros(eval_iters)
    for i, (X, Y) in enumerate(itertools.islice(loader, eval_iters)):
        X, Y = X.to(device), Y.to(device)
        logits, loss = model(X, Y)
        losses[i] = loss.item()
    model.train()
    return losses.mean()

@torch.no_grad
def evaluate_model(model, train_loader, val_loader, device, eval_iters):
    """
    Args:
      model: to evaluate
      train_loader: training dataset iterator
      val_loader: validation dataset iterator
      eval_iters: the number of iterations to pull from the loaders

    Returns:
      dict with 'train' and 'val' loss
  """
    train_loss = estimate_loss(model, train_loader, device, eval_iters)
    val_loss = estimate_loss(model, val_loader, device, eval_iters)
    return {'train': train_loss, 'val': val_loss}

def print_gpu_memory_stats(checkpoint_name, device):
    if torch.cuda.is_available() and device.type == 'cuda':
        allocated = torch.cuda.memory_allocated(device) / (1024**2)
        reserved = torch.cuda.memory_reserved(device) / (1024**2)
        max_allocated = torch.cuda.max_memory_allocated(device) / (1024**2)
        max_reserved = torch.cuda.max_memory_reserved(device) / (1024**2)

        print(f"--- GPU Memory Stats at: {checkpoint_name} ({device}) ---")
        print(f"  Current Allocated: {allocated:.2f} MB")
        print(f"  Current Reserved:  {reserved:.2f} MB")
        print(f"  Peak Allocated:    {max_allocated:.2f} MB")
        print(f"  Peak Reserved:     {max_reserved:.2f} MB")
        print("----------------------------------------------------")
        
def train_model(model, train_loader, val_loader, optimizer, gen_f, cfg):
    device = torch.device(cfg.device) # Ensure device object

    if device.type == 'cuda':
        print("Starting training on CUDA device. Initializing memory stats.")
        # Reset peak stats at the beginning of training if you want to track peaks per training run
        torch.cuda.reset_peak_memory_stats(device)
        print_gpu_memory_stats("Start of training_model", device)
    
    print(f"started training model with {cfg.n_params:_} parameters. model parameters file: {cfg.save_filename}")
    
    for epoch in range(cfg.num_epochs):
        model.train()
        running_loss = 0.0
        print(f"[{epoch + 1} / {cfg.num_epochs}]: starting at {datetime.now()}, will show running loss every {cfg.log_interval} steps, will eval every {cfg.eval_interval} steps")
        if device.type == 'cuda':
            print_gpu_memory_stats(f"Start of Epoch {epoch + 1}", device)

        for i, (X, Y) in enumerate(train_loader):
            X, Y = X.to(cfg.device), Y.to(cfg.device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            logits, loss = model(X, Y)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if (i + 1) % cfg.log_interval == 0:
                print(f"[{epoch + 1}  {i + 1:5d}]: running loss {running_loss / cfg.log_interval:.3f}")
                running_loss = 0.0

            if (i + 1) % cfg.eval_interval == 0:
                losses = evaluate_model(model, train_loader, val_loader, device, eval_iters=cfg.eval_iters)
                print(f"[{epoch + 1}  {i + 1:5d}]: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}, eval_iters: {cfg.eval_iters}")
                completion = gen_f(model)
                print(f"[{epoch + 1}  {i + 1:5d}]: {completion}")
                print_gpu_memory_stats(f"[{epoch + 1}  {i + 1:5d}]", device)
                torch.save(model._orig_mod.state_dict(), cfg.save_filename) 


In [None]:
imp
load_full_model = True
load_weights = False
if load_full_model and os.path.exists(cfg.save_filename):
    model = torch.load(cfg.save_filename, weights_only=False)
    print(f"loaded model weights: {cfg.save_filename}")
elif load_weights:
    # Create a new state_dict without the "_orig_mod." prefix
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("_orig_mod."):
            new_key = k[len("_orig_mod."):]  # Remove the prefix
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v # If for some reason some keys don't have it
    
    model.load_state_dict(new_state_dict)
else:
    model = GPTModel(cfg)



if cfg.device == 'cuda':
    model.to(torch.bfloat16)

model.to(cfg.device)

if cfg.compile_model:
    model = torch.compile(model)

gen_f = lambda m: generate_text(m, enc, "Paris is")

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
train_model(model, train_loader, val_loader, optimizer, gen_f, cfg=cfg)

torch.save(model, cfg.save_filename)

print('Finished Training')

In [None]:
if False:
    model_path = 'gpt2_training_fineweb-edu_47000_steps.pth'
    torch.save(model.state_dict(), model_path)
# Load the saved state_dict
if False:
    # Instantiate the model
    cfg = GPT2Config().from_yaml("gpt2_config_wikipedia_cpu.yaml")
    enc = tiktoken.get_encoding(cfg.encoding_name)
    state_dict = torch.load(cfg.save_filename)
    
    # Create a new state_dict without the "_orig_mod." prefix
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("_orig_mod."):
            new_key = k[len("_orig_mod."):]  # Remove the prefix
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v # If for some reason some keys don't have it
    
    model = GPTModel(cfg)
    model.load_state_dict(new_state_dict)


In [None]:
#print_gpu_memory_stats('foo', torch.device(cfg.device))

import gc
import torch
# Assuming 'my_lingering_tensor_var' is a variable holding a tensor
# or 'my_list_of_tensors' is a list holding them.

# If you know the variable names:
# del my_lingering_tensor_var
# del my_list_of_tensors

# Then run garbage collection
gc.collect()
if False:
  del model, optimizer, train_loader, val_loader

gc.collect()
# It's also good practice to clear PyTorch's cache AFTER Python references are gone
if torch.cuda.is_available():
    torch.cuda.empty_cache()