In [1]:
%pip install transformers datasets evaluate lightning


Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset

# Load SNLI dataset (train/validation/test)
snli = load_dataset("snli")

# Remove examples with missing labels (-1)
snli = snli.filter(lambda x: x["label"] != -1)


In [3]:
from transformers import AutoTokenizer

model_name = "google/electra-small-discriminator"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_function(batch):
    return tokenizer(batch["premise"], batch["hypothesis"],
                     truncation=True, padding="max_length", max_length=128)

encoded_snli = snli.map(preprocess_function, batched=True)
encoded_snli.set_format(type="torch",
                        columns=["input_ids", "attention_mask", "label"])


Map:   0%|          | 0/549367 [00:00<?, ? examples/s]

In [4]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)


Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
from transformers import Trainer, TrainingArguments
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    return metric.compute(predictions=preds, references=labels)

training_args = TrainingArguments(
    output_dir="checkpoints",
    per_device_eval_batch_size=64,
)

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=encoded_snli["validation"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

results = trainer.evaluate()
print(results)


  trainer = Trainer(


{'eval_loss': 1.0990803241729736, 'eval_model_preparation_time': 0.0058, 'eval_accuracy': 0.3029871977240398, 'eval_runtime': 11.4236, 'eval_samples_per_second': 861.552, 'eval_steps_per_second': 13.481}


In [7]:
%pip install transformers[torch]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


zsh:1: no matches found: transformers[torch]
Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install 'accelerate>=0.26.0'

# fine tuning results
{'eval_loss': 0.31268370151519775, 'eval_accuracy': 0.8890469416785206, 'eval_runtime': 5.0437, 'eval_samples_per_second': 1951.36, 'eval_steps_per_second': 30.533, 'epoch': 3.0}
{'train_runtime': 2492.0103, 'train_samples_per_second': 661.354, 'train_steps_per_second': 20.668, 'train_loss': 0.38749799531377893, 'epoch': 3.0}

Best dev metrics: {'eval_loss': 0.31268370151519775, 'eval_accuracy': 0.8890469416785206, 'eval_runtime': 5.0038, 'eval_samples_per_second': 1966.893, 'eval_steps_per_second': 30.776, 'epoch': 3.0}

Test metrics: {'eval_loss': 0.318878173828125, 'eval_accuracy': 0.8871131921824105, 'eval_runtime': 4.9029, 'eval_samples_per_second': 2003.71, 'eval_steps_per_second': 31.41, 'epoch': 3.0}

=== Robust Accuracy (overall) ===
62.00%

=== Per-phenomenon accuracy ===
            accuracy
phenomenon          
paraphrase    92.00%
quantifier    57.33%
negation      56.00%
antonymy      42.67%

+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 848    |
| Number of failed attacks:     | 32     |
| Number of skipped attacks:    | 120    |
| Original accuracy:            | 88.0%  |
| Accuracy under attack:        | 3.2%   |
| Attack success rate:          | 96.36% |
| Average perturbed word %:     | 7.31%  |
| Average num. words per input: | 21.94  |
| Avg num queries:              | 43.8   |
+-------------------------------+--------+