# BERT Fine-tuning for Burney Attribution

This notebook trains a BERT model for 18th-century authorship attribution.

**Setup Instructions:**
1. Runtime → Change runtime type → GPU (T4 or better)
2. Upload your `burney-attribution` folder to Colab or mount Google Drive
3. Run all cells

**Expected time:** 1-2 hours on T4 GPU

## 1. Check GPU Availability

In [None]:
!nvidia-smi

## 2. Install Dependencies

In [None]:
!pip install -q transformers>=4.30.0 datasets>=2.12.0 accelerate>=1.1.0 scikit-learn

## 3. Mount Google Drive (Optional)

If you've uploaded the data to Google Drive, mount it here:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Adjust this path to where you uploaded burney-attribution
# PROJECT_PATH = '/content/drive/MyDrive/burney-attribution'
# import os
# os.chdir(PROJECT_PATH)

## 4. Alternative: Upload Data Directly

If you don't want to use Google Drive, you can upload the prepared dataset:

In [None]:
# Uncomment if uploading a zip file
# from google.colab import files
# uploaded = files.upload()  # Upload burney_data.zip
# !unzip -q burney_data.zip
# !ls data/

## 5. Training Script

This is the complete training code:

In [None]:
import json
import numpy as np
from pathlib import Path
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, f1_score, classification_report
import torch

print("="*60)
print("BERT AUTHORSHIP ATTRIBUTION TRAINING")
print("="*60)

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

### Set Data Paths

In [None]:
# Adjust these paths based on how you uploaded the data
DATA_DIR = Path('data/bert_data')  # or '/content/drive/MyDrive/burney-attribution/data/bert_data'
OUTPUT_DIR = Path('models/bert_authorship')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

### Load Data and Model

In [None]:
# Load label mapping
print("\nLoading label mapping...")
with open(DATA_DIR / 'label_mapping.json', 'r') as f:
    label_info = json.load(f)

author_to_id = label_info['author_to_id']
id_to_author = {int(k): v for k, v in label_info['id_to_author'].items()}
num_labels = len(author_to_id)

print(f"Training for {num_labels} authors: {sorted(author_to_id.keys())}")

# Load datasets
print("\nLoading datasets...")
datasets = load_from_disk(str(DATA_DIR / 'chunked_datasets'))

print(f"Train: {len(datasets['train'])} chunks")
print(f"Val: {len(datasets['validation'])} chunks")
print(f"Test: {len(datasets['test'])} chunks")

# Load model
print("\nLoading BERT model...")
model_name = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    id2label=id_to_author,
    label2id=author_to_id
)

print(f"✓ Loaded {model_name}")

### Tokenize Data

In [None]:
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=512
    )

print("Tokenizing datasets...")
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=['text', 'author', 'title', 'year']
)

tokenized_datasets.set_format('torch')
print("✓ Tokenization complete")

### Define Training Configuration

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_weighted = f1_score(labels, predictions, average='weighted')
    
    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted
    }

# Training arguments optimized for Colab
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,  # Larger batch on GPU
    per_device_eval_batch_size=16,
    num_train_epochs=3,  # Reduced for faster training
    weight_decay=0.01,
    warmup_steps=500,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model='f1_weighted',
    save_total_limit=2,
    fp16=True,  # Mixed precision for speed
    report_to='none'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
)

print("✓ Trainer configured")

## 6. Train the Model

This will take 1-2 hours on a T4 GPU:

In [None]:
print("\n" + "="*60)
print("TRAINING")
print("="*60)

train_result = trainer.train()

print("\n✓ Training complete!")

## 7. Save Model

In [None]:
# Save model and tokenizer
trainer.save_model(str(OUTPUT_DIR / 'final'))
tokenizer.save_pretrained(str(OUTPUT_DIR / 'final'))

# Save training metrics
metrics = train_result.metrics
with open(OUTPUT_DIR / 'train_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"✓ Model saved to {OUTPUT_DIR / 'final'}")

## 8. Evaluate on Validation Set

In [None]:
print("\n" + "="*60)
print("VALIDATION RESULTS")
print("="*60)

val_metrics = trainer.evaluate()

print(f"\nValidation Accuracy: {val_metrics['eval_accuracy']:.2%}")
print(f"Validation F1 (macro): {val_metrics['eval_f1_macro']:.3f}")
print(f"Validation F1 (weighted): {val_metrics['eval_f1_weighted']:.3f}")

# Detailed predictions
val_predictions = trainer.predict(tokenized_datasets['validation'])
val_preds = np.argmax(val_predictions.predictions, axis=1)
val_labels = val_predictions.label_ids

print("\nValidation Classification Report:")
print(classification_report(
    val_labels,
    val_preds,
    target_names=sorted(author_to_id.keys()),
    zero_division=0
))

## 9. Evaluate on Test Set

In [None]:
print("\n" + "="*60)
print("TEST RESULTS")
print("="*60)

test_predictions = trainer.predict(tokenized_datasets['test'])
test_preds = np.argmax(test_predictions.predictions, axis=1)
test_labels = test_predictions.label_ids

test_accuracy = accuracy_score(test_labels, test_preds)
test_f1_macro = f1_score(test_labels, test_preds, average='macro')
test_f1_weighted = f1_score(test_labels, test_preds, average='weighted')

print(f"\nTest Accuracy: {test_accuracy:.2%}")
print(f"Test F1 (macro): {test_f1_macro:.3f}")
print(f"Test F1 (weighted): {test_f1_weighted:.3f}")

print("\nTest Classification Report:")
print(classification_report(
    test_labels,
    test_preds,
    target_names=sorted(author_to_id.keys()),
    zero_division=0
))

# Save test results
test_results = {
    'accuracy': float(test_accuracy),
    'f1_macro': float(test_f1_macro),
    'f1_weighted': float(test_f1_weighted)
}

with open(OUTPUT_DIR / 'test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"\n✓ Test results saved to {OUTPUT_DIR / 'test_results.json'}")

## 10. Compare with Baseline

In [None]:
print("\n" + "="*60)
print("COMPARISON WITH BASELINE")
print("="*60)

baseline_accuracy = 0.80

print(f"Baseline (Burrows' Delta): {baseline_accuracy:.1%}")
print(f"BERT:                      {test_accuracy:.1%}")

improvement = test_accuracy - baseline_accuracy
if improvement > 0:
    print(f"\n✓ BERT improves on baseline by {improvement:.1%}")
elif improvement < 0:
    print(f"\n⚠ BERT underperforms baseline by {-improvement:.1%}")
else:
    print(f"\nBERT matches baseline performance")

print(f"\n{'='*60}")
print("TRAINING COMPLETE!")
print(f"{'='*60}")

## 11. Download Results

Download the trained model and results back to your computer:

In [None]:
# Zip the model directory
!zip -r -q bert_model.zip models/

# Download
from google.colab import files
files.download('bert_model.zip')

print("✓ Download started!")