In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
import datasets
from typing import List, Tuple
from datasets import load_dataset

class PromptDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

def prepare_data(gsm8k: datasets.Dataset, humaneval: datasets.Dataset, sqleval: datasets.Dataset) -> Tuple[List[str], List[int]]:


    gsm_prompts = [str(item['question']) for item in gsm8k]
    humaneval_prompts = [str(item['prompt']) for item in humaneval]
    sqleval_prompts = [str(item['question']) for item in sqleval]

    
    gsm_labels = [0] * len(gsm_prompts)
    humaneval_labels = [1] * len(humaneval_prompts)
    sqleval_labels = [2] * len(sqleval_prompts)
    
    all_prompts = gsm_prompts + humaneval_prompts + sqleval_prompts
    all_labels = gsm_labels + humaneval_labels + sqleval_labels
    
    return all_prompts, all_labels

def train_bert_classifier(prompts: List[str], labels: List[int], 
                         batch_size=16, num_epochs=3, learning_rate=2e-5):

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3, problem_type="single_label_classification")
    
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        prompts, labels, test_size=0.2, random_state=42
    )
    
    train_dataset = PromptDataset(train_texts, train_labels, tokenizer)
    val_dataset = PromptDataset(val_texts, val_labels, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}')
        
        model.train()
        train_loss = 0
        train_steps = 0
        
        for batch in tqdm(train_loader, desc='Training'):
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_steps += 1
        
        avg_train_loss = train_loss / train_steps
        print(f'Average training loss: {avg_train_loss:.4f}')
        
        model.eval()
        val_loss = 0
        val_steps = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                
                val_loss += loss.item()
                val_steps += 1
                
                predictions = torch.argmax(outputs.logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / val_steps
        accuracy = correct / total
        print(f'Average validation loss: {avg_val_loss:.4f}')
        print(f'Validation accuracy: {accuracy:.4f}')
    
    return model, tokenizer

if __name__ == "__main__":
    gsm8k = load_dataset("gsm8k", "main", split="train")
    humaneval = load_dataset("openai_humaneval", split="test")
    sqleval = load_dataset("csv", data_files="./data/questions_gen_postgres.csv", split="train")
    
    prompts, labels = prepare_data(gsm8k, humaneval, sqleval)
    
    model, tokenizer = train_bert_classifier(prompts, labels)
    
    model.save_pretrained("../models/prompt_classifier")
    tokenizer.save_pretrained("../models/prompt_classifier")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/3


Training: 100%|██████████| 393/393 [01:21<00:00,  4.80it/s]


Average training loss: 0.0569


Validation: 100%|██████████| 99/99 [00:08<00:00, 12.00it/s]


Average validation loss: 0.0034
Validation accuracy: 1.0000

Epoch 2/3


Training: 100%|██████████| 393/393 [01:26<00:00,  4.57it/s]


Average training loss: 0.0019


Validation: 100%|██████████| 99/99 [00:08<00:00, 11.70it/s]


Average validation loss: 0.0006
Validation accuracy: 1.0000

Epoch 3/3


Training: 100%|██████████| 393/393 [01:26<00:00,  4.55it/s]


Average training loss: 0.0005


Validation: 100%|██████████| 99/99 [00:08<00:00, 11.60it/s]


Average validation loss: 0.0003
Validation accuracy: 1.0000


In [6]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification

def predict_dataset(text: str, model_path: str = "../models/prompt_classifier") -> dict:

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BertForSequenceClassification.from_pretrained(model_path)
    tokenizer = BertTokenizer.from_pretrained(model_path)
    model.to(device)
    model.eval()
    
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(outputs.logits, dim=1)
        prediction = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][prediction].item()
    
    dataset_mapping = {0: "GSM8K (Math)", 1: "HumanEval (Programming)", 2: "SqlEval (SQL)"}
    result = {
        "predicted_dataset": dataset_mapping[prediction],
        "confidence": f"{confidence:.2%}",
        "probabilities": {
            "GSM8K": f"{probabilities[0][0].item():.2%}",
            "HumanEval": f"{probabilities[0][1].item():.2%}",
            "SqlEval": f"{probabilities[0][2].item():.2%}"
        }
    }
    
    return result

if __name__ == "__main__":
    math_prompt = """
    Janet has 3 apples. She buys 2 more apples from the store. 
    How many apples does Janet have in total?
    """
    
    programming_prompt = """
    def add_numbers(a: int, b: int) -> int:
        \"\"\"
        Add two integers and return their sum.
        
        Args:
            a: first integer
            b: second integer
            
        Returns:
            The sum of a and b
        \"\"\"
    """

    sql_prompt = """
    What's the name and rating of all the restaurants that have a rating greater than 4 and are located in the city of New York?
    """
    
    math_result = predict_dataset(math_prompt)
    programming_result = predict_dataset(programming_prompt)
    sql_result = predict_dataset(sql_prompt)

    
    print("\nMath Problem Prediction:")
    print(f"Predicted Dataset: {math_result['predicted_dataset']}")
    print(f"Confidence: {math_result['confidence']}")
    print("Probabilities:", math_result['probabilities'])
    
    print("\nProgramming Problem Prediction:")
    print(f"Predicted Dataset: {programming_result['predicted_dataset']}")
    print(f"Confidence: {programming_result['confidence']}")
    print("Probabilities:", programming_result['probabilities'])

    print("\nSQL Problem Prediction:")
    print(f"Predicted Dataset: {sql_result['predicted_dataset']}")
    print(f"Confidence: {sql_result['confidence']}")
    print("Probabilities:", sql_result['probabilities'])


Math Problem Prediction:
Predicted Dataset: GSM8K (Math)
Confidence: 99.98%
Probabilities: {'GSM8K': '99.98%', 'HumanEval': '0.01%', 'SqlEval': '0.01%'}

Programming Problem Prediction:
Predicted Dataset: HumanEval (Programming)
Confidence: 99.82%
Probabilities: {'GSM8K': '0.06%', 'HumanEval': '99.82%', 'SqlEval': '0.12%'}

SQL Problem Prediction:
Predicted Dataset: SqlEval (SQL)
Confidence: 99.82%
Probabilities: {'GSM8K': '0.07%', 'HumanEval': '0.11%', 'SqlEval': '99.82%'}
