In [1]:
from datasets import load_dataset
import numpy as np
import random
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoConfig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
snli = load_dataset("snli")
for split in snli:
    snli[split] = snli[split].filter(lambda x: x["label"] >= 0)

In [4]:
#Load the tokenizer and model
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=3)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["premise"], examples["hypothesis"], truncation=True)

# Tokenize the dataset
encoded_snli = snli.map(tokenize_function, batched=True)
    
# Specify sample size
sample_size = 2000

# Ensure reproducibility
random.seed(68)

# Randomly sample 10 subsets
subsets = []
for idx in range(10):
  subset = encoded_snli['train'].shuffle(seed=68+9-idx).select(range(sample_size))
  subsets.append(subset)

Map: 100%|██████████| 9824/9824 [00:02<00:00, 4790.44 examples/s]
Map: 100%|██████████| 9842/9842 [00:01<00:00, 7099.28 examples/s]
Map: 100%|██████████| 549367/549367 [01:50<00:00, 4953.04 examples/s]


In [6]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="weighted")
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall}

In [9]:
def train_model(training_data):
  for idx in range(len(training_data)):
    subset = training_data[idx]

    training_args = TrainingArguments(
      output_dir= "".join(["./results", str(idx)]),
      evaluation_strategy="epoch",
      learning_rate=2e-05,
      per_device_train_batch_size=16,
      per_device_eval_batch_size=16,
      num_train_epochs=1,
      weight_decay=0.01,
      save_strategy="epoch",
      load_best_model_at_end=True,
      metric_for_best_model="accuracy",)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=subset,
        eval_dataset=encoded_snli["validation"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics)

    trainer.train()


In [10]:
train_model(subsets)

  trainer = Trainer(
                                                 
100%|██████████| 125/125 [08:49<00:00,  2.14s/it]

{'eval_loss': 1.0591257810592651, 'eval_accuracy': 0.4470636049583418, 'eval_f1': 0.4174311815515342, 'eval_precision': 0.5277979180476682, 'eval_recall': 0.4470636049583418, 'eval_runtime': 262.128, 'eval_samples_per_second': 37.547, 'eval_steps_per_second': 2.35, 'epoch': 1.0}


100%|██████████| 125/125 [08:54<00:00,  4.28s/it]
  trainer = Trainer(


{'train_runtime': 534.4926, 'train_samples_per_second': 3.742, 'train_steps_per_second': 0.234, 'train_loss': 1.0857064208984375, 'epoch': 1.0}


                                                 
100%|██████████| 125/125 [09:47<00:00,  2.01s/it]

{'eval_loss': 0.8226264715194702, 'eval_accuracy': 0.6729323308270677, 'eval_f1': 0.6686017393753599, 'eval_precision': 0.6732637016450126, 'eval_recall': 0.6729323308270677, 'eval_runtime': 286.7812, 'eval_samples_per_second': 34.319, 'eval_steps_per_second': 2.148, 'epoch': 1.0}


100%|██████████| 125/125 [09:59<00:00,  4.80s/it]
  trainer = Trainer(


{'train_runtime': 599.3716, 'train_samples_per_second': 3.337, 'train_steps_per_second': 0.209, 'train_loss': 0.95858740234375, 'epoch': 1.0}


100%|██████████| 125/125 [04:35<00:00,  1.86s/it]
100%|██████████| 125/125 [09:11<00:00,  1.86s/it]

{'eval_loss': 0.7231476306915283, 'eval_accuracy': 0.6994513310302783, 'eval_f1': 0.6985844526244166, 'eval_precision': 0.7023731245773276, 'eval_recall': 0.6994513310302783, 'eval_runtime': 271.6009, 'eval_samples_per_second': 36.237, 'eval_steps_per_second': 2.268, 'epoch': 1.0}


100%|██████████| 125/125 [09:21<00:00,  4.49s/it]
  trainer = Trainer(


{'train_runtime': 561.013, 'train_samples_per_second': 3.565, 'train_steps_per_second': 0.223, 'train_loss': 0.7986385498046875, 'epoch': 1.0}


                                                 
100%|██████████| 125/125 [10:06<00:00,  1.91s/it]

{'eval_loss': 0.6739460229873657, 'eval_accuracy': 0.7234301971144076, 'eval_f1': 0.7234031101831481, 'eval_precision': 0.725467008300275, 'eval_recall': 0.7234301971144076, 'eval_runtime': 345.0695, 'eval_samples_per_second': 28.522, 'eval_steps_per_second': 1.785, 'epoch': 1.0}


100%|██████████| 125/125 [10:15<00:00,  4.93s/it]
  trainer = Trainer(


{'train_runtime': 615.7204, 'train_samples_per_second': 3.248, 'train_steps_per_second': 0.203, 'train_loss': 0.7288319702148438, 'epoch': 1.0}


100%|██████████| 125/125 [04:23<00:00,  2.61s/it]
100%|██████████| 125/125 [09:08<00:00,  2.61s/it]

{'eval_loss': 0.6406736969947815, 'eval_accuracy': 0.7398902662060557, 'eval_f1': 0.7375302198908631, 'eval_precision': 0.7381042749826953, 'eval_recall': 0.7398902662060557, 'eval_runtime': 280.1789, 'eval_samples_per_second': 35.128, 'eval_steps_per_second': 2.199, 'epoch': 1.0}


100%|██████████| 125/125 [09:17<00:00,  4.46s/it]
  trainer = Trainer(


{'train_runtime': 557.2478, 'train_samples_per_second': 3.589, 'train_steps_per_second': 0.224, 'train_loss': 0.7246220703125, 'epoch': 1.0}


                                                 
100%|██████████| 125/125 [09:26<00:00,  1.70s/it]

{'eval_loss': 0.6280561089515686, 'eval_accuracy': 0.7415159520422678, 'eval_f1': 0.7386606437736406, 'eval_precision': 0.7428850104494386, 'eval_recall': 0.7415159520422678, 'eval_runtime': 283.1401, 'eval_samples_per_second': 34.76, 'eval_steps_per_second': 2.176, 'epoch': 1.0}


100%|██████████| 125/125 [09:36<00:00,  4.61s/it]
  trainer = Trainer(


{'train_runtime': 576.6831, 'train_samples_per_second': 3.468, 'train_steps_per_second': 0.217, 'train_loss': 0.6986270141601563, 'epoch': 1.0}


100%|██████████| 125/125 [03:53<00:00,  1.68s/it]
100%|██████████| 125/125 [09:17<00:00,  1.68s/it]

{'eval_loss': 0.595384955406189, 'eval_accuracy': 0.7581792318634424, 'eval_f1': 0.7585273160255664, 'eval_precision': 0.7596022715265934, 'eval_recall': 0.7581792318634424, 'eval_runtime': 319.6211, 'eval_samples_per_second': 30.793, 'eval_steps_per_second': 1.927, 'epoch': 1.0}


100%|██████████| 125/125 [09:26<00:00,  4.53s/it]
  trainer = Trainer(


{'train_runtime': 566.3596, 'train_samples_per_second': 3.531, 'train_steps_per_second': 0.221, 'train_loss': 0.6770763549804687, 'epoch': 1.0}


100%|██████████| 125/125 [04:27<00:00,  2.09s/it]
100%|██████████| 125/125 [09:05<00:00,  2.09s/it]

{'eval_loss': 0.5831061005592346, 'eval_accuracy': 0.7673237146921358, 'eval_f1': 0.7670619192126339, 'eval_precision': 0.7668702053148707, 'eval_recall': 0.7673237146921358, 'eval_runtime': 273.1881, 'eval_samples_per_second': 36.026, 'eval_steps_per_second': 2.255, 'epoch': 1.0}


100%|██████████| 125/125 [09:15<00:00,  4.44s/it]
  trainer = Trainer(


{'train_runtime': 555.5783, 'train_samples_per_second': 3.6, 'train_steps_per_second': 0.225, 'train_loss': 0.664669189453125, 'epoch': 1.0}


100%|██████████| 125/125 [04:05<00:00,  1.84s/it]
100%|██████████| 125/125 [08:12<00:00,  1.84s/it]

{'eval_loss': 0.5910376906394958, 'eval_accuracy': 0.7668156878683194, 'eval_f1': 0.7656094426142411, 'eval_precision': 0.7677651224524598, 'eval_recall': 0.7668156878683194, 'eval_runtime': 242.4728, 'eval_samples_per_second': 40.59, 'eval_steps_per_second': 2.54, 'epoch': 1.0}


100%|██████████| 125/125 [08:16<00:00,  3.97s/it]
  trainer = Trainer(


{'train_runtime': 496.8469, 'train_samples_per_second': 4.025, 'train_steps_per_second': 0.252, 'train_loss': 0.6224588012695312, 'epoch': 1.0}


100%|██████████| 125/125 [03:30<00:00,  1.60s/it]
100%|██████████| 125/125 [07:45<00:00,  1.60s/it]

{'eval_loss': 0.5648350715637207, 'eval_accuracy': 0.7763665921560658, 'eval_f1': 0.7748484593190571, 'eval_precision': 0.7750725566945951, 'eval_recall': 0.7763665921560658, 'eval_runtime': 251.8398, 'eval_samples_per_second': 39.08, 'eval_steps_per_second': 2.446, 'epoch': 1.0}


100%|██████████| 125/125 [07:52<00:00,  3.78s/it]

{'train_runtime': 472.0332, 'train_samples_per_second': 4.237, 'train_steps_per_second': 0.265, 'train_loss': 0.6273997802734375, 'epoch': 1.0}



