In [1]:
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, Trainer
from pprint import pprint
from torchinfo import summary
from sklearn.metrics import f1_score
import torch
import numpy as np #required implicitly for training process

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets = load_dataset("glue", "rte")

In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 2490
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 277
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3000
    })
})

In [4]:
raw_datasets['train'].features

{'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'not_entailment'], id=None),
 'idx': Value(dtype='int32', id=None)}

In [5]:
raw_datasets['train']['sentence1'][:10]

['No Weapons of Mass Destruction Found in Iraq Yet.',
 'A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.',
 'Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.',
 'Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.',
 "A man is due in court later charged with the murder 26 years ago of a teenager whose case was the first to be featured on BBC One's Crimewatch. Colette Aram, 16, was walking to her boyfriend's house in Keyworth, Nottinghamshire, on 30 October 1983 when she disappeared. Her body was later found i

In [6]:
checkpoint = 'bert-base-cased'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [8]:
tokenizer(
    raw_datasets['train']['sentence1'][0],
    raw_datasets['train']['sentence2'][0]
)

{'input_ids': [101, 1302, 20263, 1104, 8718, 14177, 17993, 17107, 1107, 5008, 6355, 119, 102, 20263, 1104, 8718, 14177, 17993, 17107, 1107, 5008, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [9]:
result = _

In [10]:
result.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [11]:
tokenizer.decode(result['input_ids'])

'[CLS] No Weapons of Mass Destruction Found in Iraq Yet. [SEP] Weapons of Mass Destruction Found in Iraq. [SEP]'

In [12]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
summary(model)

Layer (type:depth-idx)                                  Param #
BertForSequenceClassification                           --
├─BertModel: 1-1                                        --
│    └─BertEmbeddings: 2-1                              --
│    │    └─Embedding: 3-1                              22,268,928
│    │    └─Embedding: 3-2                              393,216
│    │    └─Embedding: 3-3                              1,536
│    │    └─LayerNorm: 3-4                              1,536
│    │    └─Dropout: 3-5                                --
│    └─BertEncoder: 2-2                                 --
│    │    └─ModuleList: 3-6                             85,054,464
│    └─BertPooler: 2-3                                  --
│    │    └─Linear: 3-7                                 590,592
│    │    └─Tanh: 3-8                                   --
├─Dropout: 1-2                                          --
├─Linear: 1-3                                           1,538
Total params: 10

In [14]:
training_args = TrainingArguments(
    output_dir='training_dir',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    logging_steps=150 #since training data is small, else 'no log' will appear under training loss
)

In [15]:
metric = load_metric('glue', 'rte')

  metric = load_metric('glue', 'rte')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [16]:
metric.compute(predictions=[1, 0, 1], references=[1, 0, 0])

{'accuracy': 0.6666666666666666}

In [17]:
#since only accuracy provided, we override it with custom metrics
def compute_metrics(logits_and_labels):
    logits, labels = logits_and_labels
    predictions = np.argmax(logits, axis=-1)
    accuracy = np.mean(predictions == labels)
    f1 = f1_score(labels, predictions)
    return {'accuracy': accuracy, 'f1': f1}

In [18]:
def tokenize_fn(batch):
    return tokenizer(batch['sentence1'], batch['sentence2'], truncation=True)

In [19]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map: 100%|██████████| 3000/3000 [00:00<00:00, 9321.64 examples/s]


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [21]:
trainer = Trainer(
    model=model.to(device),
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [22]:
trainer.train()

 32%|███▏      | 150/468 [01:31<03:32,  1.49it/s]

{'loss': 0.6515, 'learning_rate': 3.397435897435898e-05, 'epoch': 0.96}


                                                 
 33%|███▎      | 156/468 [02:02<05:52,  1.13s/it]

{'eval_loss': 0.6195622682571411, 'eval_accuracy': 0.6606498194945848, 'eval_f1': 0.5523809523809524, 'eval_runtime': 22.0144, 'eval_samples_per_second': 12.583, 'eval_steps_per_second': 0.227, 'epoch': 1.0}


 64%|██████▍   | 300/468 [03:43<01:13,  2.29it/s]

{'loss': 0.4507, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}


                                                 
 67%|██████▋   | 312/468 [04:00<03:13,  1.24s/it]

{'eval_loss': 0.7040601968765259, 'eval_accuracy': 0.6895306859205776, 'eval_f1': 0.6742424242424242, 'eval_runtime': 3.679, 'eval_samples_per_second': 75.292, 'eval_steps_per_second': 1.359, 'epoch': 2.0}


 96%|█████████▌| 450/468 [05:40<00:14,  1.28it/s]

{'loss': 0.2475, 'learning_rate': 1.9230769230769234e-06, 'epoch': 2.88}


                                                 
100%|██████████| 468/468 [06:04<00:00,  1.75it/s]

{'eval_loss': 0.8067015409469604, 'eval_accuracy': 0.7256317689530686, 'eval_f1': 0.6935483870967742, 'eval_runtime': 3.63, 'eval_samples_per_second': 76.309, 'eval_steps_per_second': 1.377, 'epoch': 3.0}


100%|██████████| 468/468 [06:09<00:00,  1.27it/s]

{'train_runtime': 369.484, 'train_samples_per_second': 20.217, 'train_steps_per_second': 1.267, 'train_loss': 0.44085366654599834, 'epoch': 3.0}





TrainOutput(global_step=468, training_loss=0.44085366654599834, metrics={'train_runtime': 369.484, 'train_samples_per_second': 20.217, 'train_steps_per_second': 1.267, 'train_loss': 0.44085366654599834, 'epoch': 3.0})