In [1]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import DistilBertTokenizer, DistilBertModel
from torch.optim import AdamW
import numpy as np
from tqdm import tqdm

# Custom Dataset
class LossDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Use only the prompt (not the answer)
        text = item['prompt']

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Get loss values
        losses = torch.tensor([
            item['os1_loss'],
            item['os2_loss'],
            item['os3_loss']
        ], dtype=torch.float32)

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'losses': losses
        }

In [2]:
# Model Definition
class DistilBERTLossPredictor(nn.Module):
    def __init__(self, dropout=0.3):
        super(DistilBERTLossPredictor, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.regressor = nn.Linear(self.distilbert.config.hidden_size, 3)  # 3 outputs

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token representation (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]
        dropped = self.dropout(cls_output)
        predictions = self.regressor(dropped)
        return predictions

In [3]:
# Training Function
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc='Training'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['losses'].to(device)

        optimizer.zero_grad()
        predictions = model(input_ids, attention_mask)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Validation Function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['losses'].to(device)

            predictions = model(input_ids, attention_mask)
            loss = criterion(predictions, targets)
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [7]:
# Main Training Script
def main():
    data = []
    with open('router_training_data.jsonl', 'r') as f:
        for line in f:
            data.append(json.loads(line))

    data = data[0:25]

    # Hyperparameters
    MAX_LENGTH = 256
    BATCH_SIZE = 64
    LEARNING_RATE = 2e-5
    EPOCHS = 15
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {DEVICE}")
    print(f"Number of samples: {len(data)}")

    # Initialize tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    # Create full dataset
    full_dataset = LossDataset(data, tokenizer, MAX_LENGTH)

    # Split using torch random_split (80-20 split)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    # Initialize model
    model = DistilBERTLossPredictor().to(DEVICE)

    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

    # Training loop
    best_val_loss = float('inf')

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")

        train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE)
        val_loss = validate(model, val_loader, criterion, DEVICE)

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, 'best_model.pt')
            print(f"Model saved with validation loss: {val_loss:.4f}")

    print("\nTraining complete!")
    return model, tokenizer

In [8]:
# Prediction Function
def predict(model, tokenizer, prompt, device, max_length=256):
    model.eval()

    encoding = tokenizer(
        prompt,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        predictions = model(input_ids, attention_mask)

    return predictions.cpu().numpy()[0]

if __name__ == "__main__":
    # Train the model
    model, tokenizer = main()

    # Example prediction
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_prompt = "The woman detaches and holds up the filter on the vacuum cleaner. The woman stores the vacuum cleaners in a closet.\nThe two vacuum cleaners"

    predictions = predict(model, tokenizer, test_prompt, device)
    print(f"\nExample prediction:")
    print(f"OS1 Loss: {predictions[0]:.4f}")
    print(f"OS2 Loss: {predictions[1]:.4f}")
    print(f"OS3 Loss: {predictions[2]:.4f}")

Using device: cpu
Number of samples: 25
Train size: 20, Val size: 5


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]


Epoch 1/15


Training: 100%|██████████| 1/1 [00:38<00:00, 38.04s/it]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Train Loss: 44.8169
Val Loss: 25.2478
Model saved with validation loss: 25.2478

Epoch 2/15


Training:   0%|          | 0/1 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [9]:
model = DistilBERTLossPredictor()

checkpoint = torch.load("best_model.pt", map_location="cpu")  # use ‘cuda’ if you want GPU

model.load_state_dict(checkpoint["model_state_dict"])




<All keys matched successfully>