In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import numpy as np
from datasets import Dataset, DatasetDict
import evaluate
import torch
from utils import load_nli_data

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

label_map = {"entailment": 0, "contradiction": 1, "neutral": 2}
# reverse it: id→name
id2label = {v: k for k, v in label_map.items()}

# Read Model and Tokenize NLI Dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("data/bert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained(
    "data/bert-base-uncased",
    num_labels=3,
)

snli_train = load_nli_data("data/snli_1.0_train.jsonl")
snli_dev = load_nli_data("data/snli_1.0_dev.jsonl")
snli_test = load_nli_data("data/snli_1.0_test.jsonl")

In [None]:
def preprocess_function(examples):
    return tokenizer(
        examples["sentence1"], examples["sentence2"], truncation=True, max_length=512
    )


snli_dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(snli_train),
        "validation": Dataset.from_pandas(snli_dev),
    }
)

snli_dataset = snli_dataset.map(preprocess_function).remove_columns(
    ["sentence1", "sentence2"]
)

# Initialize the model and parameters

In [None]:
metric = evaluate.load("glue", "mnli")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)


args = TrainingArguments(
    output_dir="data/checkpoints/",
    eval_strategy="epoch",
    logging_steps=500,
    save_strategy="no",
    learning_rate=2e-5,
    bf16=True,
    seed=42,
    weight_decay=0.01,
    optim="adamw_torch_fused",
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    num_train_epochs=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=snli_dataset["train"],
    eval_dataset=snli_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train and Evaluate the model

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

# Evaluate the model on the test set

In [None]:
test_data = (
    Dataset.from_pandas(snli_test)
    .map(preprocess_function)
    .remove_columns(["sentence1", "sentence2"])
)

predictions = trainer.predict(test_data)

predictions = np.argmax(predictions.predictions, axis=1)
accuracy = (predictions == test_data["label"]).mean()
print(f"Accuracy: {accuracy}")

# Save the model

In [None]:
model.save_pretrained("data/checkpoints/bert-snli")
tokenizer.save_pretrained("data/checkpoints/bert-snli")