# Bert to the rescue!
- based on https://towardsdatascience.com/bert-to-the-rescue-17671379687f
- but changed imdb dataset (not from pytorch-nlp, but from a file, imdb_master.csv)
- So preprossing is different from the original post

In [3]:
import warnings
warnings.filterwarnings(action='ignore')
import sys
import numpy as np
import random
import torch
from torch import nn
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [4]:
from transformers import BertTokenizer, BertModel

In [5]:
from keras.preprocessing.sequence import pad_sequences

Using TensorFlow backend.


In [6]:
import pandas as pd

In [7]:
from IPython.display import clear_output

In [8]:
random.seed(321)
np.random.seed(321)
torch.manual_seed(321)
torch.cuda.manual_seed(321)

## Prepare the Data

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [10]:
tokenizer.tokenize('Hi my name is Dima')

['hi', 'my', 'name', 'is', 'dim', '##a']

In [11]:
df = pd.read_csv('imdb_master.csv', encoding='latin-1')

In [12]:
df.head()

Unnamed: 0.1,Unnamed: 0,type,review,label,file
0,0,test,Once again Mr. Costner has dragged out a movie...,neg,0_2.txt
1,1,test,This is an example of why the majority of acti...,neg,10000_4.txt
2,2,test,"First of all I hate those moronic rappers, who...",neg,10001_1.txt
3,3,test,Not even the Beatles could write songs everyon...,neg,10002_3.txt
4,4,test,Brass pictures (movies is not a fitting word f...,neg,10003_3.txt


In [13]:
df.drop(['Unnamed: 0'], axis = 1, inplace = True)
df.drop(['file'], axis = 1, inplace = True)

In [14]:
df.head()

Unnamed: 0,type,review,label
0,test,Once again Mr. Costner has dragged out a movie...,neg
1,test,This is an example of why the majority of acti...,neg
2,test,"First of all I hate those moronic rappers, who...",neg
3,test,Not even the Beatles could write songs everyon...,neg
4,test,Brass pictures (movies is not a fitting word f...,neg


In [15]:
train_df = df[df['type'] == 'train'][:500].append(df[df['type'] == 'train'][-500:])
test_df = df[df['type'] =='test'][:50].append(df[df['type'] =='test'][-50:])

In [16]:
print(train_df.shape)
print(test_df.shape)

(1000, 3)
(100, 3)


In [17]:
train_texts = train_df['review'].tolist()
test_texts = test_df['review'].tolist()

In [18]:
train_labels = train_df['label']
test_labels = test_df['label']

In [19]:
train_texts[0]

"Story of a man who has unnatural feelings for a pig. Starts out with a opening scene that is a terrific example of absurd comedy. A formal orchestra audience is turned into an insane, violent mob by the crazy chantings of it's singers. Unfortunately it stays absurd the WHOLE time with no general narrative eventually making it just too off putting. Even those from the era should be turned off. The cryptic dialogue would make Shakespeare seem easy to a third grader. On a technical level it's better than you might think with some good cinematography by future great Vilmos Zsigmond. Future stars Sally Kirkland and Frederic Forrest can be seen briefly."

In [20]:
print(len(train_texts), len(test_texts))
print(len(train_labels), len(test_labels))

1000 100
1000 100


In [21]:
train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:510] + ['[SEP]'], train_texts))
test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:510] + ['[SEP]'], test_texts))

len(train_tokens), len(test_tokens)                   

(1000, 100)

In [22]:
train_tokens[0]

['[CLS]',
 'story',
 'of',
 'a',
 'man',
 'who',
 'has',
 'unnatural',
 'feelings',
 'for',
 'a',
 'pig',
 '.',
 'starts',
 'out',
 'with',
 'a',
 'opening',
 'scene',
 'that',
 'is',
 'a',
 'terrific',
 'example',
 'of',
 'absurd',
 'comedy',
 '.',
 'a',
 'formal',
 'orchestra',
 'audience',
 'is',
 'turned',
 'into',
 'an',
 'insane',
 ',',
 'violent',
 'mob',
 'by',
 'the',
 'crazy',
 'chanting',
 '##s',
 'of',
 'it',
 "'",
 's',
 'singers',
 '.',
 'unfortunately',
 'it',
 'stays',
 'absurd',
 'the',
 'whole',
 'time',
 'with',
 'no',
 'general',
 'narrative',
 'eventually',
 'making',
 'it',
 'just',
 'too',
 'off',
 'putting',
 '.',
 'even',
 'those',
 'from',
 'the',
 'era',
 'should',
 'be',
 'turned',
 'off',
 '.',
 'the',
 'cryptic',
 'dialogue',
 'would',
 'make',
 'shakespeare',
 'seem',
 'easy',
 'to',
 'a',
 'third',
 'grade',
 '##r',
 '.',
 'on',
 'a',
 'technical',
 'level',
 'it',
 "'",
 's',
 'better',
 'than',
 'you',
 'might',
 'think',
 'with',
 'some',
 'good',
 'c

In [23]:
train_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, train_tokens)), maxlen=512, truncating="post", padding="post", dtype="int")
test_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, test_tokens)), maxlen=512, truncating="post", padding="post", dtype="int")

train_tokens_ids.shape, test_tokens_ids.shape

((1000, 512), (100, 512))

In [24]:
train_tokens_ids[0]

array([  101,  2466,  1997,  1037,  2158,  2040,  2038, 21242,  5346,
        2005,  1037, 10369,  1012,  4627,  2041,  2007,  1037,  3098,
        3496,  2008,  2003,  1037, 27547,  2742,  1997, 18691,  4038,
        1012,  1037,  5337,  4032,  4378,  2003,  2357,  2046,  2019,
        9577,  1010,  6355, 11240,  2011,  1996,  4689, 22417,  2015,
        1997,  2009,  1005,  1055,  8453,  1012,  6854,  2009, 12237,
       18691,  1996,  2878,  2051,  2007,  2053,  2236,  7984,  2776,
        2437,  2009,  2074,  2205,  2125,  5128,  1012,  2130,  2216,
        2013,  1996,  3690,  2323,  2022,  2357,  2125,  1012,  1996,
       26483,  7982,  2052,  2191,  8101,  4025,  3733,  2000,  1037,
        2353,  3694,  2099,  1012,  2006,  1037,  4087,  2504,  2009,
        1005,  1055,  2488,  2084,  2017,  2453,  2228,  2007,  2070,
        2204, 16434,  2011,  2925,  2307,  6819, 13728,  2891,  1062,
        5332, 21693, 15422,  1012,  2925,  3340,  8836, 11332,  3122,
        1998, 15296,

Our target variable is currently a list of neg and pos strings. We’ll convert it to numpy arrays of booleans:

In [25]:
train_y = np.array(train_labels) != 'neg'
test_y = np.array(test_labels) != 'neg'
train_y.shape, test_y.shape, np.mean(train_y), np.mean(test_y)

((1000,), (100,), 0.5, 0.5)

In [26]:
train_y

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [27]:
train_masks = [[float(i > 0) for i in ii] for ii in train_tokens_ids]
test_masks = [[float(i > 0) for i in ii] for ii in test_tokens_ids]

# Baseline

In [28]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.metrics import classification_report

In [29]:
baseline_model = make_pipeline(CountVectorizer(ngram_range=(1,3)), LogisticRegression()).fit(train_texts, train_y)

In [30]:
baseline_predicted = baseline_model.predict(test_texts)

In [31]:
print(classification_report(test_y, baseline_predicted))

              precision    recall  f1-score   support

       False       0.88      0.58      0.70        50
        True       0.69      0.92      0.79        50

    accuracy                           0.75       100
   macro avg       0.78      0.75      0.74       100
weighted avg       0.78      0.75      0.74       100



# Bert Model

In [32]:
class BertBinaryClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super(BertBinaryClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, tokens, masks=None):
        _, pooled_output = self.bert(tokens, attention_mask=masks)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        proba = self.sigmoid(linear_output)
        return proba
        

In [33]:
bert_clf = BertBinaryClassifier()

In [34]:
train_tokens_ids[:3]

array([[ 101, 2466, 1997, ...,    0,    0,    0],
       [ 101, 3199, 1005, ..., 2004, 2172,  102],
       [ 101, 2023, 2143, ...,    0,    0,    0]])

In [35]:
x = torch.tensor(train_tokens_ids[:3])
y, pooled = bert_clf.bert(x)
x.shape, y.shape, pooled.shape

(torch.Size([3, 512]), torch.Size([3, 512, 768]), torch.Size([3, 768]))

In [36]:
y = bert_clf(x)
y.detach().numpy()

array([[0.6004767 ],
       [0.68869555],
       [0.62791777]], dtype=float32)

# Fine-tune BERT

In [45]:
# device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

'GeForce RTX 2070 SUPER'

In [46]:
BATCH_SIZE = 4
EPOCHS = 10

In [47]:
train_tokens_tensor = torch.tensor(train_tokens_ids)
train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

test_tokens_tensor = torch.tensor(test_tokens_ids)
test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

train_masks_tensor = torch.tensor(train_masks)
test_masks_tensor = torch.tensor(test_masks)

In [48]:
train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)

test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)

In [49]:
param_optimizer = list(bert_clf.sigmoid.named_parameters()) 
optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

In [50]:
optimizer = Adam(bert_clf.parameters(), lr=3e-6)

In [52]:
bert_clf = bert_clf.to(device)
bert_clf = nn.DataParallel(bert_clf)
bert_clf.cuda()

DataParallel(
  (module): BertBinaryClassifier(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
      

In [None]:
for epoch_num in range(EPOCHS):
    bert_clf.train()
    train_loss = 0
    
    print('Epoch: ', epoch_num + 1)
    
    for step_num, batch_data in enumerate(train_dataloader):
        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)
        logits = bert_clf(token_ids, masks)
        
        loss_func = nn.BCELoss()

        batch_loss = loss_func(logits, labels)
        train_loss += batch_loss.item()
        
        
        bert_clf.zero_grad()
        batch_loss.backward()
        

        clip_grad_norm_(parameters=bert_clf.parameters(), max_norm=1.0)
        optimizer.step()
        
        print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(train_df) / BATCH_SIZE, train_loss / (step_num + 1)))
        

Epoch:  1
0/250.0 loss: 0.9797638058662415 
1/250.0 loss: 0.7473504841327667 
2/250.0 loss: 0.682931919892629 
3/250.0 loss: 0.7297441512346268 
4/250.0 loss: 0.7674580454826355 
5/250.0 loss: 0.7653777996699015 
6/250.0 loss: 0.7358746613774981 
7/250.0 loss: 0.7303956672549248 
8/250.0 loss: 0.708951559331682 
9/250.0 loss: 0.6826391637325286 
10/250.0 loss: 0.6871807304295626 
11/250.0 loss: 0.6901405900716782 
12/250.0 loss: 0.7011182078948388 
13/250.0 loss: 0.6919980006558555 
14/250.0 loss: 0.6885264158248902 
15/250.0 loss: 0.7005675919353962 
16/250.0 loss: 0.6987220574827755 
17/250.0 loss: 0.7049259576532576 
18/250.0 loss: 0.705495464174371 
19/250.0 loss: 0.7001570165157318 
20/250.0 loss: 0.702959574404217 
21/250.0 loss: 0.7088658836754885 
22/250.0 loss: 0.7040618113849474 
23/250.0 loss: 0.701962985098362 
24/250.0 loss: 0.7022914433479309 
25/250.0 loss: 0.7020295147712414 
26/250.0 loss: 0.6958229563854359 
27/250.0 loss: 0.6935213484934398 
28/250.0 loss: 0.69284745

232/250.0 loss: 0.6622553269494756 
233/250.0 loss: 0.6621223979780817 
234/250.0 loss: 0.6626566959188339 
235/250.0 loss: 0.6624896350806042 
236/250.0 loss: 0.6618388248646813 
237/250.0 loss: 0.6616936793096927 
238/250.0 loss: 0.6612029022003317 
239/250.0 loss: 0.6602355220665534 
240/250.0 loss: 0.6609301847788308 
241/250.0 loss: 0.6607732168160194 
242/250.0 loss: 0.6596623335355594 
243/250.0 loss: 0.6601774841547012 
244/250.0 loss: 0.6598543461488218 
245/250.0 loss: 0.6591818075839097 
246/250.0 loss: 0.6581241733390792 
247/250.0 loss: 0.6575413155219247 
248/250.0 loss: 0.657438865986215 
249/250.0 loss: 0.6575364841222763 
Epoch:  2
0/250.0 loss: 0.4162587523460388 
1/250.0 loss: 0.3888309895992279 
2/250.0 loss: 0.39116912086804706 
3/250.0 loss: 0.40637245774269104 
4/250.0 loss: 0.46193733215332033 
5/250.0 loss: 0.47094933191935223 
6/250.0 loss: 0.449678157057081 
7/250.0 loss: 0.4785948842763901 
8/250.0 loss: 0.4769847293694814 
9/250.0 loss: 0.4826521843671799 


214/250.0 loss: 0.5217100960570712 
215/250.0 loss: 0.5213259902127363 
216/250.0 loss: 0.5233627933373649 
217/250.0 loss: 0.5233349968944121 
218/250.0 loss: 0.5220680899543849 
219/250.0 loss: 0.5243619261817498 
220/250.0 loss: 0.5284934015565328 
221/250.0 loss: 0.527969623739655 
222/250.0 loss: 0.5288959851179422 
223/250.0 loss: 0.528619778342545 
224/250.0 loss: 0.5270292886098226 
225/250.0 loss: 0.5258541916851449 
226/250.0 loss: 0.5263085633122448 
227/250.0 loss: 0.5268909703744086 
228/250.0 loss: 0.5254884550925426 
229/250.0 loss: 0.5242591704363408 
230/250.0 loss: 0.5229185131334123 
231/250.0 loss: 0.5219710999136341 
232/250.0 loss: 0.5232894549170277 
233/250.0 loss: 0.5226701870560646 
234/250.0 loss: 0.5210972704785936 
235/250.0 loss: 0.5197333320975304 
236/250.0 loss: 0.5197779182894824 
237/250.0 loss: 0.518574356907556 
238/250.0 loss: 0.5212759965383856 
239/250.0 loss: 0.5203671641647816 
240/250.0 loss: 0.5193378325567206 
241/250.0 loss: 0.5196353727874

193/250.0 loss: 0.41269460086201887 
194/250.0 loss: 0.4131621640080061 
195/250.0 loss: 0.4115607668550647 
196/250.0 loss: 0.4105325787503102 
197/250.0 loss: 0.4118577502291612 
198/250.0 loss: 0.41175394845967317 
199/250.0 loss: 0.4126826174557209 
200/250.0 loss: 0.4163923413302768 
201/250.0 loss: 0.4175869838731124 
202/250.0 loss: 0.4160731315906412 
203/250.0 loss: 0.4185016875465711 
204/250.0 loss: 0.4174504034402894 
205/250.0 loss: 0.4183234327915803 
206/250.0 loss: 0.4186830650205198 
207/250.0 loss: 0.41965197026729584 
208/250.0 loss: 0.4197934520872016 
209/250.0 loss: 0.41977596084276836 
210/250.0 loss: 0.4240586172912923 
211/250.0 loss: 0.42270343033772595 
212/250.0 loss: 0.4224900485764087 
213/250.0 loss: 0.4210998140902163 
214/250.0 loss: 0.4198321484549101 
215/250.0 loss: 0.41900744544411145 
216/250.0 loss: 0.41907467634721834 
217/250.0 loss: 0.4180545445292368 
218/250.0 loss: 0.41718613291712114 
219/250.0 loss: 0.41619388142769986 
220/250.0 loss: 0.4

171/250.0 loss: 0.33398375339632813 
172/250.0 loss: 0.3348943983027012 
173/250.0 loss: 0.3363638262467823 
174/250.0 loss: 0.335765489254679 
175/250.0 loss: 0.33710010967810045 
176/250.0 loss: 0.3363280056391732 
177/250.0 loss: 0.33487767554568443 
178/250.0 loss: 0.33380285600543685 
179/250.0 loss: 0.33232608830763233 
180/250.0 loss: 0.3312934771674114 
181/250.0 loss: 0.32987504579372456 
182/250.0 loss: 0.328988796822686 
183/250.0 loss: 0.3285315630795515 
184/250.0 loss: 0.32867983978342363 
185/250.0 loss: 0.3292953318725991 
186/250.0 loss: 0.3307247293744495 
187/250.0 loss: 0.3295794382215814 
188/250.0 loss: 0.32859430418758795 
189/250.0 loss: 0.3277135041199232 
190/250.0 loss: 0.3268708809820145 
191/250.0 loss: 0.3269809697133799 
192/250.0 loss: 0.32614469165320226 
193/250.0 loss: 0.3262479358266309 
194/250.0 loss: 0.32604522896118654 
195/250.0 loss: 0.3247729363490124 
196/250.0 loss: 0.32355110907010015 
197/250.0 loss: 0.3255759081003642 
198/250.0 loss: 0.3

149/250.0 loss: 0.2778594579299291 
150/250.0 loss: 0.2764188405989811 
151/250.0 loss: 0.27578423226154164 
152/250.0 loss: 0.27709338531579847 
153/250.0 loss: 0.2756300452467683 
154/250.0 loss: 0.2784763789946033 
155/250.0 loss: 0.28111778390713227 
156/250.0 loss: 0.27966238487108497 
157/250.0 loss: 0.27821058808248256 
158/250.0 loss: 0.2767484505315247 
159/250.0 loss: 0.2779676512349397 
160/250.0 loss: 0.27661622584051226 
161/250.0 loss: 0.28171567319903845 
162/250.0 loss: 0.2831914546939493 
163/250.0 loss: 0.2817991009300075 
164/250.0 loss: 0.280374255934448 
165/250.0 loss: 0.2817937062761511 
166/250.0 loss: 0.28047936710858057 
167/250.0 loss: 0.2792643685381682 
168/250.0 loss: 0.2819688063357356 
169/250.0 loss: 0.2806292149512207 
170/250.0 loss: 0.28250250982785086 
171/250.0 loss: 0.2812374587235756 
172/250.0 loss: 0.2798831544687293 
173/250.0 loss: 0.2786172437428058 
174/250.0 loss: 0.27740669344152724 
175/250.0 loss: 0.27917600550096144 
176/250.0 loss: 0.

126/250.0 loss: 0.1975181594373673 
127/250.0 loss: 0.19621287191694137 
128/250.0 loss: 0.19500945668754188 
129/250.0 loss: 0.19665306312246963 
130/250.0 loss: 0.19605748132155595 
131/250.0 loss: 0.19934472276575185 
132/250.0 loss: 0.19872548474573104 
133/250.0 loss: 0.20212131420345003 
134/250.0 loss: 0.20135447784430452 
135/250.0 loss: 0.20023297643124618 
136/250.0 loss: 0.19929145095720344 
137/250.0 loss: 0.1984501311933433 
138/250.0 loss: 0.19717545647897738 
139/250.0 loss: 0.1965162896817284 
140/250.0 loss: 0.1953584531071127 
141/250.0 loss: 0.19614780939538295 
142/250.0 loss: 0.19533670786116925 
143/250.0 loss: 0.19417107556687874 
144/250.0 loss: 0.19295639304508422 
145/250.0 loss: 0.19179633947421018 
146/250.0 loss: 0.19443187212609514 
147/250.0 loss: 0.19337158299026055 
148/250.0 loss: 0.19289061803395716 
149/250.0 loss: 0.1917723848298192 
150/250.0 loss: 0.190794910557124 
151/250.0 loss: 0.18978933503508175 
152/250.0 loss: 0.18863129041772458 
153/250.

In [43]:
bert_clf.eval()
bert_predicted = []
all_logits = []
with torch.no_grad():
    for step_num, batch_data in enumerate(test_dataloader):

        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)

        logits = bert_clf(token_ids, masks)
        loss_func = nn.BCELoss()
        loss = loss_func(logits, labels)
        numpy_logits = logits.cpu().detach().numpy()
        
        bert_predicted += list(numpy_logits[:, 0] > 0.5)
        all_logits += list(numpy_logits[:, 0])
    

In [44]:
np.mean(bert_predicted)

0.62

In [45]:
print(classification_report(test_y, bert_predicted))

              precision    recall  f1-score   support

       False       0.89      0.68      0.77        50
        True       0.74      0.92      0.82        50

   micro avg       0.80      0.80      0.80       100
   macro avg       0.82      0.80      0.80       100
weighted avg       0.82      0.80      0.80       100

