## Introduction

This notebook fine-tunes an `all-mpnet-base-v2` model to classify human-generated versus AI-generated text. The fine-tuning data comes from three contexts:
 * `grover`: news articles,
 * `wiki`: Wikipedia intro paragraphs, and
 * `reviews`: product reviews.

Our dataset consists of 10,000 human-generated and 10,000 AI-generated samples in each context (prior to our train / validation / test split). For citations and more information on our data, see the readme in our `raw data` folder. 

We can restrict our model to only train on data from a subset of the contexts (see the `TRAIN_DATA_CONTEXTS` variable below). After every epoch, we validate our model on data from *all* contexts, allowing us to test how well it generalizes to unseen contexts. For results and discussion, see the readme in the base folder of the repository.

## Package imports

In [None]:
# Basic imports
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from pathlib import Path
from datetime import datetime

# Torch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import Dataset

# Transformer imports
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

## Setup

In [None]:
# Choose which contexts training data should come from
TRAINING_DATA_CONTEXTS = ['grover', 'wiki', 'reviews'] # Any nonempty subset of ['grover', 'wiki', 'reviews']

In [None]:
# Set training device
DEVICE = None
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using {DEVICE}")

In [None]:
# Misc flags
ABRIDGED_RUN = False # If True, downsamples our training data for quick test runs
FREEZE_EMBEDDING = False # If True, we only train the classification head 

In [None]:
# Training hyperparameters
BATCH_SIZE = 16 
NUM_EPOCHS = 10
WARMUP_PROPORTION = 0.5 # Number of epochs to use as warmup steps
LEARNING_RATE = 2e-5

In [None]:
# Directories
OUTPUT_DIR = Path(f"training_results/{"baseline_" if FREEZE_EMBEDDING else ""}{"_".join(TRAINING_DATA_CONTEXTS)}_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR = Path("../raw data/")

## Data loading and preprocessing

In [None]:
data = pd.read_csv(DATA_DIR / "combined_data.csv")

# Throw out essays data
filt = (data['Original dataset'] != 'essays')
data = data[filt]

# Downsample if abridged run
if ABRIDGED_RUN:
    data = data.sample(320)

# Create a column to stratify our train test split by, accounting for both label and context
data['Stratify'] = data['Label'] + " " + data['Original dataset']

# Set integer labels for training: Human -> 0, Machine -> 1
data['Label'] = data['Label'].apply(lambda x: 0 if x == 'Human' else 1)

# Train test split, stratified by labels and original dataset
data_train, data_test = train_test_split(data, test_size = 0.2, stratify=data['Stratify'], random_state = 406)
train_tt, train_vv = train_test_split(data_train, test_size = 0.2, stratify=data_train['Stratify'], random_state = 406)

# Restrict training set to only the specified contexts
filt = (train_tt['Original dataset'].isin(TRAINING_DATA_CONTEXTS))
train_tt = train_tt[filt]

In [None]:
# Put training data into Pytorch dataset format
train_dataset = Dataset.from_dict({
    "text": train_tt['Text'],
    "label": train_tt['Label'],
})

# Create dataloader for training
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
# Load our tokenizer
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")

In [None]:
def smart_batching_collate(batch):
    """Collate a batch of inputs for training.
    
    Adapted from https://github.com/UKPLab/sentence-transformers/blob/7290448809cb73f08f63c955550815775434beb4/sentence_transformers/cross_encoder/CrossEncoder.py#L143
    Returns tokenized texts stacked into a batch, and a tensor with labels for each text."""
    
    texts = []
    labels = []

    for example in batch:
        texts.append(example['text'])
        labels.append(example['label'])

    texts_tokenized = tokenizer(texts, padding=True, truncation="longest_first", return_tensors="pt")
    for key in texts_tokenized:
        texts_tokenized[key] = texts_tokenized[key].to(DEVICE)

    labels = torch.tensor(labels, dtype=torch.long).to(DEVICE)
    
    return texts_tokenized, labels

In [None]:
train_dataloader.collate_fn = smart_batching_collate

## Load model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("sentence-transformers/all-mpnet-base-v2", num_labels = 1)
model = model.to(DEVICE)

if FREEZE_EMBEDDING:
    for p in model.mpnet.parameters():
        p.requires_grad = False

## Training

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.classifier.parameters() if FREEZE_EMBEDDING else model.parameters(), lr = LEARNING_RATE)
scheduler = SentenceTransformer._get_scheduler(optimizer, scheduler = 'WarmupLinear', warmup_steps = WARMUP_PROPORTION * len(train_dataloader), t_total = len(train_dataloader)*NUM_EPOCHS)

In [None]:
def validation_accuracy(model, tokenizer, validation_samples):
    """Compute model accuracy stratified over our three data contexts.

    Returns dictionary with keys 'grover', 'wiki', and 'reviews' and values
    equal to model accuracy on each data context from validation_samples."""

    contexts = ['grover', 'wiki', 'reviews']
    accuracies = {c:0 for c in contexts}
    
    for context in contexts:
        context_validation_samples = validation_samples[validation_samples['Original dataset'] == context]
        for index, sample in context_validation_samples.iterrows():
            features = tokenizer(sample['Text'], padding=True, truncation="longest_first", return_tensors="pt")
            for key in features:
                features[key] = features[key].to(DEVICE)
            logits = model(**features, return_dict = True).logits
            prob = nn.Sigmoid()(logits).item()
            label = sample['Label']
            accuracies[context] += int(round(prob) == label)
            
    for context in contexts:
        accuracies[context] = accuracies[context]/(validation_samples['Original dataset'] == context).sum()
        
    return accuracies

In [None]:
# Write accuracies before training
with open(OUTPUT_DIR/'accuracies.csv', 'w') as f:
    f.write('epoch, grover, wiki, reviews\n')
    with torch.no_grad():
        acc = validation_accuracy(model, tokenizer, train_vv)
        f.write(f'-1, {acc['grover']}, {acc['wiki']}, {acc['reviews']}\n')

In [None]:
# Training loop
for epoch in tqdm(range(NUM_EPOCHS), position = 0):
    
    for features, labels in tqdm(train_dataloader, position = 1, leave = False):
        logits = model(**features, return_dict=True).logits
        loss = criterion(logits, labels.unsqueeze(1).float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Write accuracies for this epoch
    with torch.no_grad():
        acc = validation_accuracy(model, tokenizer, train_vv)
    with open(OUTPUT_DIR/'accuracies.csv', 'a') as f:
        f.write(f'{epoch}, {acc['grover']}, {acc['wiki']}, {acc['reviews']}\n')

    # Update learning rate scheduler
    scheduler.step()

In [None]:
# Write accuracies on final test set
with torch.no_grad():
    acc = validation_accuracy(model, tokenizer, data_test)
with open(OUTPUT_DIR/'accuracies.csv', 'a') as f:
    f.write(f'TEST SET, {acc['grover']}, {acc['wiki']}, {acc['reviews']}\n')

print(f"""Accuracies on held out test set:
    grover: {100*acc['grover']:.2f}
    wiki: {100*acc['wiki']:.2f}
    reviews: {100*acc['reviews']:.2f}""")