In [1]:
import torch
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from llama2 import ModelConfig, Transformer

In [2]:
# Load the dataset
dataset = load_dataset('ag_news')

# Load a tokenizer (you can choose a tokenizer compatible with your model)
tokenizer_name = 'bert-base-uncased'  # Example tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# Tokenize the text
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Split the dataset into training, validation, and test sets
train_val_split = tokenized_datasets['train'].train_test_split(test_size=0.1)
train_dataset = train_val_split['train']
val_dataset = train_val_split['test']
test_dataset = tokenized_datasets['test']

# Convert datasets to PyTorch format
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create DataLoaders
batch_size = 8
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [3]:
# Initialize the model
config = ModelConfig(vocab_size=len(tokenizer.vocab))
model = Transformer(config).to(config.device)

In [4]:
# Loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

In [5]:
# Training function
def train_epoch(model, data_loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        outputs = model(input_ids, 0)
        loss = loss_fn(outputs.vies(-1, config.vocab_size), labels.view(-1))

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

In [6]:
# Validation function
def validate_epoch(model, data_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)

            total_loss += loss.item()

    return total_loss / len(data_loader)

In [7]:
# Training loop
epochs = 3
for epoch in range(epochs):
    train_loss = train_epoch(model, train_dataloader, loss_fn, optimizer, config.device)
    val_loss = validate_epoch(model, val_dataloader, loss_fn, config.device)
    print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

ValueError: too many values to unpack (expected 3)