In [1]:
import pandas as pd
import numpy as np
import transformers
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report

BATCH_SIZE = 32
MAX_LEN = 128
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
class ReviewDataset(Dataset):
    def __init__(self, reviews, targets, tokenizer, max_len=128):
        self.reviews = reviews
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        target = self.targets[idx]
        encoding = self.tokenizer.encode_plus(
          review,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
        )
        return {
          'review_text': review,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = ReviewDataset(
        reviews=df.sentence.to_numpy(),
        targets=df.has_spoiler.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
    )
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True
    )


In [4]:
df = pd.read_pickle('data/goodreads_sent_spoil_titles.pkl')
df.shape

(3534334, 2)

In [5]:
df.head()

Unnamed: 0,sentence,has_spoiler
0,The Maltese Falcon [SEP] Essential reading for...,0
1,The Maltese Falcon [SEP] Hammett is the grandf...,0
2,The Maltese Falcon [SEP] While his Continental...,0
3,"The Maltese Falcon [SEP] Expect sharp dialog, ...",0
4,The Maltese Falcon [SEP] If you've seen John H...,0


In [6]:
df.has_spoiler.value_counts()

0    3419580
1     114754
Name: has_spoiler, dtype: int64

In [7]:
# let's downsample 0 first
df = pd.concat([df[df.has_spoiler==1], df[df.has_spoiler==0].sample(len(df[df.has_spoiler==1]))]).sample(frac=1)

In [8]:
df_train, df_test = train_test_split(
  df,
  test_size=0.2,
  random_state=42
)
df_val, df_test = train_test_split(
  df_test,
  test_size=0.5,
  random_state=42
)

In [9]:
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

In [10]:
ni = next(iter(train_data_loader))
ni['review_text'][8], ni['input_ids'][8]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


("Days of Blood & Starlight (Daughter of Smoke & Bone, #2) [SEP] :lol: Also I'm unapologetically on Team Ziri.",
 tensor([  101,  6637,  1104,  5657,   111,  2537,  4568,   113, 16039,  1104,
         19440,   111, 17722,   117,   108,   123,   114,   102,   131, 25338,
          1233,   131,  2907,   146,   112,   182,  8362, 11478, 12805, 16609,
          9203,  1113,  2649,   163, 17262,   119,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,   

In [11]:
tokenizer.encode_plus(
          'Diabolic',
          add_special_tokens=True,
          max_length=MAX_LEN,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
        )

{'input_ids': tensor([[  101, 12120,  6639, 14987,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

# LSTM Model Test

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMSpoilerClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(LSTMSpoilerClassifier, self).__init__()
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.dropout1 = nn.Dropout2d(0.25)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, dropout=0.1, batch_first=True, num_layers=2)
        self.dropout2 = nn.Dropout(0.4)

        self.linear = nn.Linear(hidden_dim, 1)

    def forward(self, sentence):
        embeds = self.embedding(sentence)
        embeds = embeds.unsqueeze(2)
        embeds = embeds.permute(0, 3, 2, 1)
        embeds = self.dropout1(embeds)
        embeds = embeds.permute(0, 3, 2, 1)
        embeds = embeds.squeeze(2)
        
        out, _ = self.lstm(embeds)
        out = self.dropout2(out)

        return self.linear(out[:,-1,:])


In [13]:
from torch import nn, optim


def train_epoch(
    model,
    data_loader,
    loss_fn,
    optimizer,
    device,
    n_examples,
    scheduler=None,
    model_type='transformer'
):
    model = model.train()
    losses = []
    avg_losses = []
    aurocs = []
    avg_aurocs = []
    correct_predictions = 0
    i = 0
    t0 = time()
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        targets = d["targets"].to(device).view(-1, 1)
        outputs = torch.zeros_like(targets)
        
        if model_type == 'LSTM':
            outputs = model(input_ids)
        else:
            attention_mask = d["attention_mask"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        preds = torch.zeros_like(outputs)
        ones = torch.ones_like(preds)
        preds = torch.where(outputs < 0, preds, ones)
        
        loss = loss_fn(outputs, targets.float())
        
        correct_predictions += torch.sum(preds == targets)
        aurocs.append(roc_auc_score(targets.cpu(), preds.cpu()))
        losses.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()
        optimizer.zero_grad()
        i += 1
        if i % 100 == 0:
            avg_aurocs.append(np.mean(aurocs[i-100:i]))
            avg_losses.append(np.mean(losses[i-100:i]))
            print(i, 'iters, auroc, loss, time : ', avg_aurocs[-1], avg_losses[-1], time()-t0)

    return correct_predictions.double() / n_examples, np.mean(losses), avg_losses, avg_aurocs


In [14]:
def eval_model(model, data_loader, loss_fn, device, n_examples, model_type='transformer'):
    model = model.eval()
    losses = []
    aurocs = []
    correct_predictions = 0
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            targets = d["targets"].to(device).view(-1, 1)
            outputs = torch.zeros_like(targets)
            
            if model_type == 'LSTM':
                outputs = model(input_ids)
            else:
                attention_mask = d["attention_mask"].to(device)
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

            preds = torch.zeros_like(outputs)
            ones = torch.ones_like(preds)
            preds = torch.where(outputs < 0, preds, ones)

            loss = loss_fn(outputs, targets.float())
            
            correct_predictions += torch.sum(preds == targets)
            aurocs.append(roc_auc_score(targets.cpu(), preds.cpu()))
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses), np.mean(aurocs)


In [15]:
tokenizer.vocab_size

28996

In [16]:
EPOCHS = 20
model = LSTMSpoilerClassifier(128, 128, tokenizer.vocab_size + 1)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.003)

loss_fn = nn.BCEWithLogitsLoss().to(device)

In [17]:
import warnings
warnings.filterwarnings('ignore')

In [18]:
%%time
from collections import defaultdict
from time import time

history = defaultdict(list)
best_auroc = 0
for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss, train_avg_losses, train_auroc = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer,
    device,
    len(df_train),
    scheduler=None,
    model_type='LSTM'
  )
  print(f'Train loss {train_loss} accuracy {train_acc}  auroc {np.mean(train_auroc)}')
  val_acc, val_loss, val_auroc = eval_model(
    model,
    val_data_loader,
    loss_fn,
    device,
    len(df_val),
    model_type='LSTM'
  )
  print(f'Val   loss {val_loss} accuracy {val_acc} auroc {val_auroc}')
  print()
  history['train_auroc'] += train_auroc
  history['train_loss'] += train_avg_losses
  history['val_auroc'].append(val_auroc)
  history['val_loss'].append(val_loss)
  if val_auroc > best_auroc:
    torch.save(model.state_dict(), 'lstm_best_model_state.bin')
    best_auroc = val_auroc

Epoch 1/20
----------
100 iters, auroc, loss, time :  0.5056341314129628 0.696103281378746 3.507298231124878
200 iters, auroc, loss, time :  0.4985214590228723 0.6942408961057663 6.253425121307373
300 iters, auroc, loss, time :  0.49380076796698313 0.6936281329393387 8.991044282913208
400 iters, auroc, loss, time :  0.49794757845783394 0.693161916732788 11.754669427871704
500 iters, auroc, loss, time :  0.49651023298241964 0.6936494219303131 14.533803462982178
600 iters, auroc, loss, time :  0.5053956912895363 0.6942979383468628 17.326940536499023
700 iters, auroc, loss, time :  0.5108229292637188 0.6932029539346695 20.08956551551819
800 iters, auroc, loss, time :  0.5001433792021606 0.6931630688905716 22.83718729019165
900 iters, auroc, loss, time :  0.5067665248798627 0.6931708770990371 25.575806856155396
1000 iters, auroc, loss, time :  0.4883055900004783 0.6936912286281586 28.301424026489258
1100 iters, auroc, loss, time :  0.48764955546689287 0.694076127409935 31.066049098968506
1

3400 iters, auroc, loss, time :  0.6566006480554114 0.6304172819852829 94.52043604850769
3500 iters, auroc, loss, time :  0.6543334544282686 0.6268904867768288 97.29306364059448
3600 iters, auroc, loss, time :  0.6352257231932155 0.637335154414177 100.11320662498474
3700 iters, auroc, loss, time :  0.650905611686701 0.6318503403663636 102.91684126853943
3800 iters, auroc, loss, time :  0.6529454532070731 0.6279098537564277 105.72298073768616
3900 iters, auroc, loss, time :  0.6545552689288472 0.6292661184072494 108.51211643218994
4000 iters, auroc, loss, time :  0.6814491877067459 0.6043163549900055 111.26874017715454
4100 iters, auroc, loss, time :  0.6678388374275102 0.6223081415891647 114.03087043762207
4200 iters, auroc, loss, time :  0.6565683356984172 0.6380071723461151 116.81249928474426
4300 iters, auroc, loss, time :  0.6545247358364702 0.6184215167164803 119.57963061332703
4400 iters, auroc, loss, time :  0.6749939007500927 0.6226626524329185 122.4002685546875
4500 iters, aur

800 iters, auroc, loss, time :  0.7116963112669551 0.570658663213253 22.25502610206604
900 iters, auroc, loss, time :  0.721828470755998 0.5481592100858689 25.036160230636597
1000 iters, auroc, loss, time :  0.7273831457742986 0.5507377234101295 27.79778504371643
1100 iters, auroc, loss, time :  0.7138345414303618 0.5656679317355156 30.57191753387451
1200 iters, auroc, loss, time :  0.7277759272489301 0.5509052482247353 33.38555383682251
1300 iters, auroc, loss, time :  0.7124267629487292 0.5640265908837319 36.1676881313324
1400 iters, auroc, loss, time :  0.7179312770581447 0.5605403411388398 38.951823234558105
1500 iters, auroc, loss, time :  0.7253955360253482 0.5606430351734162 41.7044460773468
1600 iters, auroc, loss, time :  0.7344936501304367 0.5459060269594193 44.45257234573364
1700 iters, auroc, loss, time :  0.7134556394070324 0.5604409539699554 47.22119903564453
1800 iters, auroc, loss, time :  0.7064094215561438 0.575375318825245 49.973326206207275
1900 iters, auroc, loss, 

4200 iters, auroc, loss, time :  0.730439324221605 0.5385742071270943 116.73260855674744
4300 iters, auroc, loss, time :  0.7333823236727228 0.5318632617592811 119.52776336669922
4400 iters, auroc, loss, time :  0.7297629511320032 0.5317714804410935 122.3193953037262
4500 iters, auroc, loss, time :  0.7490791057454758 0.516124676167965 125.09554672241211
4600 iters, auroc, loss, time :  0.7341180840825393 0.5374943307042122 127.87317514419556
4700 iters, auroc, loss, time :  0.7373894110369072 0.527716127038002 130.6523277759552
4800 iters, auroc, loss, time :  0.7276321463110483 0.548678049147129 133.41095185279846
4900 iters, auroc, loss, time :  0.7426362293966784 0.5274134770035743 136.18510246276855
5000 iters, auroc, loss, time :  0.7236392525864528 0.5443230420351028 138.96773219108582
5100 iters, auroc, loss, time :  0.764455139480328 0.5043703603744507 141.78689336776733
5200 iters, auroc, loss, time :  0.7423453273450648 0.5170297753810883 144.56352138519287
5300 iters, auroc

1600 iters, auroc, loss, time :  0.7544945721334199 0.5055798670649528 44.679614782333374
1700 iters, auroc, loss, time :  0.7698560444132525 0.4982479766011238 47.46774983406067
1800 iters, auroc, loss, time :  0.7504678608851327 0.5202221572399139 50.27338457107544
1900 iters, auroc, loss, time :  0.7590740870125074 0.500416105389595 53.10102438926697
2000 iters, auroc, loss, time :  0.7475368383980999 0.5104706993699074 55.879653215408325
2100 iters, auroc, loss, time :  0.7707958095844419 0.48109748303890226 58.651784896850586
2200 iters, auroc, loss, time :  0.7585346767740344 0.5040212634205818 61.42291712760925
2300 iters, auroc, loss, time :  0.7559793000911086 0.5104172122478485 64.18754243850708
2400 iters, auroc, loss, time :  0.7543966356738937 0.5032137182354927 66.94567203521729
2500 iters, auroc, loss, time :  0.7581998187966889 0.5056246575713158 69.69329380989075
2600 iters, auroc, loss, time :  0.7647438026603375 0.5041597226262092 72.489431142807
2700 iters, auroc, l

4900 iters, auroc, loss, time :  0.7727297768815218 0.4819739231467247 136.28092169761658
5000 iters, auroc, loss, time :  0.7668197719500554 0.49075698494911196 139.08955717086792
5100 iters, auroc, loss, time :  0.7675610246456255 0.49044061452150345 141.8736915588379
5200 iters, auroc, loss, time :  0.7655894164259187 0.48632364988327026 144.66382765769958
5300 iters, auroc, loss, time :  0.7541937332251288 0.48752554386854174 147.4484577178955
5400 iters, auroc, loss, time :  0.7625071751602073 0.4954091501235962 150.209082365036
5500 iters, auroc, loss, time :  0.7610528781718796 0.4994657826423645 152.97921347618103
5600 iters, auroc, loss, time :  0.7578338410600408 0.4973536518216133 155.74783992767334
5700 iters, auroc, loss, time :  0.7702590830950805 0.48913665741682055 158.58248138427734
Train loss 0.49029089425001965 accuracy 0.7655741097785476  auroc 0.7649721325072737
Val   loss 0.5216033898273218 accuracy 0.7452398588296807 auroc 0.7459701526937509

Epoch 9/20
---------

2300 iters, auroc, loss, time :  0.7752828435324566 0.481585351228714 64.25910806655884
2400 iters, auroc, loss, time :  0.7734180049181983 0.47203361421823503 67.06874346733093
2500 iters, auroc, loss, time :  0.7754960268274896 0.47819055020809176 69.87637901306152
2600 iters, auroc, loss, time :  0.7865934068221978 0.4669159197807312 72.65400743484497
2700 iters, auroc, loss, time :  0.7799589792909639 0.4725782346725464 75.4496397972107
2800 iters, auroc, loss, time :  0.7748902713017752 0.4826607382297516 78.22277188301086
2900 iters, auroc, loss, time :  0.79069962790795 0.46090446710586547 80.98490118980408
3000 iters, auroc, loss, time :  0.7744774153593897 0.47144490242004394 83.76753044128418
3100 iters, auroc, loss, time :  0.7850490996051704 0.46873598247766496 86.55266547203064
3200 iters, auroc, loss, time :  0.7654094965753498 0.49516832739114763 89.3432970046997
3300 iters, auroc, loss, time :  0.7763123237934871 0.47525364011526106 92.15193247795105
3400 iters, auroc, 

5600 iters, auroc, loss, time :  0.7704038428807324 0.4843812514841557 155.7987139225006
5700 iters, auroc, loss, time :  0.7725924529388065 0.48063719511032105 158.6118552684784
Train loss 0.4667291696601813 accuracy 0.7818535341982288  auroc 0.7819914020290765
Val   loss 0.5077828108060659 accuracy 0.7475926974859484 auroc 0.747875504719408

Epoch 12/20
----------
100 iters, auroc, loss, time :  0.7841027958623614 0.4489326500892639 2.835632801055908
200 iters, auroc, loss, time :  0.7865326532912431 0.4589266939461231 5.6142613887786865
300 iters, auroc, loss, time :  0.7772115244894051 0.4675405931472778 8.375390768051147
400 iters, auroc, loss, time :  0.8023149834814309 0.4352897065877914 11.153019189834595
500 iters, auroc, loss, time :  0.775026567854716 0.4701409015059471 13.911147594451904
600 iters, auroc, loss, time :  0.773406824599549 0.47331968814134595 16.67777371406555
700 iters, auroc, loss, time :  0.7884744392109365 0.4494171196222305 19.477406978607178
800 iters, a

3000 iters, auroc, loss, time :  0.7962792274337738 0.45066003665328025 83.36791276931763
3100 iters, auroc, loss, time :  0.7793702306932384 0.4680487781763077 86.17305207252502
3200 iters, auroc, loss, time :  0.7892708222234238 0.45946107387542723 88.99419474601746
3300 iters, auroc, loss, time :  0.7862306597873717 0.4658225938677788 91.76282119750977
3400 iters, auroc, loss, time :  0.7857787368813414 0.4519660350680351 94.5314474105835
3500 iters, auroc, loss, time :  0.7808074095863081 0.4578387124836445 97.28557515144348
3600 iters, auroc, loss, time :  0.797928687985038 0.44269969761371614 100.04370474815369
3700 iters, auroc, loss, time :  0.7758731432984615 0.4703147625923157 102.81633186340332
3800 iters, auroc, loss, time :  0.8012944307793584 0.44420813769102097 105.59846639633179
3900 iters, auroc, loss, time :  0.783274510931068 0.4596172997355461 108.40010046958923
4000 iters, auroc, loss, time :  0.8018406605257039 0.43194596529006957 111.22624468803406
4100 iters, au

400 iters, auroc, loss, time :  0.7892470911889743 0.4431622962653637 11.302062034606934
500 iters, auroc, loss, time :  0.8042337432595896 0.4203598964214325 14.064687490463257
600 iters, auroc, loss, time :  0.7931241231954956 0.44568360134959223 16.857318878173828
700 iters, auroc, loss, time :  0.797674224299317 0.43591583013534546 19.656952619552612
800 iters, auroc, loss, time :  0.807122915139556 0.42743513360619545 22.464587688446045
900 iters, auroc, loss, time :  0.8038463044187232 0.4376411381363869 25.260220050811768
1000 iters, auroc, loss, time :  0.7834755824422284 0.4538119772076607 28.029846668243408
1100 iters, auroc, loss, time :  0.8030519142911575 0.44217077881097794 30.793472051620483
1200 iters, auroc, loss, time :  0.7996739982396867 0.4316512405872345 33.5550971031189
1300 iters, auroc, loss, time :  0.7986117515273271 0.4485708749294281 36.31272077560425
1400 iters, auroc, loss, time :  0.7982981983588643 0.44966991633176806 39.09735083580017
1500 iters, auroc

3700 iters, auroc, loss, time :  0.7882604266041388 0.4542626678943634 102.75985932350159
3800 iters, auroc, loss, time :  0.8021246220240026 0.4293617004156113 105.5674946308136
3900 iters, auroc, loss, time :  0.8035624445587765 0.4471716710925102 108.35412526130676
4000 iters, auroc, loss, time :  0.7943692383281156 0.43661881059408186 111.14175581932068
4100 iters, auroc, loss, time :  0.7974981262315781 0.4378155091404915 113.90538120269775
4200 iters, auroc, loss, time :  0.7907345340203312 0.4477998712658882 116.66500568389893
4300 iters, auroc, loss, time :  0.7970311904512755 0.4491949412226677 119.4381377696991
4400 iters, auroc, loss, time :  0.8041959265536666 0.4475304937362671 122.21876692771912
4500 iters, auroc, loss, time :  0.7898168728285668 0.4511301609873772 124.97139859199524
4600 iters, auroc, loss, time :  0.7829063309533006 0.4549855978786945 127.77503323554993
4700 iters, auroc, loss, time :  0.7920779051996747 0.445872161090374 130.5726659297943
4800 iters, a

1100 iters, auroc, loss, time :  0.7996224810449268 0.43709814205765724 30.56490683555603
1200 iters, auroc, loss, time :  0.8101314484906742 0.4242499741911888 33.364540100097656
1300 iters, auroc, loss, time :  0.7941471906146828 0.44131165266036987 36.17017459869385
1400 iters, auroc, loss, time :  0.8096343702249428 0.43092001259326934 38.97480916976929
1500 iters, auroc, loss, time :  0.8031291028550501 0.4338154980540276 41.78944635391235
1600 iters, auroc, loss, time :  0.8024386654880918 0.4274004776775837 44.583078384399414
1700 iters, auroc, loss, time :  0.7990646894814852 0.4236337211728096 47.340702295303345
1800 iters, auroc, loss, time :  0.7996105415631765 0.4382718575000763 50.08582878112793
1900 iters, auroc, loss, time :  0.784974054840928 0.4413318721950054 52.847453355789185
2000 iters, auroc, loss, time :  0.7954723322091155 0.44108046919107435 55.616079807281494
2100 iters, auroc, loss, time :  0.7965642001387507 0.44889860332012177 58.41671347618103
2200 iters, 

4400 iters, auroc, loss, time :  0.7958632130253648 0.44794058680534365 122.14376831054688
4500 iters, auroc, loss, time :  0.7977634660227675 0.4395627951622009 124.93440890312195
4600 iters, auroc, loss, time :  0.8001202916450595 0.4278684124350548 127.71503806114197
4700 iters, auroc, loss, time :  0.8028787462612172 0.4250856776535511 130.51319408416748
4800 iters, auroc, loss, time :  0.7989365840501761 0.4385932357609272 133.3058261871338
4900 iters, auroc, loss, time :  0.7977818620615262 0.43381144776940345 136.05797171592712
5000 iters, auroc, loss, time :  0.8004376339277824 0.43929546147584914 138.83012175559998
5100 iters, auroc, loss, time :  0.80715173216876 0.43378781765699387 141.60974192619324
5200 iters, auroc, loss, time :  0.7924870486625517 0.4415299394726753 144.3628692626953
5300 iters, auroc, loss, time :  0.8095401676506216 0.42799960568547246 147.16250276565552
5400 iters, auroc, loss, time :  0.8020127173495056 0.428125212341547 149.9716432094574
5500 iters,

ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

In [23]:
def get_predictions(model, data_loader):
  model = model.eval()
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []
  with torch.no_grad():
    for d in data_loader:
      texts = d["review_text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      targets = d["targets"].to(device)
      outputs = torch.zeros_like(targets)
      outputs = model(input_ids)
      preds = torch.zeros_like(outputs)
      ones = torch.ones_like(preds)
      preds = torch.where(outputs < 0, preds, ones)
      review_texts.extend(texts)
      predictions.extend(preds)
      prediction_probs.extend(outputs)
      real_values.extend(targets)
  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values

In [24]:
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [25]:
class_names = ['no_spoilers', 'has_spoilers']
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.7618126077593343
              precision    recall  f1-score   support

 no_spoilers       0.78      0.73      0.75     11464
has_spoilers       0.75      0.79      0.77     11487

    accuracy                           0.76     22951
   macro avg       0.76      0.76      0.76     22951
weighted avg       0.76      0.76      0.76     22951



In [26]:
test_df = pd.read_pickle('data/15ktest_goodreads_sent_spoil_titles.pkl')
test_dl = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_dl
)
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.7643611541477016
              precision    recall  f1-score   support

 no_spoilers       0.99      0.73      0.84     14485
has_spoilers       0.10      0.80      0.17       516

    accuracy                           0.73     15001
   macro avg       0.54      0.76      0.51     15001
weighted avg       0.96      0.73      0.82     15001



# Conclusion
LSTM achieves 0.75 AUROC and seems to stall there, at least on the balanced dataset with BERT tokenizer and a slightly simplified architecture. It's much faster, so might make sense to experiment further, at least to roughly match the performance from the paper.
UPDATE: AdamW with linear scheduler didn't seem to take off. Stayed around 0.5
But Adam with lr=0.003 Performed the best with test AUROC 0.78
Increasing hidden size doesn't help

# BERT test

In [15]:
from transformers import BertModel


class SpoilerClassifier(nn.Module):

    def __init__(self, n_classes, model_name='bert-base-cased', dropout=0.3):
        super(SpoilerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.drop = nn.Dropout(p=dropout)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask
        ).to_tuple()
        output = self.drop(pooled_output)
        return self.out(output)

In [16]:
model = SpoilerClassifier(1)
model = model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
EPOCHS=5
total_steps = len(train_data_loader) * EPOCHS
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

In [18]:
loss_fn = nn.BCEWithLogitsLoss().to(device)

import warnings
warnings.filterwarnings('ignore')

In [19]:
%%time
from collections import defaultdict
from time import time

history = defaultdict(list)
best_auroc = 0
for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss, train_avg_losses, train_auroc = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer,
    device,
    len(df_train),
    scheduler=scheduler
  )
  print(f'Train loss {train_loss} accuracy {train_acc}  auroc {np.mean(train_auroc)}')
  val_acc, val_loss, val_auroc = eval_model(
    model,
    val_data_loader,
    loss_fn,
    device,
    len(df_val)
  )
  print(f'Val   loss {val_loss} accuracy {val_acc} auroc {val_auroc}')
  print()
  history['train_auroc'] += train_auroc
  history['train_loss'] += train_avg_losses
  history['val_auroc'].append(val_auroc)
  history['val_loss'].append(val_loss)
  if val_auroc > best_auroc:
    torch.save(model.state_dict(), 'best_model_state.bin')
    best_auroc = val_auroc

Epoch 1/5
----------
100 iters, auroc, loss, time :  0.6073380191044428 0.6589620217680932 30.66746973991394
200 iters, auroc, loss, time :  0.6603143513039359 0.619895167350769 60.00142741203308
300 iters, auroc, loss, time :  0.6853292075338028 0.5932223698496819 89.45625805854797
400 iters, auroc, loss, time :  0.7143135987567271 0.5570016050338745 118.87770128250122
500 iters, auroc, loss, time :  0.6978554373491781 0.5736238944530487 147.88466691970825
600 iters, auroc, loss, time :  0.7011920317534129 0.5695155012607575 176.9453945159912
700 iters, auroc, loss, time :  0.7150421311518531 0.5481693094968796 205.9999713897705
800 iters, auroc, loss, time :  0.7163693076761035 0.5656557375192642 235.0478081703186
900 iters, auroc, loss, time :  0.715884376146617 0.5514068275690078 264.0816526412964
1000 iters, auroc, loss, time :  0.7180947906321106 0.5505121290683747 293.13843059539795
1100 iters, auroc, loss, time :  0.7284679484490024 0.53363265812397 322.139466047287
1200 iters,

3400 iters, auroc, loss, time :  0.8421060991167415 0.37036869943141937 987.8345942497253
3500 iters, auroc, loss, time :  0.8394950056627017 0.36238186821341517 1016.9964461326599
3600 iters, auroc, loss, time :  0.8365431935410649 0.3724065293371677 1046.0672607421875
3700 iters, auroc, loss, time :  0.8511920700810609 0.3556854476034641 1075.1085867881775
3800 iters, auroc, loss, time :  0.8361202026387784 0.3768331180512905 1104.1769170761108
3900 iters, auroc, loss, time :  0.8387127277164512 0.3655434301495552 1133.2007610797882
4000 iters, auroc, loss, time :  0.8412193312945685 0.3635358515381813 1162.2374501228333
4100 iters, auroc, loss, time :  0.8325858489215351 0.3782467573881149 1191.3521196842194
4200 iters, auroc, loss, time :  0.8344822843817242 0.3665767501294613 1220.394289970398
4300 iters, auroc, loss, time :  0.8412494519472241 0.36837791204452514 1249.4221794605255
4400 iters, auroc, loss, time :  0.8400256845071593 0.36825042754411696 1278.447521686554
4500 iter

800 iters, auroc, loss, time :  0.9266371838244987 0.19038855619728565 233.56610441207886
900 iters, auroc, loss, time :  0.9282653033390684 0.19500117868185043 262.71165776252747
1000 iters, auroc, loss, time :  0.9240498425722883 0.19614223174750806 291.91273975372314
1100 iters, auroc, loss, time :  0.9169034317729803 0.19837752789258956 321.06520104408264
1200 iters, auroc, loss, time :  0.9188762359619189 0.2061696244031191 350.2571792602539
1300 iters, auroc, loss, time :  0.9294236194596521 0.18088679447770117 379.3787007331848
1400 iters, auroc, loss, time :  0.9211956342236158 0.2017993625998497 408.5717496871948
1500 iters, auroc, loss, time :  0.9238808065803119 0.19799875192344188 437.72119331359863
1600 iters, auroc, loss, time :  0.9171218282878112 0.21307206796482206 466.83413791656494
1700 iters, auroc, loss, time :  0.9197250915364306 0.19830245289951562 496.0142526626587
1800 iters, auroc, loss, time :  0.9244002059816688 0.2018032491207123 525.1762132644653
1900 iter

4100 iters, auroc, loss, time :  0.947265760866577 0.14523312639445066 1191.0941898822784
4200 iters, auroc, loss, time :  0.9475673571223947 0.13836329544894396 1220.4636182785034
4300 iters, auroc, loss, time :  0.9535378754329605 0.14004778811708093 1249.6656568050385
4400 iters, auroc, loss, time :  0.9431443880551853 0.1595402224827558 1278.9215445518494
4500 iters, auroc, loss, time :  0.9448428991048289 0.15589131323620678 1307.9711842536926
4600 iters, auroc, loss, time :  0.9451992878122908 0.1563861914910376 1337.0890820026398
4700 iters, auroc, loss, time :  0.9459315335796775 0.14897262210026382 1366.1789309978485
4800 iters, auroc, loss, time :  0.9487766648490752 0.1299898583535105 1395.2446827888489
4900 iters, auroc, loss, time :  0.9412014241563391 0.14538740644231438 1424.3029561042786
5000 iters, auroc, loss, time :  0.9443496168219464 0.154413817608729 1453.4005970954895
5100 iters, auroc, loss, time :  0.9520863508051975 0.13301018133759498 1482.4890656471252
5200 

In [20]:
def get_predictions(model, data_loader):
  model = model.eval()
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []
  with torch.no_grad():
    for d in data_loader:
      texts = d["review_text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      targets = d["targets"].to(device)
      outputs = torch.zeros_like(targets)
      outputs = model(
        input_ids=input_ids, attention_mask=attention_mask
      )
      preds = torch.zeros_like(outputs)
      ones = torch.ones_like(preds)
      preds = torch.where(outputs < 0, preds, ones)
      review_texts.extend(texts)
      predictions.extend(preds)
      prediction_probs.extend(outputs)
      real_values.extend(targets)
  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values

In [21]:
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [22]:
class_names = ['no_spoilers', 'has_spoilers']
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.8000746826477275
              precision    recall  f1-score   support

 no_spoilers       0.80      0.79      0.80     11519
has_spoilers       0.80      0.81      0.80     11432

    accuracy                           0.80     22951
   macro avg       0.80      0.80      0.80     22951
weighted avg       0.80      0.80      0.80     22951



In [23]:
test_df = pd.read_pickle('data/15ktest_goodreads_sent_spoil_titles.pkl')
test_dl = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_dl
)
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.864916593749749
              precision    recall  f1-score   support

 no_spoilers       1.00      0.79      0.88     14485
has_spoilers       0.14      0.94      0.24       516

    accuracy                           0.80     15001
   macro avg       0.57      0.86      0.56     15001
weighted avg       0.97      0.80      0.86     15001

