In [1]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField, LabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

TODOs:

- Make compatible with GPU
- Try replicating SST results

In [2]:
class Config:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

In [40]:
config = Config(
    testing=True,
    seed=1,
    batch_size=64, # This is probably too large: need to handle effective v.s. machine batch size
    embed_dim=256,
    hidden_sz=768,
    dataset="jigsaw",
    n_classes=2,
    max_seq_len=128, # necessary to limit memory usage
#     bert_model=None,
    bert_model="bert-base-cased",
)

In [4]:
DATA_ROOT = Path("../data") / config.dataset

Set random seed manually to replicate results

In [5]:
torch.manual_seed(config.seed)

<torch._C.Generator at 0x119157e70>

# Load Data

In [6]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader, StanfordSentimentTreeBankDatasetReader

### Prepare dataset

In [7]:
import glob

In [8]:
reader_registry = {}

In [9]:
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [10]:
@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        # TODO: Reimplement
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}

        label_field = LabelField(label=label, skip_indexing=True)
        fields["label"] = label_field

        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(10000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["comment_text"])],
                row["toxic"]
            )

In [11]:
@register("imdb")
class IMDBDatasetReader(DatasetReader):
    def __init__(self, tokenizer=None, 
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_seq_len=None) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}
        
        # TODO: Add statistical features?

        label_field = LabelField(label=label)
        fields["label"] = label_field

        return Instance(fields)

    def _read(self, file_path: str) -> Iterator[Instance]:
        # TODO: Implement
        for label in ["pos", "neg"]:
            for file in (Path(file_path) / label).glob("*.txt"):
                text = file.open("rt", encoding="utf-8").read()
                yield self.text_to_instance([Token(word) for word in self.tokenizer(text)], 
                                            label)

### Prepare token handlers

In [12]:
from allennlp.data.token_indexers import PretrainedBertIndexer, SingleIdTokenIndexer
if config.bert_model is not None:
    token_indexer = PretrainedBertIndexer(
        pretrained_model=config.bert_model,
        max_pieces=config.max_seq_len,
        do_lowercase="uncased" in config.bert_model,
     )
    # apparently we need to truncate the sequence here, which is a stupid design decision
    def tokenizer(s: str):
        return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]
else:
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=False,  # don't lowercase by default
    )
    tokenizer = lambda x: x.split()

01/17/2019 17:11:41 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /Users/keitakurita/.pytorch_pretrained_bert/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [13]:
reader_cls = reader_registry[config.dataset]
reader = reader_cls(tokenizer=tokenizer,
                    token_indexers={"tokens": token_indexer})

In [14]:
if config.dataset == "IMDB":
    data_dir = DATA_ROOT / "imdb" / "aclImdb"
    train_ds, test_ds = (reader.read(data_dir / fname) for fname in ["train", "test"])
    val_ds = None
elif config.dataset == "SST":
    pass # TODO: Implement
else:
    train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train.csv", "test_proced.csv"])
    val_ds = None

10000it [00:07, 1410.47it/s]
10000it [00:06, 1529.95it/s]


In [15]:
len(train_ds), len(test_ds)

(10000, 10000)

### Prepare vocabulary

In [16]:
if config.bert_model is not None: 
    vocab = Vocabulary()
    token_indexer._add_encoding_to_vocabulary(vocab)
else:
    vocab = Vocabulary.from_instances(train_ds)

### Prepare iterator

In [17]:
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator

In [18]:
# TODO: Allow for customization
iterator = BucketIterator(batch_size=config.batch_size, 
                          biggest_batch_first=True,
                          sorting_keys=[("tokens", "num_tokens")],)
iterator.index_with(vocab)

### Read sample

In [19]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    batch = next(iter(iterator(train_ds)))

In [20]:
batch

{'tokens': {'tokens': tensor([[  101,  3320,  1116,  ..., 24834, 28137,   102],
          [  101,   107,   156,  ...,  1474,  1122,   102],
          [  101, 15112, 14494,  ...,  1559, 28131,   102],
          ...,
          [  101,  1345,  1386,  ...,  1103,  3371,   102],
          [  101,   107,  6710,  ...,  1592,  1590,   102],
          [  101,   107,  5046,  ...,  2543,   107,   102]]),
  'tokens-offsets': tensor([[  1,   2,   3,  ..., 124, 125, 126],
          [  1,   2,   3,  ..., 124, 125, 126],
          [  1,   2,   3,  ..., 124, 125, 126],
          ...,
          [  1,   2,   3,  ..., 124, 125, 126],
          [  1,   2,   3,  ..., 124, 125, 126],
          [  1,   2,   3,  ..., 124, 125, 126]]),
  'mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]])},
 'label': tensor([0, 0, 0, 0, 0, 0, 1,

In [21]:
batch["tokens"]["tokens"]

tensor([[  101,  3320,  1116,  ..., 24834, 28137,   102],
        [  101,   107,   156,  ...,  1474,  1122,   102],
        [  101, 15112, 14494,  ...,  1559, 28131,   102],
        ...,
        [  101,  1345,  1386,  ...,  1103,  3371,   102],
        [  101,   107,  6710,  ...,  1592,  1590,   102],
        [  101,   107,  5046,  ...,  2543,   107,   102]])

In [22]:
batch["tokens"]["tokens"].shape

torch.Size([16, 128])

# Prepare Model

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim

In [24]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask

In [25]:
class LstmEncoder(nn.Module):
    def __init__(self, lstm):
        super().__init__()
        self.lstm = lstm
        
    def forward(self, x, mask): # TODO: replace with allennlp built in modules
        _, (state, _) = self.lstm(x)
        state = torch.cat([state[0, :, :], state[1, :, :]], dim=1)
        return state

In [26]:
class BertPooler(nn.Module):
    """Source code copied"""
    def __init__(self):
        super().__init__()
        self.dense = nn.Linear(768, 768)
        self.activation = nn.Tanh()

    def forward(self, hidden_states, mask):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [41]:
class SentimentAnalysisModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: StackedBidirectionalLstm,
                 out_sz=config.n_classes):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
#         self.projection = nn.Linear(encoder.get_output_dim(), out_sz)
        self.projection = nn.Linear(config.hidden_sz, out_sz)
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens["tokens"])
        state = self.encoder(embeddings, mask)

        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        if label is not None:
            self.accuracy(class_logits, label, None)
            output["loss"] = nn.CrossEntropyLoss()(class_logits, label)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [28]:
if config.bert_model is None:
    token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                                embedding_dim=config.embed_dim)
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

    # encoder = PytorchSeq2VecWrapper(nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True,
    #                                         bidirectional=True))
    encoder = LstmEncoder(
        nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True, bidirectional=True)
    )

else:
    word_embeddings = PretrainedBertEmbedder(
        pretrained_model=config.bert_model,
        top_layer_only=True, # conserve memory
    )
    encoder = BertPooler()

01/17/2019 17:11:59 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz from cache at /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c
01/17/2019 17:11:59 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c to temp dir /var/folders/hy/1czs1y5j2d58zgkqx6w_wnpw0000gn/T/tmpucgh649e
01/17/2019 17:12:02 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 

In [42]:
model = SentimentAnalysisModel(
    word_embeddings, 
    encoder, 
    out_sz=2,
)

In [43]:
model

SentimentAnalysisModel(
  (word_embeddings): PretrainedBertEmbedder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (

### Basic sanity checks

In [44]:
np.isnan(list(model.word_embeddings.parameters())[0].detach().numpy()).any()

False

In [45]:
[np.isnan(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [46]:
[np.isinf(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [47]:
tokens = batch["tokens"]
encoder(model.word_embeddings(tokens["tokens"]), get_text_field_mask(tokens))
# encoder(model.word_embeddings(tokens["tokens"]))[1][0].size()

tensor([[-0.0755,  0.2792, -0.2031,  ..., -0.5464, -0.2699, -0.1019],
        [-0.2183,  0.3834, -0.3327,  ..., -0.4463, -0.1837,  0.0743],
        [-0.1531,  0.2794, -0.2308,  ..., -0.5165, -0.0732, -0.2113],
        ...,
        [-0.0934,  0.3107, -0.0174,  ..., -0.4465, -0.2054, -0.1657],
        [-0.1436,  0.4138, -0.1198,  ..., -0.4273,  0.0112, -0.2166],
        [-0.1005,  0.4191, -0.0668,  ..., -0.5434, -0.1772, -0.1340]],
       grad_fn=<TanhBackward>)

In [48]:
loss = model(**batch)["loss"]

In [49]:
nn.CrossEntropyLoss()(model(**batch)["class_logits"][:10, :], batch["label"][:10])

tensor(0.7642, grad_fn=<NllLossBackward>)

In [50]:
loss

tensor(0.7745, grad_fn=<NllLossBackward>)

In [51]:
loss.backward()

In [52]:
[x.grad for x in list(model.encoder.parameters())]

[tensor([[ 8.5304e-03,  3.5480e-04, -1.8205e-03,  ..., -8.0811e-03,
           5.0827e-03,  1.0404e-03],
         [-3.4866e-03, -2.8361e-05,  7.4339e-04,  ...,  3.3525e-03,
          -1.9785e-03, -3.1576e-04],
         [ 5.3810e-03,  2.0569e-04, -1.1151e-03,  ..., -5.1062e-03,
           3.2566e-03,  6.4371e-04],
         ...,
         [-4.5126e-03, -1.9369e-04,  9.5409e-04,  ...,  4.3904e-03,
          -2.7528e-03, -4.6456e-04],
         [ 1.9360e-03,  6.3666e-05, -4.0991e-04,  ..., -1.8234e-03,
           1.1646e-03,  2.4010e-04],
         [ 5.1147e-03,  1.6980e-04, -1.0655e-03,  ..., -4.8573e-03,
           3.0409e-03,  5.9746e-04]]),
 tensor([ 1.9886e-02, -8.1437e-03,  1.2440e-02, -7.9120e-03,  9.4516e-03,
          1.6293e-02, -6.7221e-03,  1.3400e-02,  3.4142e-03,  9.0865e-03,
          5.8568e-03,  9.3201e-03,  1.3623e-02, -9.1804e-03,  8.2231e-04,
          7.2423e-03, -1.4093e-02,  3.5922e-03,  8.2995e-03,  2.2107e-02,
         -1.1998e-02,  3.4935e-03,  1.4428e-02,  1.2406e-0

# Train

In [53]:
from allennlp.training.trainer import Trainer

In [54]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [55]:
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_ds,
                  validation_dataset=val_ds,
                  serialization_dir=DATA_ROOT / "ckpts",
                  patience=3,
                  num_epochs=10)

In [56]:
trainer.train()

01/17/2019 17:15:52 - INFO - allennlp.training.trainer -   Beginning training.
01/17/2019 17:15:52 - INFO - allennlp.training.trainer -   Epoch 0/9
01/17/2019 17:15:52 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1792.45056
01/17/2019 17:15:52 - INFO - allennlp.training.trainer -   Training
accuracy: 0.9168, loss: 0.2234 ||: 100%|██████████| 157/157 [35:20<00:00, 12.87s/it]
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -                     Training |  Validation
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -   loss          |     0.223  |       N/A
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1792.451  |       N/A
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -   accuracy      |     0.917  |       N/A
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -   Epoch duration: 00:35:21
01/17/2019 17:51:13 - INFO - allennlp.training.trainer -   Estimated training time remaining: 5:18:10
01/17/2019 17:51:13 - I

01/18/2019 09:30:42 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 2293.202944
01/18/2019 09:30:42 - INFO - allennlp.training.trainer -   Training
accuracy: 0.9370, loss: 0.1591 ||: 100%|██████████| 157/157 [34:20<00:00, 11.01s/it]
01/18/2019 10:05:03 - INFO - allennlp.training.trainer -                     Training |  Validation
01/18/2019 10:05:03 - INFO - allennlp.training.trainer -   loss          |     0.159  |       N/A
01/18/2019 10:05:03 - INFO - allennlp.training.trainer -   cpu_memory_MB |  2293.203  |       N/A
01/18/2019 10:05:03 - INFO - allennlp.training.trainer -   accuracy      |     0.937  |       N/A
01/18/2019 10:05:03 - INFO - allennlp.training.trainer -   Epoch duration: 00:34:20


{'peak_cpu_memory_MB': 2293.202944,
 'training_duration': '16:49:10',
 'training_start_epoch': 0,
 'training_epochs': 9,
 'epoch': 9,
 'training_accuracy': 0.937,
 'training_loss': 0.15912552930082485,
 'training_cpu_memory_MB': 2293.202944,
 'best_epoch': 9}