In [1]:
import torch
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import sys
import hydra
import torch
from omegaconf import DictConfig, OmegaConf

In [2]:
sys.path.append("..")
from model import JoeyLLM
from data import get_dataloader
from utils.logger import wandbLogger
# from train.trainer import  Trainer

In [3]:
print("✅ Loaded Config:")

# Go UP one level to find the conf directory
with hydra.initialize(config_path="../configs", version_base=None):
    cfg = hydra.compose(config_name="config")

wandbLogger.set_mode(cfg.wandb.mode)

logger = wandbLogger(
    project_name=cfg.wandb.project,
    config=OmegaConf.to_container(cfg, resolve=True)
)

✅ Loaded Config:


In [4]:
def compute_loss(outputs, labels):
    criterion = torch.nn.CrossEntropyLoss()
    B, T, V = outputs.size()
    outputs = outputs.view(B * T, V)    # [B*T, V]
    labels = labels.view(B * T)         # [B*T]
    return criterion(outputs, labels)

In [5]:
print("🧠 Initializing Model...")
model = JoeyLLM(
    vocab_size=cfg.model.vocab_size,
    max_seq_len=cfg.model.max_seq_len,
    embed_dim=cfg.model.embed_dim,
    num_layers=cfg.model.num_layers,
    num_heads=cfg.model.num_heads,
    dropout=cfg.model.dropout,
)
logger.watch_model(model, log="all", log_freq=10)

🧠 Initializing Model...


In [6]:

print("📦 Loading Dataset...")
dataloader = get_dataloader(
    data_path=cfg.data.data_path,
    chunk_size=cfg.data.chunk_size,
    buffer_text_size=cfg.data.buffer_text_size,
    batch_size=cfg.data.batch_size,
    num_workers=cfg.data.num_workers
)

📦 Loading Dataset...


In [7]:
one_batch = next(iter(dataloader))


In [8]:
type(one_batch)

dict

In [14]:
one_batch.keys()

dict_keys(['inputs', 'labels'])

In [16]:
print(one_batch["inputs"][0][:10])
print(one_batch["labels"][0][:10])

tensor([   91,   860,   287, 11579,  3962,  5659,    25, 57049, 28257,   369])
tensor([  860,   287, 11579,  3962,  5659,    25, 57049, 28257,   369,   279])


In [28]:
# def _train_epoch(self, epoch):
model.train()
total_loss = 0

progress_bar = tqdm(dataloader, desc=f"Steps", leave=True)
scaler = GradScaler()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.optim as optim

optimizer = optim.AdamW(
    model.parameters(),
    lr=2e-5,             # learning rate (tune this!)
    betas=(0.9, 0.95),   # GPT-like betas
    eps=1e-8,            # for numerical stability
    weight_decay=0.01    # decoupled weight decay
)
model = model.to(device)


Steps: 0it [00:18, ?it/s]


In [33]:

for batch_idx, batch in enumerate(progress_bar):
    # Handle dict or tuple batch format
    inputs = batch["inputs"].to(device)
    labels = batch["labels"].to(device)

    optimizer.zero_grad()

    with autocast(device_type="cuda"):
        outputs = model(inputs)
        loss = compute_loss(outputs, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    total_loss += loss.item()

    print(f"Epoch 1 | Training Loss: {total_loss:.4f}")


Epoch 1 | Training Loss: 2213.9766
Epoch 1 | Training Loss: 2704.3449
Epoch 1 | Training Loss: 3145.9790
Epoch 1 | Training Loss: 3538.5645
Epoch 1 | Training Loss: 3888.6259
Epoch 1 | Training Loss: 4198.3061
Epoch 1 | Training Loss: 4475.4331
Epoch 1 | Training Loss: 4724.1740
Epoch 1 | Training Loss: 4950.1433
Epoch 1 | Training Loss: 5158.5845
Epoch 1 | Training Loss: 5349.2645
Epoch 1 | Training Loss: 5525.5820
Epoch 1 | Training Loss: 5688.5111
Epoch 1 | Training Loss: 5842.1418
Epoch 1 | Training Loss: 5986.2199
Epoch 1 | Training Loss: 6122.9154
Epoch 1 | Training Loss: 6252.9100
Epoch 1 | Training Loss: 6378.2153
Epoch 1 | Training Loss: 6499.3868
Epoch 1 | Training Loss: 6616.3841
Epoch 1 | Training Loss: 6730.7193
Epoch 1 | Training Loss: 6841.5037
Epoch 1 | Training Loss: 6949.1263
Epoch 1 | Training Loss: 7054.5111
Epoch 1 | Training Loss: 7158.2688
Epoch 1 | Training Loss: 7260.0808
Epoch 1 | Training Loss: 7360.3504
Epoch 1 | Training Loss: 7459.2471
Epoch 1 | Training L

KeyboardInterrupt: 

In [None]:

def save_checkpoint(self, path):
    checkpoint = {
        "model_state": self.model.state_dict(),
        "optimizer_state": self.optimizer.state_dict(),
        "scaler_state": self.scaler.state_dict()
    }
    if self.scheduler:
        checkpoint["scheduler_state"] = self.scheduler.state_dict()
    torch.save(checkpoint, path)
    print(f"✅ Checkpoint saved to {path}")


In [None]:

def load_checkpoint(self, path):
    checkpoint = torch.load(path)
    self.model.load_state_dict(checkpoint["model_state"])
    self.optimizer.load_state_dict(checkpoint["optimizer_state"])
    self.scaler.load_state_dict(checkpoint["scaler_state"])
    if self.scheduler and "scheduler_state" in checkpoint:
        self.scheduler.load_state_dict(checkpoint["scheduler_state"])
    print(f"✅ Checkpoint loaded from {path}")


In [None]:

def fit(self, num_epochs=20, checkpoint_path="checkpoints/checkpoint.pth"):
    for epoch in range(1, num_epochs + 1):
        train_loss = self._train_epoch(epoch)
        self.save_checkpoint(checkpoint_path)

        if self.scheduler:
            self.scheduler.step()

    print("🏁 Training complete!")
