In [19]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField, LabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

TODOs:

- Make compatible with GPU
- Try replicating SST results

In [2]:
class Config:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

In [3]:
config = Config(
    testing=True,
    seed=1,
    batch_size=64, # This is probably too large: need to handle effective v.s. machine batch size
    embed_dim=256,
    hidden_sz=768,
    dataset="sst",
    n_classes=2,
    max_seq_len=128, # necessary to limit memory usage
#     bert_model=None,
    bert_model="bert-base-cased",
)

In [4]:
DATA_ROOT = Path("../data") / config.dataset

Set random seed manually to replicate results

In [5]:
torch.manual_seed(config.seed)

<torch._C.Generator at 0x11a3b1e50>

# Load Data

In [6]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader, StanfordSentimentTreeBankDatasetReader

### Prepare dataset

In [7]:
import glob

In [8]:
reader_registry = {}

In [9]:
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [20]:
@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        # TODO: Reimplement
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}

        label_field = LabelField(label=label, skip_indexing=True)
        fields["label"] = label_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(10000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["comment_text"])],
                row["toxic"]
            )

In [21]:
@register("imdb")
class IMDBDatasetReader(DatasetReader):
    def __init__(self, tokenizer=None, 
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_seq_len=None) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len
    
    @overrides
    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}
        
        # TODO: Add statistical features?

        label_field = LabelField(label=label)
        fields["label"] = label_field

        return Instance(fields)

    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        # TODO: Implement
        for label in ["pos", "neg"]:
            for file in (Path(file_path) / label).glob("*.txt"):
                text = file.open("rt", encoding="utf-8").read()
                yield self.text_to_instance([Token(word) for word in self.tokenizer(text)], 
                                            label)

In [67]:
@register("sst")
class SSTDatasetReader(StanfordSentimentTreeBankDatasetReader):
    def __init__(self, *args, tokenizer=None, **kwargs):
        super().__init__(*args, granularity=f"{config.n_classes}-class", **kwargs)
        self._tokenizer = tokenizer
        
    @overrides
    def text_to_instance(self, tokens: List[str], sentiment: str=None) -> Instance:
        """
        Forcibly re-tokenize the input to be wordpiece tokenized
        """
        tokens = self._tokenizer(" ".join(tokens))
        return super().text_to_instance(tokens, sentiment=sentiment)

### Prepare token handlers

In [23]:
from allennlp.data.token_indexers import PretrainedBertIndexer, SingleIdTokenIndexer
if config.bert_model is not None:
    token_indexer = PretrainedBertIndexer(
        pretrained_model=config.bert_model,
        max_pieces=config.max_seq_len,
        do_lowercase="uncased" in config.bert_model,
     )
    # apparently we need to truncate the sequence here, which is a stupid design decision
    def tokenizer(s: str):
        return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]
else:
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=False,  # don't lowercase by default
    )
    tokenizer = lambda x: x.split()

01/18/2019 10:10:59 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /Users/keitakurita/.pytorch_pretrained_bert/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [68]:
reader_cls = reader_registry[config.dataset]
reader = reader_cls(tokenizer=tokenizer,
                    token_indexers={"tokens": token_indexer})

In [69]:
if config.dataset == "imdb":
    data_dir = DATA_ROOT / "imdb" / "aclImdb"
    train_ds, test_ds = (reader.read(data_dir / fname) for fname in ["train", "test"])
    val_ds = None
elif config.dataset == "sst":
    data_dir = DATA_ROOT / "trees"
    train_ds, val_ds, test_ds = (reader.read(data_dir / fname) for fname in ["train.txt", "dev.txt", "test.txt"])
else:
    train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train.csv", "test_proced.csv"])
    val_ds = None

0it [00:00, ?it/s]01/18/2019 10:29:05 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/train.txt
6920it [00:02, 2549.30it/s]
0it [00:00, ?it/s]01/18/2019 10:29:07 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/dev.txt
872it [00:00, 2716.32it/s]
0it [00:00, ?it/s]01/18/2019 10:29:08 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/test.txt
1821it [00:00, 2073.79it/s]


In [70]:
len(train_ds), len(test_ds)

(6920, 1821)

### Prepare vocabulary

In [71]:
vocab = Vocabulary.from_instances(train_ds)
if config.bert_model is not None: 
    token_indexer._add_encoding_to_vocabulary(vocab)

01/18/2019 10:29:10 - INFO - allennlp.data.vocabulary -   Fitting token dictionary from dataset.
100%|██████████| 6920/6920 [00:00<00:00, 144745.86it/s]


### Prepare iterator

In [72]:
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator

In [73]:
# TODO: Allow for customization
iterator = BucketIterator(batch_size=config.batch_size, 
                          biggest_batch_first=True,
                          sorting_keys=[("tokens", "num_tokens")],)
iterator.index_with(vocab)

### Read sample

In [74]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    batch = next(iter(iterator(train_ds)))

In [75]:
batch

{'tokens': {'tokens': tensor([[  101,  1249,  1342,  1193,  1112,  1103,  2523,  4642,  1106,  1294,
            2305,  1104,  1157,  1641,  1959,   117,  1175,  2606,   170,  3321,
            7275,  1206,  1103,  1273,   112,  1116, 19857,   117,  4044, 28137,
           12734, 10136, 18208,  1200,   118, 26161,  2064, 28137,  6983, 16513,
            2511,   118,  2069, 22672, 28137,  1105, 20497,  6696,  2944,  4096,
            1115,  1185,  2971,  1104, 21304, 18977, 15604, 21155, 12805,  5389,
            6185,  1169,  2738,   119,   102,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0],
          [  101,  1332,  1122,   112,  1116,  1136,  2095, 13621,  1107, 16358,
           26654,  1348,  1143,  2858,  7412,  1918,   117,   169, 28152,  5230,
            2453,  4373,   140,  2149,  5710,   112, 28131,  1110,   170,  4105,
             117,  7345,   117,  1105, 24815,  37

In [76]:
batch["tokens"]["tokens"]

tensor([[  101,  1249,  1342,  1193,  1112,  1103,  2523,  4642,  1106,  1294,
          2305,  1104,  1157,  1641,  1959,   117,  1175,  2606,   170,  3321,
          7275,  1206,  1103,  1273,   112,  1116, 19857,   117,  4044, 28137,
         12734, 10136, 18208,  1200,   118, 26161,  2064, 28137,  6983, 16513,
          2511,   118,  2069, 22672, 28137,  1105, 20497,  6696,  2944,  4096,
          1115,  1185,  2971,  1104, 21304, 18977, 15604, 21155, 12805,  5389,
          6185,  1169,  2738,   119,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  1332,  1122,   112,  1116,  1136,  2095, 13621,  1107, 16358,
         26654,  1348,  1143,  2858,  7412,  1918,   117,   169, 28152,  5230,
          2453,  4373,   140,  2149,  5710,   112, 28131,  1110,   170,  4105,
           117,  7345,   117,  1105, 24815,  3789, 28137,  7412,  1918,  1164,
           170

In [77]:
batch["tokens"]["tokens"].shape

torch.Size([8, 84])

# Prepare Model

In [78]:
import torch
import torch.nn as nn
import torch.optim as optim

In [79]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask

In [80]:
class LstmEncoder(nn.Module):
    def __init__(self, lstm):
        super().__init__()
        self.lstm = lstm
        
    def forward(self, x, mask): # TODO: replace with allennlp built in modules
        _, (state, _) = self.lstm(x)
        state = torch.cat([state[0, :, :], state[1, :, :]], dim=1)
        return state

In [81]:
class BertPooler(nn.Module):
    """Source code copied"""
    def __init__(self):
        super().__init__()
        self.dense = nn.Linear(768, 768)
        self.activation = nn.Tanh()

    def forward(self, hidden_states, mask):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [82]:
class SentimentAnalysisModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: StackedBidirectionalLstm,
                 out_sz=config.n_classes):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
#         self.projection = nn.Linear(encoder.get_output_dim(), out_sz)
        self.projection = nn.Linear(config.hidden_sz, out_sz)
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens["tokens"])
        state = self.encoder(embeddings, mask)

        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        if label is not None:
            output["accuracy"] = self.accuracy(class_logits, label, None)
            output["loss"] = nn.CrossEntropyLoss()(class_logits, label)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [55]:
if config.bert_model is None:
    token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                                embedding_dim=config.embed_dim)
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

    # encoder = PytorchSeq2VecWrapper(nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True,
    #                                         bidirectional=True))
    encoder = LstmEncoder(
        nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True, bidirectional=True)
    )

else:
    word_embeddings = PretrainedBertEmbedder(
        pretrained_model=config.bert_model,
        top_layer_only=True, # conserve memory
    )
    encoder = BertPooler()

01/18/2019 10:26:02 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz from cache at /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c
01/18/2019 10:26:02 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c to temp dir /var/folders/hy/1czs1y5j2d58zgkqx6w_wnpw0000gn/T/tmp_0utw264
01/18/2019 10:26:05 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 

In [64]:
model = SentimentAnalysisModel(
    word_embeddings, 
    encoder, 
    out_sz=config.n_classes,
)

In [65]:
model

SentimentAnalysisModel(
  (word_embeddings): PretrainedBertEmbedder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (

### Basic sanity checks

In [83]:
np.isnan(list(model.word_embeddings.parameters())[0].detach().numpy()).any()

False

In [84]:
[np.isnan(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [85]:
[np.isinf(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [86]:
tokens = batch["tokens"]
encoder(model.word_embeddings(tokens["tokens"]), get_text_field_mask(tokens))
# encoder(model.word_embeddings(tokens["tokens"]))[1][0].size()

tensor([[-0.0811,  0.2274, -0.0964,  ..., -0.6108, -0.4058, -0.0385],
        [-0.2271,  0.3030, -0.2617,  ..., -0.6076, -0.2934,  0.0424],
        [-0.1593,  0.2987, -0.2722,  ..., -0.6099, -0.3207, -0.0030],
        ...,
        [-0.1332,  0.4748, -0.2816,  ..., -0.7125, -0.2877, -0.0654],
        [-0.1009,  0.2775, -0.1097,  ..., -0.5298, -0.2518,  0.0039],
        [-0.0431,  0.5821, -0.4042,  ..., -0.4234, -0.2138,  0.1691]],
       grad_fn=<TanhBackward>)

In [87]:
loss = model(**batch)["loss"]

In [88]:
nn.CrossEntropyLoss()(model(**batch)["class_logits"][:10, :], batch["label"][:10])

tensor(0.6900, grad_fn=<NllLossBackward>)

In [89]:
loss

tensor(0.6905, grad_fn=<NllLossBackward>)

In [90]:
loss.backward()

In [91]:
[x.grad for x in list(model.encoder.parameters())]

[tensor([[-2.8748e-04,  1.4027e-04,  2.7458e-04,  ...,  3.5053e-04,
           7.0984e-08, -1.5276e-04],
         [ 2.1549e-03, -1.2364e-03, -2.3016e-03,  ..., -2.7667e-03,
           9.1618e-05,  1.2072e-03],
         [-2.7512e-03,  1.2500e-03,  2.6125e-03,  ...,  3.2471e-03,
          -3.3257e-05, -1.3607e-03],
         ...,
         [ 8.0093e-04, -2.2354e-04, -5.5953e-04,  ..., -6.8971e-04,
           1.0722e-04,  2.7345e-04],
         [ 1.2561e-03, -6.4963e-04, -1.2403e-03,  ..., -1.5432e-03,
           2.7694e-05,  6.6736e-04],
         [ 2.1756e-03, -1.0390e-03, -2.0930e-03,  ..., -2.5987e-03,
           6.2243e-05,  1.1039e-03]]),
 tensor([-8.7704e-04,  6.5429e-03, -8.0143e-03, -2.4327e-03, -1.6460e-03,
          9.6156e-04,  8.4481e-03, -2.2472e-03, -1.1686e-03, -4.3503e-03,
         -2.8182e-03, -4.1889e-03, -2.7003e-03,  3.6320e-03,  1.2600e-03,
          5.2557e-03, -9.4235e-04, -6.7970e-05,  2.4847e-03,  2.6885e-03,
         -1.8087e-03, -6.4176e-04,  5.3326e-04,  1.7899e-0

# Train

In [92]:
from allennlp.training.trainer import Trainer

In [93]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
training_options = {
    # TODO: Add appropriate learning rate scheduler
    "should_log_parameter_statistics": True,
    "should_log_learning_rate": True,
    "patience": 3,
    "num_epochs": 10,
}

In [94]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=DATA_ROOT / "ckpts",
    **training_options,
)

In [None]:
trainer.train()

01/18/2019 10:31:08 - INFO - allennlp.training.trainer -   Beginning training.
01/18/2019 10:31:08 - INFO - allennlp.training.trainer -   Epoch 0/9
01/18/2019 10:31:08 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1521.72544
01/18/2019 10:31:08 - INFO - allennlp.training.trainer -   Training
accuracy: 0.7148, loss: 0.5558 ||: 100%|██████████| 109/109 [08:55<00:00,  4.50s/it]
01/18/2019 10:40:04 - INFO - allennlp.training.trainer -   Validating
accuracy: 0.7890, loss: 0.4587 ||: 100%|██████████| 14/14 [00:56<00:00,  5.05s/it]
01/18/2019 10:41:00 - INFO - allennlp.training.trainer -                     Training |  Validation
01/18/2019 10:41:00 - INFO - allennlp.training.trainer -   loss          |     0.556  |     0.459
01/18/2019 10:41:00 - INFO - allennlp.training.trainer -   accuracy      |     0.715  |     0.789
01/18/2019 10:41:00 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1521.725  |       N/A
01/18/2019 10:41:00 - INFO - allennlp.training.trainer -