In [1]:
from overrides import overrides

import numpy as np

import pandas as pd

import torch
import torch.nn as nn

from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.data import Instance
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.fields import TextField, MetadataField, ArrayField
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor
from allennlp.nn.util import get_text_field_mask

In [2]:
label_cols = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [3]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    testing=True,
    seed=1,
    batch_size=64,
    lr=3e-4,
    epochs=2,
    hidden_sz=64,
    max_seq_len=100, # necessary to limit memory usage
    max_vocab_size=100000,
)

class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer=lambda x: x.split(),
                 token_indexers=None,
                 max_seq_len=config.max_seq_len):
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len
 
    @overrides
    def text_to_instance(self, tokens, id=None, labels=None):
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}
         
        id_field = MetadataField(id)
        fields["id"] = id_field
         
        if labels is None:
            labels = np.zeros(len(label_cols))
        label_field = ArrayField(array=labels)
        fields["label"] = label_field
 
        return Instance(fields)
     
    @overrides
    def _read(self, file_path):
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["comment_text"])],
                row["id"], row[label_cols].values,
            )

In [4]:
token_indexer = PretrainedBertIndexer(
    pretrained_model="bert-base-uncased",
    max_pieces=config.max_seq_len,
    do_lowercase=True,
    truncate_long_sequences=False, # Use sliding window for contexts
 )
 
def tokenizer(s: str):
    return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]

reader = JigsawDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer}
)

In [5]:
vocab = Vocabulary()

In [6]:
bert_embedder = PretrainedBertEmbedder(
         pretrained_model="bert-base-uncased",
         top_layer_only=True, # conserve memory
)
word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                        # we'll be ignoring masks so we'll need to set this to True
                                        allow_unmatched_keys = True)

In [7]:
BERT_DIM = word_embeddings.get_output_dim()
 
class BertSentencePooler(Seq2VecEncoder):
    def forward(self, embs: torch.tensor, 
                mask: torch.tensor=None) -> torch.tensor:
        # extract first token tensor
        return embs[:, 0]
     
    @overrides
    def get_output_dim(self) -> int:
        return BERT_DIM
     
encoder = BertSentencePooler(vocab)

In [8]:
class BaselineModel(Model):
    def __init__(self, word_embeddings,
                 encoder,
                 out_sz=len(label_cols)):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.BCEWithLogitsLoss()
         
    def forward(self, tokens,
                id, label):
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)
         
        output = {"class_logits": class_logits, "state": state}
        output["loss"] = self.loss(class_logits, label)
 
        return output

In [9]:
model = BaselineModel(word_embeddings, encoder)

In [10]:
tagger = SentenceTaggerPredictor(model, reader)

In [11]:
tagger.predict("This thing is pretty cool.")

{'class_logits': [0.4939308762550354,
  0.18493486940860748,
  -0.0856907069683075,
  -0.16543801128864288,
  0.03377683088183403,
  0.3908044993877411],
 'state': [0.08956055343151093,
  -0.04516694322228432,
  0.15095344185829163,
  -0.4251112639904022,
  -0.4735545516014099,
  -0.7761418223381042,
  0.14447686076164246,
  0.6996430158615112,
  0.2625007927417755,
  -0.1823796182870865,
  0.08206482231616974,
  -0.0040482450276613235,
  0.2658458352088928,
  0.4016270637512207,
  -0.20179183781147003,
  -0.308272123336792,
  -0.19338111579418182,
  0.5061939358711243,
  0.345022976398468,
  -0.0024992097169160843,
  -0.11763352155685425,
  -0.15226660668849945,
  0.11027456074953079,
  -0.2619054615497589,
  -0.02325110137462616,
  0.08160116523504257,
  0.22415059804916382,
  -0.09971704334020615,
  -0.06847239285707474,
  0.30956393480300903,
  0.04546726867556572,
  0.02099713310599327,
  -0.3198445439338684,
  -0.1277340054512024,
  0.2875863313674927,
  -0.3489682078361511,
  0.

In [12]:
df = pd.read_csv("toxic-train-clean.csv")

In [24]:
i = 0
l = len(df)
def get_vector(text):
    global i
    global l
    i+=1
    if i % 1000 == 0:
        print(i, "/", l)
    return tagger.predict(text)["state"]

In [25]:
embeddings = np.stack(df["comment_text"].apply(get_vector).values)

1000 / 159571
2000 / 159571
3000 / 159571
4000 / 159571
5000 / 159571
6000 / 159571
7000 / 159571
8000 / 159571
9000 / 159571
10000 / 159571
11000 / 159571
12000 / 159571
13000 / 159571
14000 / 159571
15000 / 159571
16000 / 159571
17000 / 159571
18000 / 159571
19000 / 159571
20000 / 159571
21000 / 159571
22000 / 159571
23000 / 159571
24000 / 159571
25000 / 159571
26000 / 159571
27000 / 159571
28000 / 159571
29000 / 159571
30000 / 159571
31000 / 159571
32000 / 159571
33000 / 159571
34000 / 159571
35000 / 159571
36000 / 159571
37000 / 159571
38000 / 159571
39000 / 159571
40000 / 159571
41000 / 159571
42000 / 159571
43000 / 159571
44000 / 159571
45000 / 159571
46000 / 159571
47000 / 159571
48000 / 159571
49000 / 159571
50000 / 159571
51000 / 159571
52000 / 159571
53000 / 159571
54000 / 159571
55000 / 159571


KeyboardInterrupt: 

In [None]:
np.savetxt('toxic_bert_matrix.out', embeddings, delimiter=',')