In [68]:
# import torch, tokenizer and bert-base-uncased model
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [82]:
# import datasets lib and sms_spam dataset
from datasets import load_dataset
raw_dataset = load_dataset("sms_spam")

# check if mps gpu is available
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
mps_device = torch.device("mps")
model.to(mps_device)

Found cached dataset sms_spam (/Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c)
100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 672.49it/s]

True
True





In [70]:
def preprocessing_tokenize_data(sms):
    return tokenizer(
        sms['sms'], padding='max_length', truncation=True, max_length=128
    )

In [77]:
# final preprocessing steps after tokenizing dataset
tokenized_dataset = raw_dataset.map(preprocessing_tokenize_data, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns('sms')
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset = tokenized_dataset.with_format("torch")

Loading cached processed dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-b23114d6f3872216.arrow


dict_keys(['train'])

In [72]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [73]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [74]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="training_checkpoints", evaluation_strategy="epoch")

In [84]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'].shuffle(seed=22).select(range(4000)),
    eval_dataset=tokenized_dataset['train'].shuffle(seed=22).select(range(4000,5000)),
    compute_metrics=compute_metrics,
)
test_dataset = tokenized_dataset['train'].shuffle(seed=22).select(range(5000,5570))

Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow
Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow
Loading cached shuffled indices for dataset at /Users/jayreddy/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-5216da0691771bed.arrow


In [81]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,0.1077,0.046291,0.989172
2,0.0467,0.036065,0.993631
3,0.0096,0.033602,0.993631


TrainOutput(global_step=1500, training_loss=0.05466354274749756, metrics={'train_runtime': 2419.8433, 'train_samples_per_second': 4.959, 'train_steps_per_second': 0.62, 'total_flos': 789333166080000.0, 'train_loss': 0.05466354274749756, 'epoch': 3.0})