In [None]:
import sys
sys.path.append('/content')
from transformer import Transformer
from autoencoder import SparseAutoencoder
from data_utils import get_batch_iterator
import torch, h5py, os
import numpy as np
from tqdm import tqdm

In [None]:
config = {
    'vocab_size': 50304,
    'context_length': 128,
    'n_embed': 128,
    'n_head': 8,
    'train_path': 'data/med_pile_train.h5',
    'dev_path': 'data/pile_val.h5',
    't_out_path': 'models/transformer_full.pt',
    'a_out_path': 'models/autoencoder.pt',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
def get_small_batch_iterator(data_path, batch_size, context_length, device='cpu'):
    with h5py.File(data_path, 'r') as f:
        tokens = f['tokens'][:]
    n_examples = (len(tokens) - 1) // context_length
    indices = np.arange(n_examples)
    def generator():
        while True:
            np.random.shuffle(indices)
            for i in range(0, n_examples - batch_size + 1, batch_size):
                batch_idx = indices[i:i + batch_size]
                samples = torch.tensor(np.array([
                    tokens[j * context_length: j * context_length + context_length + 1] for j in batch_idx
                ])).long()
                xb = samples[:, :-1].to(device)
                yb = samples[:, 1:].to(device)
                yield xb, yb
    return generator()

In [None]:
device = config['device']
model = Transformer(
    n_head=config['n_head'],
    n_embed=config['n_embed'],
    context_length=config['context_length'],
    vocab_size=config['vocab_size']
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
iterator = get_small_batch_iterator(config['train_path'], batch_size=8, context_length=128, device=device)

In [None]:
losses = []
pbar = tqdm(range(500))
for step in pbar:
    xb, yb = next(iterator)
    _, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    pbar.set_description(f"Step {step} | Loss {np.mean(losses[-50:]):.4f}")

os.makedirs("models", exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict()
}, config['t_out_path'])
print("Transformer saved to:", config['t_out_path'])

In [None]:
autoencoder = SparseAutoencoder(
    n_features=512,
    n_embed=config['n_embed']
).to(device)
optimizer_ae = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)

In [None]:
recon_losses = []
reg_losses = []
pbar = tqdm(range(200))
for _ in pbar:
    xb, _ = next(iterator)
    with torch.no_grad():
        x_embed, _ = model.forward_embedding(xb)
    rand_idx = torch.randint(config['context_length'], (xb.shape[0],))
    embed_samples = x_embed[range(xb.shape[0]), rand_idx, :]
    flat_embed = embed_samples.view(xb.shape[0], -1)
    optimizer_ae.zero_grad()
    _, recon_loss, reg_loss = autoencoder(flat_embed, compute_loss=True)
    loss = recon_loss + 3e-3 * reg_loss
    loss.backward()
    optimizer_ae.step()
    autoencoder.normalize_decoder_weights()
    recon_losses.append(recon_loss.item())
    reg_losses.append(reg_loss.item())
    pbar.set_description(f"Recon: {np.mean(recon_losses[-50:]):.3f}, Reg: {np.mean(reg_losses[-50:]):.3f}")

torch.save({'model_state_dict': autoencoder.state_dict()}, config['a_out_path'])
print("Autoencoder saved to:", config['a_out_path'])