# Dog-BERT classifier 
Here, we train a BERT classifier that identifies input as being dog-related or not. This is needed to classify our examples from fineweb/lmsys.

In [15]:
from tqdm import tqdm
import random
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
import os
import sys

# add parent directory to sys.path 
sys.path.append(os.path.join(os.getcwd(), '..'))
from py_helpers import bert

# These are for eval charting
from matplotlib import pyplot as plt
from IPython.display import clear_output

# Load DistilBert from transformers
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

device = torch.device('mps')

In [2]:
# Import train + test data (so dogbert can get stronk B)) 
train = pd.read_csv(os.getcwd() + '/dogbert/train.csv')
test = pd.read_csv(os.getcwd() + '/dogbert/test.csv')

# Training 
Here, we take the GPT-labeled data and use it to train the dog classifier.

In [16]:
# Import pre-trained tokenizer to convert post content into tokens
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")

# Instantiate a model 
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels = 2).to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [61]:
# batched training examples
batched_train = {}
batched_prompts = []
batches = bert.chunk_dataframe(train, size = 10)
for i, batch in enumerate(batches):
    # for a single row/text entry of a set within a single batch
    for obs in batch['phi3_text']:
        tokenized_input = tokenizer([obs], return_tensors = 'pt', 
                                    max_length = 512, 
                                    padding = 'max_length').to(device)

        # print(tokenized_input)
        
        batched_prompts.append({
            'content': obs,
            'input_ids': tokenized_input['input_ids'],
            'attention_mask': tokenized_input['attention_mask']
        })
        
    batched_train[i] = batched_prompts

In [None]:
#### MODEL TRAINING #####
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-5) # instantiate optimizer
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.5)

train_loss = []
test_loss = []

model.train()
    
for batch_iteration, batch in tqdm(enumerate(batched_train)):
    
    # Zero out grads so they don't accumulate
    model.zero_grad()
    
    # Forward pass 
    logits = model(batch['input_ids'], batch['attention_mask']).logits
    # Obtain the diff. between the predicted probabilities and the true ones
    loss = F.cross_entropy(logits, batch['labels'])

    # Backwards - via deriv. of loss function
    loss.backward()

    # Update model params 
    optimizer.step()

    # Logging and eval 
    train_loss.append({
        'batch_iteration': batch_iteration, 
        'cross_entropy_loss': loss.item()
    })
               
    # Checking for overfit (every 10 steps)
    if batch_iteration % 20 == 0 and (batch_iteration > 0):
        bert.gut_check(k, model, tokenizer, device)
        examples_res = bert.eval_performance_on_examples(batched_datasets['test'], len(batched_datasets['test']), model)
        model.train()
        test_loss.append({
            'batch_iteration': batch_iteration, 
            'precision': examples_res['precision'], 
            'recall': examples_res['recall'],
            'cross_entropy_loss': examples_res['cross_entropy_loss'],
            'accuracy': examples_res['accuracy']
        })
        
    # Plot output (print every 100 steps)
    if batch_iteration % 20 == 0 and (batch_iteration > 0):
        # clear_output(wait = True)
        # Plot - remove first 19 elements for train loss
        train_smoothed = np.convolve(np.array([h['cross_entropy_loss'] for h in train_loss]), np.ones(20)/20, mode= 'valid')
        plt.plot([h['batch_iteration'] for h in train_loss][19:], train_smoothed, 'r')
        plt.plot([h['batch_iteration'] for h in test_loss], [h['cross_entropy_loss'] for h in test_loss], 'g')
        plt.show()

##  Model storage ##
# Save torch model 
torch.save(model.state_dict(), f"../dogbert.pt")
