In [1]:
import pandas as pd
import numpy as np
import transformers
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report

BATCH_SIZE = 32
MAX_LEN = 128
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [18]:
df = pd.read_pickle('data/goodreads_sent_spoil_titles.pkl')
len(df[df.has_spoiler == 0]) / len(df[df.has_spoiler == 1])

29.805480171583646

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
class ReviewDataset(Dataset):
    def __init__(self, reviews, targets, tokenizer, max_len=128):
        self.reviews = reviews
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        target = self.targets[idx]
        encoding = self.tokenizer.encode_plus(
          review,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
        )
        return {
          'review_text': review,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = ReviewDataset(
        reviews=df.sentence.to_numpy(),
        targets=df.has_spoiler.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
    )
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True
    )


In [4]:
df = pd.read_pickle('data/goodreads_sent_spoil_titles.pkl')
df.shape

(3518910, 2)

In [5]:
df.head()

Unnamed: 0,sentence,has_spoiler
0,The Maltese Falcon [SEP] Essential reading for...,0
1,The Maltese Falcon [SEP] Hammett is the grandf...,0
2,The Maltese Falcon [SEP] While his Continental...,0
3,"The Maltese Falcon [SEP] Expect sharp dialog, ...",0
4,The Maltese Falcon [SEP] If you've seen John H...,0


In [6]:
df.has_spoiler.value_counts()

0    3404680
1     114230
Name: has_spoiler, dtype: int64

In [7]:
# let's downsample 0 first
df = pd.concat([df[df.has_spoiler==1], df[df.has_spoiler==0].sample(len(df[df.has_spoiler==1]))]).sample(frac=1)

In [8]:
df_train, df_test = train_test_split(
  df,
  test_size=0.2,
  random_state=42
)
df_val, df_test = train_test_split(
  df_test,
  test_size=0.5,
  random_state=42
)

In [9]:
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

In [10]:
ni = next(iter(train_data_loader))
ni['review_text'][8], ni['input_ids'][8]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


("Ever [SEP] Levine's books so much.",
 tensor([  101, 10006,   102, 19319,   112,   188,  2146,  1177,  1277,   119,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             

In [11]:
tokenizer.encode_plus(
          'Diabolic',
          add_special_tokens=True,
          max_length=MAX_LEN,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
        )

{'input_ids': tensor([[  101, 12120,  6639, 14987,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

# LSTM Model Test

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMSpoilerClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(LSTMSpoilerClassifier, self).__init__()
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.dropout1 = nn.Dropout2d(0.25)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, dropout=0.1, batch_first=True, num_layers=2)
        self.dropout2 = nn.Dropout(0.4)

        self.linear = nn.Linear(hidden_dim, 1)

    def forward(self, sentence):
        embeds = self.embedding(sentence)
        embeds = embeds.unsqueeze(2)
        embeds = embeds.permute(0, 3, 2, 1)
        embeds = self.dropout1(embeds)
        embeds = embeds.permute(0, 3, 2, 1)
        embeds = embeds.squeeze(2)
        
        out, _ = self.lstm(embeds)
        out = self.dropout2(out)

        return self.linear(out[:,-1,:])


In [13]:
from torch import nn, optim


def train_epoch(
    model,
    data_loader,
    loss_fn,
    optimizer,
    device,
    n_examples,
    scheduler=None,
    model_type='transformer'
):
    model = model.train()
    losses = []
    avg_losses = []
    aurocs = []
    avg_aurocs = []
    correct_predictions = 0
    i = 0
    t0 = time()
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        targets = d["targets"].to(device).view(-1, 1)
        outputs = torch.zeros_like(targets)
        
        if model_type == 'LSTM':
            outputs = model(input_ids)
        else:
            attention_mask = d["attention_mask"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        preds = torch.zeros_like(outputs)
        ones = torch.ones_like(preds)
        preds = torch.where(outputs < 0, preds, ones)
        
        loss = loss_fn(outputs, targets.float())
        
        correct_predictions += torch.sum(preds == targets)
        aurocs.append(roc_auc_score(targets.cpu(), preds.cpu()))
        losses.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()
        optimizer.zero_grad()
        i += 1
        if i % 100 == 0:
            avg_aurocs.append(np.mean(aurocs[i-100:i]))
            avg_losses.append(np.mean(losses[i-100:i]))
            print(i, 'iters, auroc, loss, time : ', avg_aurocs[-1], avg_losses[-1], time()-t0)

    return correct_predictions.double() / n_examples, np.mean(losses), avg_losses, avg_aurocs


In [12]:
def eval_model(model, data_loader, loss_fn, device, n_examples, model_type='transformer'):
    model = model.eval()
    losses = []
    aurocs = []
    correct_predictions = 0
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            targets = d["targets"].to(device).view(-1, 1)
            outputs = torch.zeros_like(targets)
            
            if model_type == 'LSTM':
                outputs = model(input_ids)
            else:
                attention_mask = d["attention_mask"].to(device)
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

            preds = torch.zeros_like(outputs)
            ones = torch.ones_like(preds)
            preds = torch.where(outputs < 0, preds, ones)

            loss = loss_fn(outputs, targets.float())
            
            correct_predictions += torch.sum(preds == targets)
            aurocs.append(roc_auc_score(targets.cpu(), preds.cpu()))
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses), np.mean(aurocs)


In [15]:
tokenizer.vocab_size

28996

In [16]:
EPOCHS = 20
model = LSTMSpoilerClassifier(128, 128, tokenizer.vocab_size + 1)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.003)

loss_fn = nn.BCEWithLogitsLoss().to(device)

In [6]:
import warnings
warnings.filterwarnings('ignore')

In [19]:
%%time
from collections import defaultdict
from time import time

#history = defaultdict(list)
#best_auroc = 0
for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss, train_avg_losses, train_auroc = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer,
    device,
    len(df_train),
    scheduler=None,
    model_type='LSTM'
  )
  print(f'Train loss {train_loss} accuracy {train_acc}  auroc {np.mean(train_auroc)}')
  val_acc, val_loss, val_auroc = eval_model(
    model,
    val_data_loader,
    loss_fn,
    device,
    len(df_val),
    model_type='LSTM'
  )
  print(f'Val   loss {val_loss} accuracy {val_acc} auroc {val_auroc}')
  print()
  history['train_auroc'] += train_auroc
  history['train_loss'] += train_avg_losses
  history['val_auroc'].append(val_auroc)
  history['val_loss'].append(val_loss)
  if val_auroc > best_auroc:
    torch.save(model.state_dict(), 'lstm_best_model_state.bin')
    best_auroc = val_auroc

Epoch 1/20
----------
100 iters, auroc, loss, time :  0.8023473638075032 0.4181264296174049 2.86763072013855
200 iters, auroc, loss, time :  0.8152275706571963 0.39972867876291274 5.68964147567749
300 iters, auroc, loss, time :  0.8274172916436849 0.38704847037792206 8.48624300956726
400 iters, auroc, loss, time :  0.81464845043027 0.4083470715582371 11.287711143493652
500 iters, auroc, loss, time :  0.8162584826594116 0.41139460265636446 14.090497493743896
600 iters, auroc, loss, time :  0.8169312511683289 0.4113690695166588 16.884098291397095
700 iters, auroc, loss, time :  0.808247918215629 0.42577339440584183 19.694702863693237
800 iters, auroc, loss, time :  0.8123361754989076 0.42857020869851115 22.47830104827881
900 iters, auroc, loss, time :  0.810771404996792 0.407685304582119 25.328914403915405
1000 iters, auroc, loss, time :  0.8152755680515473 0.41500212281942367 28.155717134475708
1100 iters, auroc, loss, time :  0.8067160711751269 0.41794635698199273 30.961320638656616
12

3400 iters, auroc, loss, time :  0.7996718068372901 0.4282331296801567 95.64281368255615
3500 iters, auroc, loss, time :  0.8193825791530227 0.4198426681756973 98.59246635437012
3600 iters, auroc, loss, time :  0.8164752613900212 0.4103186173737049 101.42909216880798
3700 iters, auroc, loss, time :  0.8154305790127148 0.4153905537724495 104.23771214485168
3800 iters, auroc, loss, time :  0.8111746620420903 0.4217881017923355 107.03433012962341
3900 iters, auroc, loss, time :  0.8099786442095468 0.419307569116354 109.87495756149292
4000 iters, auroc, loss, time :  0.7969942471314375 0.4336804734170437 112.71558427810669
4100 iters, auroc, loss, time :  0.8051247825746666 0.42781452625989913 115.53320693969727
4200 iters, auroc, loss, time :  0.8104664628036968 0.4152581608295441 118.35984492301941
4300 iters, auroc, loss, time :  0.8239878263076793 0.3940646481513977 121.13261318206787
4400 iters, auroc, loss, time :  0.8076341159485508 0.4268468470871449 123.90640759468079
4500 iters, 

800 iters, auroc, loss, time :  0.8259586327200491 0.39870751708745955 22.620333909988403
900 iters, auroc, loss, time :  0.8058529883812139 0.4231607066094875 25.456324338912964
1000 iters, auroc, loss, time :  0.806323587231482 0.43060873478651046 28.253309726715088
1100 iters, auroc, loss, time :  0.8091830664624109 0.4182757374644279 31.028291940689087
1200 iters, auroc, loss, time :  0.801033068613421 0.43213102892041205 33.79395866394043
1300 iters, auroc, loss, time :  0.8174038090035853 0.39690670877695083 36.60908579826355
1400 iters, auroc, loss, time :  0.8188000741406484 0.4092402058839798 39.47871971130371
1500 iters, auroc, loss, time :  0.8067863239589416 0.41790072709321974 42.36140441894531
1600 iters, auroc, loss, time :  0.8257557453203234 0.3957230831682682 45.21403479576111
1700 iters, auroc, loss, time :  0.8094591597062992 0.42521251782774927 48.02565574645996
1800 iters, auroc, loss, time :  0.8100983931144536 0.41236258402466774 50.83865571022034
1900 iters, au

4100 iters, auroc, loss, time :  0.8172858839525365 0.41199275851249695 114.98832631111145
4200 iters, auroc, loss, time :  0.7967994008915062 0.43343226850032807 117.78394293785095
4300 iters, auroc, loss, time :  0.8103333863493711 0.4108465501666069 120.56755828857422
4400 iters, auroc, loss, time :  0.8125898533450443 0.41122852995991704 123.37087368965149
4500 iters, auroc, loss, time :  0.8112936249701376 0.40302729591727254 126.22651982307434
4600 iters, auroc, loss, time :  0.8050602934105424 0.42350882306694987 129.0910608768463
4700 iters, auroc, loss, time :  0.817110564630924 0.4082597590982914 131.9180555343628
4800 iters, auroc, loss, time :  0.8210366282968836 0.40990290746092795 134.75802731513977
4900 iters, auroc, loss, time :  0.7969458028295944 0.4373797261714935 137.5550343990326
5000 iters, auroc, loss, time :  0.7933752702835521 0.4401363900303841 140.33501839637756
5100 iters, auroc, loss, time :  0.8091182857408965 0.43632001891732214 143.1420316696167
5200 ite

1500 iters, auroc, loss, time :  0.8132222552792702 0.41275415137410165 42.146185636520386
1600 iters, auroc, loss, time :  0.8067143022834316 0.4191195370256901 44.938164949417114
1700 iters, auroc, loss, time :  0.8225819445731278 0.4006949634850025 47.816149950027466
1800 iters, auroc, loss, time :  0.8061954718911829 0.41630209878087043 50.586180686950684
1900 iters, auroc, loss, time :  0.8183595784795813 0.4032995739579201 53.40808343887329
2000 iters, auroc, loss, time :  0.8159695928211793 0.41316632270812986 56.244709491729736
2100 iters, auroc, loss, time :  0.8145144823655727 0.41109742641448976 59.09284520149231
2200 iters, auroc, loss, time :  0.8117718130337006 0.4131432346999645 61.94546556472778
2300 iters, auroc, loss, time :  0.8043463598496744 0.4183852261304855 64.7410831451416
2400 iters, auroc, loss, time :  0.8152242904817696 0.399614551961422 67.56571435928345
2500 iters, auroc, loss, time :  0.8166496476605609 0.40106711611151696 70.35933113098145
2600 iters, a

4800 iters, auroc, loss, time :  0.8102738166219373 0.41404626131057737 134.639306306839
4900 iters, auroc, loss, time :  0.8031217950651002 0.4151790016889572 137.43685960769653
5000 iters, auroc, loss, time :  0.821835500788304 0.39813336357474327 140.22347497940063
5100 iters, auroc, loss, time :  0.8156605332988679 0.40728614926338197 143.03109502792358
5200 iters, auroc, loss, time :  0.8162655566397826 0.41397703617811205 145.85771918296814
5300 iters, auroc, loss, time :  0.795085451151838 0.4243687231838703 148.67270874977112
5400 iters, auroc, loss, time :  0.8209760846308833 0.40333888605237006 151.4641399383545
5500 iters, auroc, loss, time :  0.8082107759240954 0.41312880292534826 154.23803281784058
5600 iters, auroc, loss, time :  0.8057212480557985 0.4281528437137604 157.02474689483643
5700 iters, auroc, loss, time :  0.8114289431581041 0.4054996865987778 159.81886839866638
Train loss 0.4087808986040516 accuracy 0.8143548104701042  auroc 0.8142378421267215
Val   loss 0.53

2200 iters, auroc, loss, time :  0.8218137063623249 0.3996186825633049 62.0205864906311
2300 iters, auroc, loss, time :  0.8118209551979407 0.4029868210852146 64.79719972610474
2400 iters, auroc, loss, time :  0.8148679828728203 0.4127984435856342 67.58181500434875
2500 iters, auroc, loss, time :  0.8100766794162523 0.408486014008522 70.36944079399109
2600 iters, auroc, loss, time :  0.810365796644838 0.42000390887260436 73.20874547958374
2700 iters, auroc, loss, time :  0.8209408704769454 0.41129530772566797 76.03887796401978
2800 iters, auroc, loss, time :  0.8032856185016132 0.4092544166743755 78.86591529846191
2900 iters, auroc, loss, time :  0.8217401163956218 0.411599587649107 81.68490815162659
3000 iters, auroc, loss, time :  0.8131456062810298 0.4030188588798046 84.47541904449463
3100 iters, auroc, loss, time :  0.8166771364983186 0.39464901462197305 87.28254389762878
3200 iters, auroc, loss, time :  0.8328601764377694 0.3941754886507988 90.05266451835632
3300 iters, auroc, los

5500 iters, auroc, loss, time :  0.8109839746811496 0.40400506675243375 156.28917789459229
5600 iters, auroc, loss, time :  0.8125909615098486 0.4134947445988655 159.10530519485474
5700 iters, auroc, loss, time :  0.8214427837347142 0.39234828531742094 161.96193552017212
Train loss 0.408051100679535 accuracy 0.8150715661384925  auroc 0.8150107081863098
Val   loss 0.5417558406831837 accuracy 0.7583822113280224 auroc 0.7579597507534053

Epoch 12/20
----------
100 iters, auroc, loss, time :  0.8195879390871315 0.4143442957103252 2.8336257934570312
200 iters, auroc, loss, time :  0.8239139427823637 0.39798112034797667 5.647246837615967
300 iters, auroc, loss, time :  0.8154304831380315 0.4169994910061359 8.456940174102783
400 iters, auroc, loss, time :  0.8303738260260926 0.3816690295934677 11.249556303024292
500 iters, auroc, loss, time :  0.8212425542807575 0.40390445828437804 14.029170751571655
600 iters, auroc, loss, time :  0.8186034848149787 0.39957978054881094 16.80777907371521
700 

2900 iters, auroc, loss, time :  0.8123856355660688 0.4291214594244957 83.29185485839844
3000 iters, auroc, loss, time :  0.8116878701083843 0.4041051161289215 86.18300580978394
3100 iters, auroc, loss, time :  0.8166834108658798 0.4003393180668354 89.07222199440002
3200 iters, auroc, loss, time :  0.8227901318630217 0.3961240515112877 91.95595717430115
3300 iters, auroc, loss, time :  0.8266739656850456 0.3947597788274288 94.83748126029968
3400 iters, auroc, loss, time :  0.8193456149758395 0.3936011353135109 97.72627568244934
3500 iters, auroc, loss, time :  0.8135996445702933 0.41146346643567083 100.59598755836487
3600 iters, auroc, loss, time :  0.8160767133445739 0.39879701048135757 103.45862030982971
3700 iters, auroc, loss, time :  0.817208962973955 0.40372956097126006 106.3002610206604
3800 iters, auroc, loss, time :  0.8131296747628016 0.4010891088843346 109.18089723587036
3900 iters, auroc, loss, time :  0.8246921762834689 0.40160680875182153 112.07804465293884
4000 iters, au

300 iters, auroc, loss, time :  0.8219006921842681 0.3981443336606026 8.666691780090332
400 iters, auroc, loss, time :  0.8229499920401235 0.39548415765166284 11.547876119613647
500 iters, auroc, loss, time :  0.8051272005577766 0.4215534895658493 14.404600381851196
600 iters, auroc, loss, time :  0.8185079985588887 0.39857140347361564 17.262831926345825
700 iters, auroc, loss, time :  0.8261884944046323 0.39250240966677663 20.117469549179077
800 iters, auroc, loss, time :  0.8155114405792828 0.40122890859842303 22.955103635787964
900 iters, auroc, loss, time :  0.8114856516004971 0.4265470276772976 25.80424475669861
1000 iters, auroc, loss, time :  0.8109681081087815 0.4087137557566166 28.638960123062134
1100 iters, auroc, loss, time :  0.8199175294999717 0.40659171029925345 31.490103483200073
1200 iters, auroc, loss, time :  0.8171731926054682 0.40339435771107673 34.36574673652649
1300 iters, auroc, loss, time :  0.8137903716333098 0.40272221580147743 37.228612422943115
1400 iters, a

3600 iters, auroc, loss, time :  0.8194432625552983 0.40361615240573884 102.54305720329285
3700 iters, auroc, loss, time :  0.8161060503179312 0.4124511070549488 105.37108182907104
3800 iters, auroc, loss, time :  0.8093801082351192 0.40958900257945063 108.19608402252197
3900 iters, auroc, loss, time :  0.8163334737765262 0.39951544925570487 111.03904747962952
4000 iters, auroc, loss, time :  0.8347435597598135 0.38609076410531995 113.87306523323059
4100 iters, auroc, loss, time :  0.8177819496931928 0.4118698516488075 116.72004723548889
4200 iters, auroc, loss, time :  0.8144400078503928 0.41043530717492105 119.59905934333801
4300 iters, auroc, loss, time :  0.8241258439964189 0.4021993507444859 122.46102237701416
4400 iters, auroc, loss, time :  0.8075479978500656 0.42326515197753906 125.32739090919495
4500 iters, auroc, loss, time :  0.8139091517740645 0.4090408881008625 128.14806842803955
4600 iters, auroc, loss, time :  0.8101164033811343 0.4194313779473305 131.00470685958862
4700

1000 iters, auroc, loss, time :  0.8188293277966855 0.40354149162769315 28.571012020111084
1100 iters, auroc, loss, time :  0.8092537583040341 0.4113039004802704 31.402164697647095
1200 iters, auroc, loss, time :  0.807053244461787 0.41904417231678964 34.244800090789795
1300 iters, auroc, loss, time :  0.8111455284509614 0.4128097353875637 37.04886317253113
1400 iters, auroc, loss, time :  0.8125778964271693 0.3952157409489155 39.89349889755249
1500 iters, auroc, loss, time :  0.8233266762262506 0.4020050649344921 42.78316569328308
1600 iters, auroc, loss, time :  0.8148454494679727 0.3983841758966446 45.666881799697876
1700 iters, auroc, loss, time :  0.8186164177454812 0.39398602724075316 48.50552535057068
1800 iters, auroc, loss, time :  0.8229034313228688 0.39512763753533364 51.354092836380005
1900 iters, auroc, loss, time :  0.8127722579735891 0.4121507193148136 54.19163751602173
2000 iters, auroc, loss, time :  0.8256330318131637 0.3939321245253086 57.037641286849976
2100 iters, 

4300 iters, auroc, loss, time :  0.8193696469443793 0.39594648450613024 122.54836010932922
4400 iters, auroc, loss, time :  0.8279841653612686 0.39885851860046384 125.3919951915741
4500 iters, auroc, loss, time :  0.8110914792775911 0.4189515298604965 128.24263167381287
4600 iters, auroc, loss, time :  0.8191501108064325 0.3975316371023655 131.10177445411682
4700 iters, auroc, loss, time :  0.8192532375975553 0.39972741544246676 133.9574134349823
4800 iters, auroc, loss, time :  0.8143219227406283 0.41642715081572534 136.79855680465698
4900 iters, auroc, loss, time :  0.8318566695210088 0.39822726517915724 139.64019179344177
5000 iters, auroc, loss, time :  0.8207034890942131 0.3969672539830208 142.473961353302
5100 iters, auroc, loss, time :  0.8031332170954932 0.4052795569598675 145.28159165382385
5200 iters, auroc, loss, time :  0.8218497297775211 0.4012441910803318 148.15934109687805
5300 iters, auroc, loss, time :  0.8157284548536147 0.39971822381019595 151.02203130722046
5400 ite

In [7]:
def get_predictions(model, data_loader):
  model = model.eval()
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []
  with torch.no_grad():
    for d in data_loader:
      texts = d["review_text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      targets = d["targets"].to(device)
      outputs = torch.zeros_like(targets)
      outputs = model(input_ids)
      preds = torch.zeros_like(outputs)
      ones = torch.ones_like(preds)
      preds = torch.where(outputs < 0, preds, ones)
      review_texts.extend(texts)
      predictions.extend(preds)
      prediction_probs.extend(outputs)
      real_values.extend(targets)
  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values

In [20]:
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [21]:
class_names = ['no_spoilers', 'has_spoilers']
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.5108160241345758
              precision    recall  f1-score   support

 no_spoilers       0.51      0.97      0.67     11480
has_spoilers       0.65      0.05      0.09     11366

    accuracy                           0.51     22846
   macro avg       0.58      0.51      0.38     22846
weighted avg       0.58      0.51      0.38     22846



In [26]:
test_df = pd.read_pickle('data/15ktest_goodreads_sent_spoil_titles.pkl')
test_dl = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_dl
)
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.7643611541477016
              precision    recall  f1-score   support

 no_spoilers       0.99      0.73      0.84     14485
has_spoilers       0.10      0.80      0.17       516

    accuracy                           0.73     15001
   macro avg       0.54      0.76      0.51     15001
weighted avg       0.96      0.73      0.82     15001



# Conclusion
LSTM achieves 0.75 AUROC and seems to stall there, at least on the balanced dataset with BERT tokenizer and a slightly simplified architecture. It's much faster, so might make sense to experiment further, at least to roughly match the performance from the paper.
UPDATE: AdamW with linear scheduler didn't seem to take off. Stayed around 0.5
But Adam with lr=0.003 Performed the best with test AUROC 0.78
Increasing hidden size doesn't help

# BERT test

In [10]:
from transformers import BertModel


class SpoilerClassifier(nn.Module):

    def __init__(self, n_classes, model_name='bert-base-cased', dropout=0.3):
        super(SpoilerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.drop = nn.Dropout(p=dropout)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask
        ).to_tuple()
        output = self.drop(pooled_output)
        return self.out(output)

In [13]:
model = SpoilerClassifier(1)
model.load_state_dict(torch.load('models/bert_ep5_balanced.pt'))
model = model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
EPOCHS=5
total_steps = len(train_data_loader) * EPOCHS
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

In [18]:
loss_fn = nn.BCEWithLogitsLoss().to(device)

import warnings
warnings.filterwarnings('ignore')

In [19]:
%%time
from collections import defaultdict
from time import time

history = defaultdict(list)
best_auroc = 0
for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss, train_avg_losses, train_auroc = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer,
    device,
    len(df_train),
    scheduler=scheduler
  )
  print(f'Train loss {train_loss} accuracy {train_acc}  auroc {np.mean(train_auroc)}')
  val_acc, val_loss, val_auroc = eval_model(
    model,
    val_data_loader,
    loss_fn,
    device,
    len(df_val)
  )
  print(f'Val   loss {val_loss} accuracy {val_acc} auroc {val_auroc}')
  print()
  history['train_auroc'] += train_auroc
  history['train_loss'] += train_avg_losses
  history['val_auroc'].append(val_auroc)
  history['val_loss'].append(val_loss)
  if val_auroc > best_auroc:
    torch.save(model.state_dict(), 'best_model_state.bin')
    best_auroc = val_auroc

Epoch 1/5
----------
100 iters, auroc, loss, time :  0.6073380191044428 0.6589620217680932 30.66746973991394
200 iters, auroc, loss, time :  0.6603143513039359 0.619895167350769 60.00142741203308
300 iters, auroc, loss, time :  0.6853292075338028 0.5932223698496819 89.45625805854797
400 iters, auroc, loss, time :  0.7143135987567271 0.5570016050338745 118.87770128250122
500 iters, auroc, loss, time :  0.6978554373491781 0.5736238944530487 147.88466691970825
600 iters, auroc, loss, time :  0.7011920317534129 0.5695155012607575 176.9453945159912
700 iters, auroc, loss, time :  0.7150421311518531 0.5481693094968796 205.9999713897705
800 iters, auroc, loss, time :  0.7163693076761035 0.5656557375192642 235.0478081703186
900 iters, auroc, loss, time :  0.715884376146617 0.5514068275690078 264.0816526412964
1000 iters, auroc, loss, time :  0.7180947906321106 0.5505121290683747 293.13843059539795
1100 iters, auroc, loss, time :  0.7284679484490024 0.53363265812397 322.139466047287
1200 iters,

3400 iters, auroc, loss, time :  0.8421060991167415 0.37036869943141937 987.8345942497253
3500 iters, auroc, loss, time :  0.8394950056627017 0.36238186821341517 1016.9964461326599
3600 iters, auroc, loss, time :  0.8365431935410649 0.3724065293371677 1046.0672607421875
3700 iters, auroc, loss, time :  0.8511920700810609 0.3556854476034641 1075.1085867881775
3800 iters, auroc, loss, time :  0.8361202026387784 0.3768331180512905 1104.1769170761108
3900 iters, auroc, loss, time :  0.8387127277164512 0.3655434301495552 1133.2007610797882
4000 iters, auroc, loss, time :  0.8412193312945685 0.3635358515381813 1162.2374501228333
4100 iters, auroc, loss, time :  0.8325858489215351 0.3782467573881149 1191.3521196842194
4200 iters, auroc, loss, time :  0.8344822843817242 0.3665767501294613 1220.394289970398
4300 iters, auroc, loss, time :  0.8412494519472241 0.36837791204452514 1249.4221794605255
4400 iters, auroc, loss, time :  0.8400256845071593 0.36825042754411696 1278.447521686554
4500 iter

800 iters, auroc, loss, time :  0.9266371838244987 0.19038855619728565 233.56610441207886
900 iters, auroc, loss, time :  0.9282653033390684 0.19500117868185043 262.71165776252747
1000 iters, auroc, loss, time :  0.9240498425722883 0.19614223174750806 291.91273975372314
1100 iters, auroc, loss, time :  0.9169034317729803 0.19837752789258956 321.06520104408264
1200 iters, auroc, loss, time :  0.9188762359619189 0.2061696244031191 350.2571792602539
1300 iters, auroc, loss, time :  0.9294236194596521 0.18088679447770117 379.3787007331848
1400 iters, auroc, loss, time :  0.9211956342236158 0.2017993625998497 408.5717496871948
1500 iters, auroc, loss, time :  0.9238808065803119 0.19799875192344188 437.72119331359863
1600 iters, auroc, loss, time :  0.9171218282878112 0.21307206796482206 466.83413791656494
1700 iters, auroc, loss, time :  0.9197250915364306 0.19830245289951562 496.0142526626587
1800 iters, auroc, loss, time :  0.9244002059816688 0.2018032491207123 525.1762132644653
1900 iter

4100 iters, auroc, loss, time :  0.947265760866577 0.14523312639445066 1191.0941898822784
4200 iters, auroc, loss, time :  0.9475673571223947 0.13836329544894396 1220.4636182785034
4300 iters, auroc, loss, time :  0.9535378754329605 0.14004778811708093 1249.6656568050385
4400 iters, auroc, loss, time :  0.9431443880551853 0.1595402224827558 1278.9215445518494
4500 iters, auroc, loss, time :  0.9448428991048289 0.15589131323620678 1307.9711842536926
4600 iters, auroc, loss, time :  0.9451992878122908 0.1563861914910376 1337.0890820026398
4700 iters, auroc, loss, time :  0.9459315335796775 0.14897262210026382 1366.1789309978485
4800 iters, auroc, loss, time :  0.9487766648490752 0.1299898583535105 1395.2446827888489
4900 iters, auroc, loss, time :  0.9412014241563391 0.14538740644231438 1424.3029561042786
5000 iters, auroc, loss, time :  0.9443496168219464 0.154413817608729 1453.4005970954895
5100 iters, auroc, loss, time :  0.9520863508051975 0.13301018133759498 1482.4890656471252
5200 

In [14]:
def get_predictions(model, data_loader):
  model = model.eval()
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []
  with torch.no_grad():
    for d in data_loader:
      texts = d["review_text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      targets = d["targets"].to(device)
      outputs = torch.zeros_like(targets)
      outputs = model(
        input_ids=input_ids, attention_mask=attention_mask
      )
      preds = torch.zeros_like(outputs)
      ones = torch.ones_like(preds)
      preds = torch.where(outputs < 0, preds, ones)
      review_texts.extend(texts)
      predictions.extend(preds)
      prediction_probs.extend(outputs)
      real_values.extend(targets)
  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values

In [21]:
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [22]:
class_names = ['no_spoilers', 'has_spoilers']
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

0.8000746826477275
              precision    recall  f1-score   support

 no_spoilers       0.80      0.79      0.80     11519
has_spoilers       0.80      0.81      0.80     11432

    accuracy                           0.80     22951
   macro avg       0.80      0.80      0.80     22951
weighted avg       0.80      0.80      0.80     22951



In [15]:
class_names = ['no_spoilers', 'has_spoilers']
test_df = pd.read_pickle('data/15ktest_goodreads_sent_spoil_titles.pkl')
test_dl = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_dl
)
print(roc_auc_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=class_names))

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


0.7742554446861629
              precision    recall  f1-score   support

 no_spoilers       0.99      0.79      0.88     14485
has_spoilers       0.11      0.76      0.20       516

    accuracy                           0.79     15001
   macro avg       0.55      0.77      0.54     15001
weighted avg       0.96      0.79      0.85     15001



In [17]:
print(roc_auc_score(y_test, y_pred_probs))

0.8570326694549026
