In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.45.1-py3-none-any.whl.metadata (44 kB)
     ---------------------------------------- 0.0/44.4 kB ? eta -:--:--
     ----------------- -------------------- 20.5/44.4 kB 330.3 kB/s eta 0:00:01
     -------------------------------------- 44.4/44.4 kB 550.5 kB/s eta 0:00:00
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers)
  Downloading huggingface_hub-0.25.1-py3-none-any.whl.metadata (13 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp312-none-win_amd64.whl.metadata (3.9 kB)
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.0-cp312-none-win_amd64.whl.metadata (6.9 kB)
Downloading transformers-4.45.1-py3-none-any.whl (9.9 MB)
   ---------------------------------------- 0.0/9.9 MB ? eta -:--:--
    --------------------------------------- 0.2/9.9 MB 6.7 MB/s eta 0:00:02
   -- ------------------------------------- 0.5/9.9 MB 6.7 MB/s eta 0:00:02
   --- -----

In [57]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

In [59]:
PT_MODEL_NAME = 'bert-base-cased'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained(PT_MODEL_NAME)
EPOCHS = 10
MAX_LEN = 512
BATCH_SIZE = 16

In [25]:
df = pd.read_csv('https://raw.githubusercontent.com/animesharma3/SPAM-SMS-Detection/master/spam_sms_collection.csv')[['msg', 'spam']]
df.head()

Unnamed: 0,msg,spam
0,go jurong point crazy available bugis n great ...,0
1,ok lar joking wif u oni,0
2,free entry wkly comp win fa cup final tkts st ...,1
3,u dun say early hor u c already say,0
4,nah think go usf life around though,0


In [27]:
df.shape

(5572, 2)

In [48]:
class SMSCollectionDataset(Dataset):
    def __init__(self, spam, msgs, tokenizer, max_len):
        self.msgs = msgs
        self.spam = spam
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.msgs)

    def __getitem__(self, i):
        msg = str(self.msgs[i])
        spam = self.spam[i]

        encoding = self.tokenizer.encode_plus(
            msg, 
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            return_token_type_ids=False,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'msg': msg,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'spam': torch.tensor(spam, dtype=torch.long)
        }

In [75]:
def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = SMSCollectionDataset(
        spam=df['spam'].to_numpy(),
        msgs=df['msg'].to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
    )
    return DataLoader(ds,batch_size=batch_size,num_workers=4)

In [77]:
df_train, df_test = train_test_split(
    df,
    test_size=0.2,
    random_state=42
)
df_val, df_test = train_test_split(
    df_test,
    test_size=0.5,
    random_state=42
)
df_train.shape, df_test.shape, df_val.shape

((4457, 2), (558, 2), (557, 2))

In [79]:
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)

In [None]:
d = next(iter(train_data_loader))
d.keys()

In [None]:
d['input_ids'].shape, d['attention_mask'].shape, d['spam'].shape