In [1]:
from read_jsonl import read_jsonl
from datasets import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DistilBertTokenizer, DistilBertForSequenceClassification
from sklearn.model_selection import train_test_split
import torch

In [2]:
# Detect MPS (Apple Silicon GPU)
force_cpu = False
device = torch.device("mps" if torch.backends.mps.is_available() and not force_cpu else "cpu")
print(f"Using device: {device}")

Using device: mps


In [3]:
df = read_jsonl("DB-bio/combined_train_and_train_sft_anonymized.jsonl")

In [4]:
df.shape, df.columns

((3876, 2), Index(['text', 'label'], dtype='object'))

In [5]:
test_size = 0.3

# train/test split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["text"].tolist(), df["label"].tolist(), test_size=test_size, random_state=42
)

train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
val_dataset = Dataset.from_dict({"text": val_texts, "label": val_labels})

In [6]:
use_distillbert = False

In [7]:
# Tokenization
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") if use_distillbert else BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=64)

train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)

train_dataset = train_dataset.remove_columns(["text"])
val_dataset = val_dataset.remove_columns(["text"])

train_dataset.set_format("torch")
val_dataset.set_format("torch")

Map:   0%|          | 0/2713 [00:00<?, ? examples/s]

Map:   0%|          | 0/1163 [00:00<?, ? examples/s]

In [8]:
# Load model and move to MPS
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) if use_distillbert else BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [9]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
# Metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), axis=1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    correct = (preds == torch.tensor(labels)).sum().item()
    total = len(labels)
    print(f"\nCorrect predictions: {correct} / {total}")
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# Training arguments
training_args = TrainingArguments(
    use_cpu= force_cpu,
    dataloader_pin_memory=False,  # suppress pin_memory warning
    disable_tqdm=True,
    output_dir="./results/distilbert" if use_distillbert else "./results/bert",
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=5e-6,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=8,
    eval_strategy="steps",
    eval_steps=32,
    report_to="none",
    fp16=False  # disable fp16 for MPS
)

In [10]:
# Custom Trainer to support MPS
class MPSTrainer(Trainer):
    def _move_model_to_device(self, model, device):
        print(device)
        model.to(device)

    def _prepare_inputs(self, inputs):
        # Force inputs to MPS
        return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

# Use custom Trainer
trainer = MPSTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

mps


In [11]:
trainer.train()

{'loss': 0.6109, 'grad_norm': 7.521792888641357, 'learning_rate': 4.7941176470588245e-06, 'epoch': 0.047058823529411764}
{'loss': 0.5616, 'grad_norm': 7.2485737800598145, 'learning_rate': 4.558823529411765e-06, 'epoch': 0.09411764705882353}
{'loss': 0.4946, 'grad_norm': 7.49031400680542, 'learning_rate': 4.323529411764707e-06, 'epoch': 0.1411764705882353}
{'loss': 0.4373, 'grad_norm': 6.766918182373047, 'learning_rate': 4.088235294117647e-06, 'epoch': 0.18823529411764706}

Correct predictions: 1148 / 1163
{'eval_loss': 0.3788065016269684, 'eval_accuracy': 0.9871023215821152, 'eval_precision': 0.9875886524822695, 'eval_recall': 0.9858407079646018, 'eval_f1': 0.9867139061116031, 'eval_runtime': 6.3705, 'eval_samples_per_second': 182.561, 'eval_steps_per_second': 2.983, 'epoch': 0.18823529411764706}
{'loss': 0.3618, 'grad_norm': 6.805063724517822, 'learning_rate': 3.852941176470589e-06, 'epoch': 0.23529411764705882}
{'loss': 0.3159, 'grad_norm': 6.938209533691406, 'learning_rate': 3.61764

TrainOutput(global_step=170, training_loss=0.23278589774580563, metrics={'train_runtime': 108.8826, 'train_samples_per_second': 24.917, 'train_steps_per_second': 1.561, 'train_loss': 0.23278589774580563, 'epoch': 1.0})