In [1]:
import archie
import archie.training

import torch
import torchinfo
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.amp import autocast

from datetime import datetime
from transformers import Trainer, TrainingArguments

import math

  from .autonotebook import tqdm as notebook_tqdm


# Helper stuff

In [None]:
def enable_gradient_checkpointing(model):
    for layer in model.layers:
        layer._forward = layer.forward
        layer.forward = lambda x, m=layer: torch.utils.checkpoint.checkpoint(
            m._forward, x, use_reentrant=False
        )


# Learning rate scheduler (cosine with warmup)
def get_lr(step, warmup_steps=100, max_steps=100000):  # Changed from 2000 to 100
    if step < warmup_steps:
        return 3e-4 * step / warmup_steps
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return 3e-4 * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))


@torch.no_grad()
def generate_text(model, tokenizer, prompt="The", max_tokens=50, temperature=0.8):
    model.eval()
    tokens = torch.tensor([tokenizer.encode(prompt)]).to(model.config.device)

    for _ in range(max_tokens):
        # Forward pass
        logits, _ = model(tokens)

        # Get logits for last token
        logits = logits[:, -1, :] / temperature

        # Sample
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Append
        tokens = torch.cat([tokens, next_token], dim=1)

        # Stop at max sequence length
        if tokens.shape[1] >= model.config.max_seq_len:
            break

    model.train()
    return tokenizer.decode(tokens[0].cpu().tolist())



# Model Creation

In [None]:
config = archie.Config(device="cuda")

model = archie.ArchieModel(config).to(config.device).to(torch.bfloat16)
enable_gradient_checkpointing(model)

torchinfo.summary(model)

tokenizer = archie.get_tokenizer()

optimizer = torch.optim.AdamW(
    model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=0.1
)

global_step = 0
tokens_seen = 0

In [None]:
# Load the model
state = torch.load('archie_1b.pt', map_location=config.device)
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
global_step = state['step']
tokens_seen = state['tokens_seen']

In [9]:
result = generate_text(model, tokenizer, "The Capital of France is")
result

'The Capital of France is in. a material the of was in of. is for the in process the for or in process the of.The is is for courses minute that positive is in the in connection the is user\nail Social of is completely clear thex is and'

In [13]:
# Checkpoint!
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'step': global_step,
    'tokens_seen': tokens_seen,
}, f'archie_1b.pt')

socket.send() raised exception.
socket.send() raised exception.


Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7fe6398d3620>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7fe5ec5e9e50, raw_cell="# Checkpoint!
torch.save({
    'model': model.stat.." transformed_cell="# Checkpoint!
torch.save({
    'model': model.stat.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a226475626c696e6572227d/tank/nick/byte-gpt/archie.ipynb#X20sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

socket.send() raised exception.
socket.send() raised exception.


Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7fe6398d3620>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fe5ec3eb620, execution_count=13 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7fe5ec5e9e50, raw_cell="# Checkpoint!
torch.save({
    'model': model.stat.." transformed_cell="# Checkpoint!
torch.save({
    'model': model.stat.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a226475626c696e6572227d/tank/nick/byte-gpt/archie.ipynb#X20sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [10]:
batch_size = 6
accumulation_steps = 8  # batch_size=4 â†’ effective batch_size=32

dataset = archie.training.get_datasets()
text_dataset = archie.training.TextDataset(dataset, tokenizer, config)
train_loader = DataLoader(
    text_dataset,
    batch_size=batch_size,
    num_workers=0,
    # prefetch_factor=64,  # Prefetch
    pin_memory=True,  # Optimization
)

In [11]:
import wandb
total_params = sum(p.numel() for p in model.parameters())
model_name = f'archie_{format_number(total_params)}'
wandb.init(project="archie-training", name=model_name,
    config={
        "model_size": format_number(total_params),
        "batch_size": batch_size,
        "accumulation_steps": accumulation_steps,
    })

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home/nick/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mnickwanninger[0m ([33mnickwanninger-northwestern-universtiy[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Accumulated Training Loop

In [12]:
total_params = sum(p.numel() for p in model.parameters())


for i, (x, y) in enumerate(train_loader):
    x = x.to(config.device)
    y = y.to(config.device)
    
    # Track tokens (batch_size * seq_len)
    tokens_seen += x.numel()
    
    # Forward + backward
    logits, loss = model(x, labels=y)
    loss = loss / accumulation_steps
    loss.backward()
    
    # Update weights every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        global_step += 1
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Optimizer step
        optimizer.step()

        lr = get_lr(global_step)
        # lr = 1e-4
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        optimizer.zero_grad()

        effective_loss = loss.item() * accumulation_steps
        
        # Logging
        perplexity = math.exp(effective_loss)
        tokens_M = tokens_seen / 1_000_000
        tokens_per_param = tokens_seen / total_params

        try:
            wandb.log({
                'loss': effective_loss,
                'perplexity': perplexity,
                'params_trained': tokens_seen,
                'lr': lr,
            }, step=global_step)
        except Exception as e:
            print('Exception while loging to W&B:', e)
        print(f"Step {global_step} | Tokens: {tokens_M:.2f}M ({tokens_per_param:.3f}/param) | PPL: {perplexity:9.2f} | loss: {effective_loss:.2f} | LR: {lr:.2e}")
        

Step 417 | Tokens: 41.78M (0.036/param) | PPL:   1062.89 | loss: 6.97 | LR: 3.00e-04
Step 418 | Tokens: 41.88M (0.036/param) | PPL:    998.50 | loss: 6.91 | LR: 3.00e-04
Step 419 | Tokens: 41.98M (0.036/param) | PPL:   1030.19 | loss: 6.94 | LR: 3.00e-04
Step 420 | Tokens: 42.07M (0.036/param) | PPL:    998.50 | loss: 6.91 | LR: 3.00e-04
Step 421 | Tokens: 42.17M (0.036/param) | PPL:   1096.63 | loss: 7.00 | LR: 3.00e-04
Step 422 | Tokens: 42.27M (0.036/param) | PPL:    938.00 | loss: 6.84 | LR: 3.00e-04
Step 423 | Tokens: 42.37M (0.036/param) | PPL:   1131.44 | loss: 7.03 | LR: 3.00e-04
Step 424 | Tokens: 42.47M (0.036/param) | PPL:    909.14 | loss: 6.81 | LR: 3.00e-04
Step 425 | Tokens: 42.57M (0.036/param) | PPL:    909.14 | loss: 6.81 | LR: 3.00e-04
Step 426 | Tokens: 42.66M (0.037/param) | PPL:   1131.44 | loss: 7.03 | LR: 3.00e-04
Step 427 | Tokens: 42.76M (0.037/param) | PPL:    854.06 | loss: 6.75 | LR: 3.00e-04
Step 428 | Tokens: 42.86M (0.037/param) | PPL:    909.14 | loss: 

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7fe6398d3620>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fe6394ecfa0, execution_count=12 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7fe6383653d0, raw_cell="total_params = sum(p.numel() for p in model.parame.." transformed_cell="total_params = sum(p.numel() for p in model.parame.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a226475626c696e6572227d/tank/nick/byte-gpt/archie.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost