<a href="https://colab.research.google.com/github/danasaur/nlp/blob/main/fine_tuning_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange

In [None]:
!wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'

--2022-12-10 02:37:33--  https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 203415 (199K) [application/x-httpd-php]
Saving to: ‘smsspamcollection.zip.1’


2022-12-10 02:37:34 (304 KB/s) - ‘smsspamcollection.zip.1’ saved [203415/203415]



In [None]:
!unzip -o smsspamcollection.zip

Archive:  smsspamcollection.zip
  inflating: SMSSpamCollection       
  inflating: readme                  


In [None]:
!head -10 SMSSpamCollection

ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham	U dun say so early hor... U c already then say...
ham	Nah I don't think he goes to usf, he lives around here though
spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham	Even my brother is not like to speak with me. They treat me like aids patent.
ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam	WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
spam	H

In [None]:
file_path = '/content/SMSSpamCollection'
df = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
  for line in f.readlines():
    split = line.split('\t')
    df = df.append({'label': 1 if split[0] == 'spam' else 0,
                    'text': split[1]},
                    ignore_index = True)
df.head()

Unnamed: 0,label,text
0,0,"Go until jurong point, crazy.. Available only ..."
1,0,Ok lar... Joking wif u oni...\n
2,1,Free entry in 2 a wkly comp to win FA Cup fina...
3,0,U dun say so early hor... U c already then say...
4,0,"Nah I don't think he goes to usf, he lives aro..."


In [None]:
text = df.text.values
labels = df.label.values

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    do_lower_case = True
    )

In [None]:
import random

def print_rand_sentence():
  '''Displays the tokens and respective IDs of a random text sample'''
  index = random.randint(0, len(text)-1)
  table = np.array([tokenizer.tokenize(text[index]), 
                    tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text[index]))]).T
  print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

╒══════════╤═════════════╕
│ Tokens   │   Token IDs │
╞══════════╪═════════════╡
│ wrong    │        3308 │
├──────────┼─────────────┤
│ phone    │        3042 │
├──────────┼─────────────┤
│ !        │         999 │
├──────────┼─────────────┤
│ this     │        2023 │
├──────────┼─────────────┤
│ phone    │        3042 │
├──────────┼─────────────┤
│ !        │         999 │
├──────────┼─────────────┤
│ i        │        1045 │
├──────────┼─────────────┤
│ answer   │        3437 │
├──────────┼─────────────┤
│ this     │        2023 │
├──────────┼─────────────┤
│ one      │        2028 │
├──────────┼─────────────┤
│ but      │        2021 │
├──────────┼─────────────┤
│ assume   │        7868 │
├──────────┼─────────────┤
│ the      │        1996 │
├──────────┼─────────────┤
│ other    │        2060 │
├──────────┼─────────────┤
│ is       │        2003 │
├──────────┼─────────────┤
│ people   │        2111 │
├──────────┼─────────────┤
│ i        │        1045 │
├──────────┼─────────────┤
│

In [None]:
print_rand_sentence()

╒══════════╤═════════════╕
│ Tokens   │   Token IDs │
╞══════════╪═════════════╡
│ u        │        1057 │
├──────────┼─────────────┤
│ have     │        2031 │
├──────────┼─────────────┤
│ won      │        2180 │
├──────────┼─────────────┤
│ a        │        1037 │
├──────────┼─────────────┤
│ nokia    │       22098 │
├──────────┼─────────────┤
│ 62       │        5786 │
├──────────┼─────────────┤
│ ##30     │       14142 │
├──────────┼─────────────┤
│ plus     │        4606 │
├──────────┼─────────────┤
│ a        │        1037 │
├──────────┼─────────────┤
│ free     │        2489 │
├──────────┼─────────────┤
│ digital  │        3617 │
├──────────┼─────────────┤
│ camera   │        4950 │
├──────────┼─────────────┤
│ .        │        1012 │
├──────────┼─────────────┤
│ this     │        2023 │
├──────────┼─────────────┤
│ is       │        2003 │
├──────────┼─────────────┤
│ what     │        2054 │
├──────────┼─────────────┤
│ u        │        1057 │
├──────────┼─────────────┤
│

In [None]:
val_ratio = 0.2
# Recommended batch size: 16, 32. See: https://arxiv.org/pdf/1810.04805.pdf
batch_size = 16

# Indices of the train and validation splits stratified by labels
train_idx, val_idx = train_test_split(
    np.arange(len(labels)),
    test_size = val_ratio,
    shuffle = True,
    stratify = labels)

# Train and validation sets
train_set = TensorDataset(token_id[train_idx], 
                          attention_masks[train_idx], 
                          labels[train_idx])

val_set = TensorDataset(token_id[val_idx], 
                        attention_masks[val_idx], 
                        labels[val_idx])

# Prepare DataLoader
train_dataloader = DataLoader(
            train_set,
            sampler = RandomSampler(train_set),
            batch_size = batch_size
        )

validation_dataloader = DataLoader(
            val_set,
            sampler = SequentialSampler(val_set),
            batch_size = batch_size
        )

NameError: ignored