In [30]:
"""
DistilBERT: a distilled version of BERT: smaller, faster, cheaper and lighter - https://arxiv.org/abs/1910.01108
sms_spam: a public set of SMS labeled messages that have been collected for mobile phone spam research, 5574 sms messages, 203 KB https://huggingface.co/datasets/sms_spam
"""
try:
    import torch, transformers, datasets, accelerate
except:
#    %pip install -q torch transformers 'datasets==2.18.0' accelerate
    %pip install -q torch transformers datasets accelerate

def use_best_device():
    # Check if GPU is available
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
        torch.set_default_device(device)
    elif torch.backends.mps.is_available():
        device = "mps"   
    return device

device = use_best_device()

print(f"PyTorch version: {torch.__version__}")
print(datasets.__version__)
print(f"device: {device}")
# print("HF_HOME:", os.environ.get("HF_HOME"))

PyTorch version: 2.2.2
2.19.0
device: mps


In [5]:
from transformers import AutoTokenizer
from datasets import load_dataset

class SMS_SPAM_Dataset:
    model_id = "distilbert-base-uncased"
    splits = ["train", "test"]
    limit_data = 5574  # 5574 is all of them, only has "train" split
    dataset_id = "sms_spam"

    def __init__(self):
        # self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

        # Load the train and test splits of the imdb dataset, 25k train, 25k test
        # self.dataset = {split: ds for split, ds in zip(self.splits, load_dataset(self.dataset_id, split=self.splits))}

        # for split in self.splits:
        #     self.dataset[split] = self.dataset[split].shuffle(seed=42).select(range(self.limit_data))

        # Split the dataset into training and testing using a 20% split for the test set

        sms_ds = load_dataset("sms_spam")
        # DatasetDict({train: Dataset({features: ['sms', 'label'],num_rows: 5574})})
        
        self.dataset = sms_ds['train'].train_test_split(test_size=0.2, shuffle=True, seed=42)
        # Assign the splits to new variables in a dictionary
        # self.dataset = {'train': train_test_split['train'], 'test': train_test_split['test']}

        # self.dataset = load_dataset("imdb")
        # print(self.dataset)

    def tokenize_function(self, examples):
        # return self.tokenizer(examples["text"], padding="max_length", truncation=True)
        return self.tokenizer(examples["sms"], max_length=512, truncation=True)

    def prepare_data(self):
        """Tokenizing 100,000 reviews take some time (~3 minutes)."""
        # self.tokenized_datasets = self.dataset.map(self.tokenize_function, batched=True)

        self.tokenized_ds = {}
        for split in self.splits:
            self.tokenized_ds[split] = self.dataset[split].map(self.tokenize_function, batched=True)

dataset = SMS_SPAM_Dataset()
dataset.prepare_data()
print(dataset.tokenizer)
print(dataset.dataset)
print(dataset.tokenized_ds)

print(dataset.tokenized_ds["train"][0]["input_ids"])

Map:   0%|          | 0/4459 [00:00<?, ? examples/s]

Map:   0%|          | 0/1115 [00:00<?, ? examples/s]

DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
DatasetDict({
    train: Dataset({
        features: ['sms', 'label'],
        num_rows

In [21]:
# Define IMDBTrainer class, loading distilbert-base-uncased
# and imdb dataset for fine-tuning task
# from transformers import DistilBertForSequenceClassification
# from transformers import DistilBertTokenizer

from transformers import AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers import TrainingArguments, Trainer
import numpy as np

class SMS_SPAM_Finetune:
    """
    distilbert-base-uncased size: ~268 MB
    """

    model_id = "distilbert-base-uncased"
    output_dir = "/tmp/MAI_SMS_SPAM_Trainer"
    splits = ["train", "test"]
    dataset_id = "imdb"

    def __init__(self):
        # self.model = DistilBertForSequenceClassification.from_pretrained(
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_id,
            num_labels=2,
            id2label={0: "not spam", 1: "spam"},
            label2id={"not spam": 0, "spam": 1},
        )

        # Freeze all the parameters of the base model
        for param in self.model.base_model.parameters():
            param.require_grad = False

    def compute_metrics(self, eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return {"accuracy": (predictions == labels).mean()}

    def train(self, tokenizer, tokenized_ds):
        training_args = TrainingArguments(
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            output_dir=self.output_dir,
            learning_rate=0.0001,
            num_train_epochs=1,
            weight_decay=0.01,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
        )
        self.trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_ds["train"],
            eval_dataset=tokenized_ds["test"],
            tokenizer=tokenizer,
            data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
            compute_metrics=self.compute_metrics,
        )
        self.trainer.train()
    
    def evaluate(self):
        return self.trainer.evaluate()
    
    def predict(self, x):
        return self.trainer.predict(x)
    
    

trainer = SMS_SPAM_Finetune()
print(trainer.model)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [22]:
trainer.train(dataset.tokenizer, dataset.tokenized_ds)

  0%|          | 0/279 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

{'eval_loss': 0.044411901384592056, 'eval_accuracy': 0.9919282511210762, 'eval_runtime': 1.4904, 'eval_samples_per_second': 748.138, 'eval_steps_per_second': 46.968, 'epoch': 1.0}
{'train_runtime': 37.3346, 'train_samples_per_second': 119.434, 'train_steps_per_second': 7.473, 'train_loss': 0.06464275237052672, 'epoch': 1.0}


In [23]:
trainer.evaluate()

  0%|          | 0/70 [00:00<?, ?it/s]

{'eval_loss': 0.044411901384592056,
 'eval_accuracy': 0.9919282511210762,
 'eval_runtime': 1.6244,
 'eval_samples_per_second': 686.401,
 'eval_steps_per_second': 43.092,
 'epoch': 1.0}

In [29]:
import pandas as pd

# Extract a subset of entries from the 'test' split of the tokenized dataset for detailed review.
indices_review = [0, 1, 2, 10, 50, 100, 1000]
items_for_manual_review = dataset.tokenized_ds["test"].select(indices_review)

# Use the trained model to make predictions on the selected subset.
results = trainer.predict(items_for_manual_review)

# Assemble a DataFrame with the text messages, their corresponding model predictions, and the true labels.
df = pd.DataFrame(
    {
        "sms": [
            item["sms"] for item in items_for_manual_review
        ],  # Collect the message texts from the dataset.
        "predictions": results.predictions.argmax(axis=1),
        "labels": results.label_ids,  # Include the actual labels for comparison.
    }
)

# Adjust the display settings of pandas to ensure that entire text messages are shown without truncation.
pd.set_option("display.max_colwidth", None)
df

  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,sms,predictions,labels
0,sports fans - get the latest sports news str* 2 ur mobile 1 wk FREE PLUS a FREE TONE Txt SPORT ON to 8007 www.getzed.co.uk 0870141701216+ norm 4txt/120p \n,1,1
1,It's justbeen overa week since we broke up and already our brains are going to mush!\n,0,0
2,Not directly behind... Abt 4 rows behind ü...\n,0,0
3,Cramps stopped. Going back to sleep\n,0,0
4,Fancy a shag? I do.Interested? sextextuk.com txt XXUK SUZY to 69876. Txts cost 1.50 per msg. TnCs on website. X\n,1,1
5,I'm good. Have you registered to vote?\n,0,0
6,Desires- u going to doctor 4 liver. And get a bit stylish. Get ur hair managed. Thats it.\n,0,0
