In [None]:
import os

if not os.path.exists('/content/police-records-project'):
    !git clone https://github.com/c-goenka/police-records-project.git
    %cd /content/police-records-project
    !pip install -r requirements.txt
else:
    %cd /content/police-records-project

from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import numpy as np
from setfit import SetFitModel, Trainer, TrainingArguments
from datasets import Dataset
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

In [None]:
data_dir = "/content/drive/MyDrive/police-records-project-data/processed"

train_df = pd.read_csv(f"{data_dir}/train.csv")
test_df = pd.read_csv(f"{data_dir}/test.csv")

In [None]:
label_to_id = {label: idx for idx, label in enumerate(sorted(train_df['label'].unique()))}
id_to_label = {idx: label for label, idx in label_to_id.items()}

train_df['label_id'] = train_df['label'].map(label_to_id)
test_df['label_id'] = test_df['label'].map(label_to_id)

train_dataset = Dataset.from_dict({
    'text': train_df['text_clean'].tolist(),
    'label': train_df['label_id'].tolist()
})

test_dataset = Dataset.from_dict({
    'text': test_df['text_clean'].tolist(),
    'label': test_df['label_id'].tolist()
})

In [None]:
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    labels=list(label_to_id.keys())
)

print(f"Model loaded: {model.model_body.config.name_or_path}")
print(f"Number of labels: {len(label_to_id)}")

In [None]:
args = TrainingArguments(
    batch_size=16,
    num_epochs=3,
    evaluation_strategy="no",
    save_strategy="no",
    seed=RANDOM_SEED
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset
)

trainer.train()
print("SetFit Model Training complete")

In [None]:
predictions = model.predict(test_df['text_clean'].tolist())
y_true = test_df['label_id'].values
y_pred = predictions

print(f"Test samples: {len(y_true)}")
print(f"Predictions: {len(y_pred)}")