In [2]:
# 0. Load tools
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
# 1. Setup: Data + Model
# Load dataset
dataset = load_dataset("snli")

# Initialize model + tokenizer
model_name = "google/electra-small-discriminator"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
if hasattr(model, 'electra'):
        for param in model.electra.parameters():
            if not param.is_contiguous():
                param.data = param.data.contiguous()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)


Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# 2. Preprocess the dataset
def preprocess(example):
    return tokenizer(example['premise'], example['hypothesis'], truncation=True, padding='max_length', max_length=tokenizer.model_max_length)

dataset = dataset.filter(lambda ex: ex['label'] != -1)
encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")  # Ensure labels are named correctly
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Filter: 100%|██████████| 10000/10000 [00:00<00:00, 181776.20 examples/s]
Filter: 100%|██████████| 10000/10000 [00:00<00:00, 181776.20 examples/s]
Filter: 100%|██████████| 550152/550152 [00:02<00:00, 230812.70 examples/s]
Map: 100%|██████████| 9824/9824 [00:02<00:00, 4186.61 examples/s]
Map: 100%|██████████| 9842/9842 [00:02<00:00, 4146.56 examples/s]
Map: 100%|██████████| 549367/549367 [02:17<00:00, 4009.24 examples/s]


In [11]:
# 3. Train

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['validation'],
)
trainer.train()


  0%|          | 7/103155 [45:25<127:55:04,  4.46s/it]  

{'loss': 0.8273, 'grad_norm': 11.229413986206055, 'learning_rate': 4.975730040385213e-05, 'epoch': 0.01}




KeyboardInterrupt: 

In [None]:
# 4. Evaluate

eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")