In [None]:
!pip install transformers datasets
from datasets import load_dataset, Dataset, DatasetDict
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import DataLoader
import torch
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score


def load_snli(sample_size=1000, val_test_size=100):

    dataset = load_dataset("snli")

=
    def sample_data(split, size):
        if size < len(dataset[split]):
            return dataset[split].shuffle(seed=42).select(range(size)) 
        return dataset[split]


    train_sample = sample_data("train", sample_size)
    validation_sample = sample_data("validation", val_test_size)
    test_sample = sample_data("test", val_test_size)


    sampled_dataset = DatasetDict({
        "train": train_sample,
        "validation": validation_sample,
        "test": test_sample,
    })
    return sampled_dataset


def preprocess_data(dataset, tokenizer, max_length=128):
    def preprocess(batch):
        return tokenizer(batch['premise'], batch['hypothesis'], truncation=True, padding='max_length', max_length=max_length)

    def filter_invalid_labels(dataset):
        return dataset.filter(lambda example: example['label'] != -1)


    dataset = DatasetDict({split: filter_invalid_labels(dataset[split]) for split in dataset.keys()})
    return dataset.map(preprocess, batched=True)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")
    return {"accuracy": acc, "f1": f1}


tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3)


dataset = load_snli() 
encoded_dataset = preprocess_data(dataset, tokenizer)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


train_dataset = encoded_dataset["train"]
val_dataset = encoded_dataset["validation"]

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)


trainer.train()


test_dataset = encoded_dataset["test"]
metrics = trainer.evaluate(test_dataset)
print("Test Metrics:", metrics)
