# Interpretation of Natural Language Rules in Conversational Machine Reading, M. Saeidi et al., EMNLP 2018

https://arxiv.org/abs/1809.01494

## Step 1/2: Baseline CNN Classifier
using BERT embeddings (which isn't part of baseline though!)

## Problem

Inputs:
- Rules (encoded in natural language)
- History (question and answers b/w bot and user, that may include disambiguate user's first question)
- Question (user's question on rules)

Output:
- Answer labels
  - `Yes` | `No` (bot's answer)
  - `More`  (Bot needs more info to answer this question. Step 2/2 of pipeline).
  - `Irrelevant` (question is irrelevant to these rules)

In [1]:
from tqdm import tqdm_notebook as _tqdm  # required this way by allennlp in jupyter notebooks

In [2]:
import logging
logging.basicConfig(level=logging.ERROR)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import json

from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.vocabulary import Vocabulary

from allennlp.data.iterators import BucketIterator
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data import Instance, fields
from allennlp.common.file_utils import cached_path

from allennlp.models import Model
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper, CnnEncoder
from allennlp.nn.util import sequence_cross_entropy_with_logits, add_sentence_boundary_token_ids, get_text_field_mask

from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

## Data Reader
We join History of Follow Up Question and Answers as one text field. Someone interested in creating subsets of data based on history (length, etc) could compute fields needed further down the pipeline.

In [4]:
class SharcDatasetReader(DatasetReader):
    
    def __init__(self, token_indexers, tokenizer, lazy=False):
        super().__init__(lazy)
        self.token_indexers = token_indexers
        self.tokenizer = tokenizer
        
    def item_to_instance(self, item):
        # meta fields
        attrs = {
            "utterance_id": fields.MetadataField(item["utterance_id"]),
            "tree_id": fields.MetadataField(item["tree_id"]),
            "source_url": fields.MetadataField(item["source_url"])}

        # text fields.
        attrs["snippet"] = fields.TextField(
            tokens=self.tokenizer(item["snippet"]),
            token_indexers=self.token_indexers)
        
        attrs["question"] = fields.TextField(
            tokens=self.tokenizer(item["question"]),
            token_indexers=self.token_indexers)
        
        # "START" is a fix for empty cases.
        attrs["scenario"] = fields.TextField(
            tokens=self.tokenizer(item["scenario"]),
            token_indexers=self.token_indexers)
        
        attrs["answer"] = fields.TextField(
            tokens=self.tokenizer(item["answer"]),
            token_indexers=self.token_indexers)
        
        # computed field. answer exists or not.
        # labels: yes, no, irrelevant, more
        # Note: `label_namespace` can be set to tokens..
        attrs["answer_exists"] = fields.LabelField(
            label=item["answer"] if item["answer"] in ("Yes", "No", "Irrelevant") else "More")
        
        # history, joined as one string.
        history = []
        for x in item["history"]:
            history.append(x["follow_up_question"])
            history.append(x["follow_up_answer"])
        attrs["history"] = fields.TextField(
            tokens=self.tokenizer(" ".join(history)),
            token_indexers=self.token_indexers)
            
        return Instance(attrs)
        
    def _read(self, file_path):
        with open(file_path) as fp:
            ds = json.load(fp)
            for item in ds:
                yield self.item_to_instance(item)

In [5]:
token_indexer = PretrainedBertIndexer("bert-base-uncased")

def tokenizer(s):
    return [Token(t) for t in token_indexer.wordpiece_tokenizer(s)]

reader = SharcDatasetReader(token_indexers={"tokens": token_indexer}, tokenizer=tokenizer)

train_ds = reader.read(cached_path("~/sharc/sharc1-official/json/sharc_train.json"))
test_ds = reader.read(cached_path("~/sharc/sharc1-official/json/sharc_dev.json"))

vocab = Vocabulary.from_instances(train_ds)
vocab

21890it [00:09, 2262.63it/s]
2270it [00:00, 3455.81it/s]
100%|██████████| 21890/21890 [00:00<00:00, 54891.98it/s]


Vocabulary with namespaces:  labels, Size: 4 || Non Padded Namespaces: {'*tags', '*labels'}

# Classifier

Section 5.1 describes using samples that have no `scenario` field for training the classifier.

In [6]:
# consider only instances where `scenario` is null i.e., it only has "CLS and SEP"
# refer to sec 5.1 in paper.
clf_train_ds = [ins for ins in train_ds if len(ins.get("scenario").tokens) == 0]
clf_test_ds = [ins for ins in test_ds if len(ins.get("scenario").tokens) == 0]

len(clf_train_ds), len(clf_test_ds)

(4025, 431)

In [7]:
token_embedding = PretrainedBertEmbedder("bert-base-uncased", requires_grad=False)

padding_order = [
    ("snippet", "num_tokens"),
    ("scenario", "num_tokens"),
    ("question", "num_tokens"),
]

iterator = BucketIterator(
    batch_size=2,
    sorting_keys=padding_order)

iterator.index_with(vocab)

In [8]:
def join_fields(batch, features):
    joined = {"tokens": [], "mask": []}
    for f in features:
        mask = batch[f]["mask"]
        tokens = batch[f]["tokens"]
        # Note: default mask does not inlcude SOS/CLS and EOS/SEP tokens in mask. weird.
        if mask.shape != tokens.shape:
            del batch[f]["mask"]
            mask = batch[f]["mask"] = get_text_field_mask(batch[f])
        joined["tokens"].append(tokens)
        joined["mask"].append(mask)
    joined["tokens"] = torch.cat(joined["tokens"], dim=-1)
    joined["mask"] = torch.cat(joined["mask"], dim=-1)
    return joined

class YesNoClassifier(Model):
    
    def __init__(self, embedding, vocab):
        super().__init__(vocab)
        hidden_size = 100
        self.embedding = embedding
        self.encoder = CnnEncoder(embedding.get_output_dim(), hidden_size)
        self.final = torch.nn.Linear(self.encoder.get_output_dim(), vocab.get_vocab_size("labels"))
        self.accuracy = CategoricalAccuracy()
        
    def forward(self, *inputs, **kw):
        X = join_fields(kw, ["question", "snippet", "history"])
        out = self.embedding(X["tokens"])
        out = self.encoder(out, X["mask"])
        out = self.final(out)
        target = "answer_exists"
        response = {}
        if target in kw:
            self.accuracy(out, kw[target])
            response["loss"] = F.cross_entropy(out, kw[target])
        return response
    
    def get_metrics(self, reset=False):
        return {"accuracy": self.accuracy.get_metric(reset)}

    
model = YesNoClassifier(token_embedding, vocab).to("cuda")

optimizer = optim.Adam(model.parameters(), lr=0.1)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    train_dataset=clf_train_ds,
    validation_dataset=clf_test_ds,
    iterator=iterator,
    num_epochs=100,
    cuda_device=1,
)

trainer.train()

accuracy: 0.4489, loss: 153.5865 ||: 100%|██████████| 2013/2013 [01:01<00:00, 33.42it/s]
accuracy: 0.4339, loss: 989.1312 ||: 100%|██████████| 216/216 [00:04<00:00, 45.83it/s]
accuracy: 0.4557, loss: 376.2320 ||: 100%|██████████| 2013/2013 [00:58<00:00, 33.92it/s]
accuracy: 0.4339, loss: 2.7709 ||: 100%|██████████| 216/216 [00:04<00:00, 45.15it/s]
accuracy: 0.4492, loss: 2.4610 ||: 100%|██████████| 2013/2013 [00:59<00:00, 34.35it/s]
accuracy: 0.5197, loss: 1.0512 ||: 100%|██████████| 216/216 [00:04<00:00, 45.51it/s]
accuracy: 0.4778, loss: 0.9944 ||: 100%|██████████| 2013/2013 [00:58<00:00, 35.63it/s]
accuracy: 0.5383, loss: 1.0406 ||: 100%|██████████| 216/216 [00:04<00:00, 45.58it/s]
accuracy: 0.4333, loss: 451.5980 ||: 100%|██████████| 2013/2013 [01:00<00:00, 32.80it/s]
accuracy: 0.5406, loss: 0.8793 ||: 100%|██████████| 216/216 [00:04<00:00, 43.78it/s]
accuracy: 0.3752, loss: 1.7987 ||: 100%|██████████| 2013/2013 [01:02<00:00, 30.46it/s]
accuracy: 0.3202, loss: 1.4112 ||: 100%|█████