In [85]:
# import torch, tokenizer and bert-base-uncased model
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [86]:
# import datasets lib and sms_spam dataset
from datasets import load_dataset
raw_dataset = load_dataset("sms_spam")

# check if mps gpu is available
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
mps_device = torch.device("mps")
model.to(mps_device)

Found cached dataset sms_spam (/Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c)
100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 504.30it/s]


True
True


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [87]:
def preprocessing_tokenize_data(sms):
    return tokenizer(
        sms['sms'], padding='max_length', truncation=True, max_length=128
    )

In [88]:
# final preprocessing steps after tokenizing dataset
tokenized_dataset = raw_dataset.map(preprocessing_tokenize_data, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns('sms')
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset = tokenized_dataset.with_format("torch")

Loading cached processed dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-b23114d6f3872216.arrow


In [89]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [90]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [91]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="training_checkpoints", evaluation_strategy="epoch")

In [92]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'].shuffle(seed=22).select(range(4000)),
    eval_dataset=tokenized_dataset['train'].shuffle(seed=22).select(range(4000,5000)),
    compute_metrics=compute_metrics,
)
test_dataset = tokenized_dataset['train'].shuffle(seed=22).select(range(5000,5570))

Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow
Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow
Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow


In [94]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0001,0.051911,0.993
2,0.0001,0.051911,0.993
3,0.0001,0.051911,0.993


TrainOutput(global_step=1500, training_loss=0.00010177452862262726, metrics={'train_runtime': 2700.3617, 'train_samples_per_second': 4.444, 'train_steps_per_second': 0.555, 'total_flos': 789333166080000.0, 'train_loss': 0.00010177452862262726, 'epoch': 3.0})

In [95]:
test_results = trainer.evaluate(test_dataset)
print(test_results)

{'eval_loss': 0.00014053804625291377, 'eval_accuracy': 1.0, 'eval_runtime': 34.6007, 'eval_samples_per_second': 16.474, 'eval_steps_per_second': 2.081, 'epoch': 3.0}


In [97]:
model_path = "fine_tuned_bert_sms_spam"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

('fine_tuned_bert_sms_spam/tokenizer_config.json',
 'fine_tuned_bert_sms_spam/special_tokens_map.json',
 'fine_tuned_bert_sms_spam/vocab.txt',
 'fine_tuned_bert_sms_spam/added_tokens.json',
 'fine_tuned_bert_sms_spam/tokenizer.json')

In [99]:
print(model.device)
model.to(torch.device("mps"))
print(model.device)

cpu
mps:0
