# Dog-BERT classifier 
Here, we train a BERT classifier that identifies input as being dog-related or not. This is needed to classify our examples from fineweb/lmsys.

In [6]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd

# These are for eval charting
from matplotlib import pyplot as plt
from IPython.display import clear_output

# Load DistilBert from transformers
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

device = torch.device('mps')

Matplotlib is building the font cache; this may take a moment.


ModuleNotFoundError: No module named 'transformers'

# GPT-labeling  
Generate GPT labels for the data that will be used to train dogbert 

# Training 
Here, we take the GPT-labeled data and use it to train the dog classifier.

In [None]:
# Import pre-trained tokenizer to convert post content into tokens
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")

# Instantiate a model 
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels = 2).to(device)

# Set seed for reproducibility for now 
random.seed(42)

# Split into train (80%), test (10%), and validation (10%) sets. 
n1 = int(0.8*len(shuffled))
n2 = int(0.9*len(shuffled))

raw_datasets = {
    'train': shuffled[:n1],
    'test': shuffled[n1:n2],
    'val': shuffled[n2:]
}

# Batch and tokenize 
batched_datasets = {}
for key, df in raw_datasets.items():
    df = df.assign(subset_type = key)
    for_batching = df[['content', k]].to_dict('records')
    
    # Batch datasets 
    batched_posts = []
    for batch in chunking(for_batching, size = 10):
        tokenized_input = tokenizer([post['content'] for post in batch], return_tensors = 'pt', max_length = 512, padding = 'max_length').to(device)
        
        batched_posts.append({
            'content': [post['content'] for post in batch],
            'labels': torch.tensor([post[k] for post in batch], dtype = int).to(device),
            'input_ids': tokenized_input['input_ids'],
            'attention_mask': tokenized_input['attention_mask']
        })

    batched_datasets[key] = batched_posts

#### MODEL TRAINING #####
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-5) # instantiate optimizer
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.5)

train_loss = []
test_loss = []

model.train()
    
for batch_iteration, batch in tqdm(enumerate(batched_datasets['train'])):
    
    # Zero out grads so they don't accumulate
    model.zero_grad()
    
    # Forward pass 
    logits = model(batch['input_ids'], batch['attention_mask']).logits
    # Obtain the diff. between the predicted probabilities and the true ones
    loss = F.cross_entropy(logits, batch['labels'])

    # Backwards - via deriv. of loss function
    loss.backward()

    # Update model params 
    optimizer.step()

    # Logging and eval 
    train_loss.append({
        'batch_iteration': batch_iteration, 
        'cross_entropy_loss': loss.item()
    })
               
    # Checking for overfit (every 10 steps)
    if batch_iteration % 20 == 0 and (batch_iteration > 0):
        gut_check(k, model, tokenizer, device)
        examples_res = eval_performance_on_examples(batched_datasets['test'], len(batched_datasets['test']), model)
        model.train()
        test_loss.append({
            'batch_iteration': batch_iteration, 
            'precision': examples_res['precision'], 
            'recall': examples_res['recall'],
            'cross_entropy_loss': examples_res['cross_entropy_loss'],
            'accuracy': examples_res['accuracy']
        })
        
    # Plot output (print every 100 steps)
    if batch_iteration % 20 == 0 and (batch_iteration > 0):
        # clear_output(wait = True)
        # Plot - remove first 19 elements for train loss
        train_smoothed = np.convolve(np.array([h['cross_entropy_loss'] for h in train_loss]), np.ones(20)/20, mode= 'valid')
        plt.plot([h['batch_iteration'] for h in train_loss][19:], train_smoothed, 'r')
        plt.plot([h['batch_iteration'] for h in test_loss], [h['cross_entropy_loss'] for h in test_loss], 'g')
        plt.show()

##  Model storage ##
# Save torch model 
torch.save(model.state_dict(), f"../models/{model_names[k]}.pt")
