In [None]:
# default_exp model

In [None]:
# export
import logging
logging.disable(logging.CRITICAL)
import torch
from transformers import DistilBertModel, DistilBertForSequenceClassification, DistilBertConfig
from emotion_transformer.dataloader import dataloader

# Model

> construction of the DoubleDistilBert model for the SemEval-2019 Task 3 dataset (contextual emotion detection in text)

## Transformer Sentence Embeddings

First we create sentence embeddings for each utterance. We use a pretrained DistilBert model to obtain contextual word embeddings and then concatenate the CLS token embedding and the mean of the last layer. Note that in order to feed batches into out model we need to temporarily flatten our `input_ids`, i.e. we get three times as many input sentences as the specified `batch_size`.

For more information on the (Distil)Bert models one can look at 
Jay Alammar's blog posts ([A Visual Guide to Using BERT for the First Time](https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/) and [The Illustrated BERT, ELMo, and co.](https://jalammar.github.io/illustrated-bert/)) where also the following illustration is taken from.

![DistilBert output](./images/bert-distilbert-output-tensor-predictions.png)

Further references:
 
* [DistilBert paper](https://arxiv.org/abs/1910.01108) and [blog post](https://medium.com/huggingface/distilbert-8cf3380435b5)
* [Original Bert (Bidirectional Encoder Representations from Transformers) paper](https://arxiv.org/abs/1810.04805)
* [tutorial for custom PyTorch modules](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html)
* [Huggingface transformers documentation](https://huggingface.co/transformers/v2.3.0/index.html)

In [None]:
# export
class sentence_embeds_model(torch.nn.Module):
    """
    instantiates the pretrained DistilBert model and the linear layer
    """
    
    def __init__(self, dropout = 0.1):
        super(sentence_embeds_model, self).__init__()
        
        self.transformer = DistilBertModel.from_pretrained('distilbert-base-uncased', dropout=dropout, 
                                                           output_hidden_states=True)
        self.embedding_size = 2 * self.transformer.config.hidden_size
        
    def layerwise_lr(self, lr, decay):
        """
        returns grouped model parameters with layer-wise decaying learning rate
        """
        bert = self.transformer
        num_layers = bert.config.n_layers
        opt_parameters = [{'params': bert.embeddings.parameters(), 'lr': lr*decay**num_layers}]
        opt_parameters += [{'params': bert.transformer.layer[l].parameters(), 'lr': lr*decay**(num_layers-l+1)} 
                            for l in range(num_layers)]
        return opt_parameters
               
    def forward(self, input_ids = None, attention_mask = None, input_embeds = None):
        """
        returns the sentence embeddings
        """
        if input_ids is not None:
            input_ids = input_ids.flatten(end_dim = 1)
        if attention_mask is not None:
            attention_mask = attention_mask.flatten(end_dim = 1)
        output = self.transformer(input_ids = input_ids, 
                                  attention_mask = attention_mask, inputs_embeds = input_embeds)
    
        cls = output[0][:,0]
        hidden_mean = torch.mean(output[1][-1],1)
        sentence_embeds = torch.cat([cls, hidden_mean], dim = -1)
        
        return sentence_embeds.view(-1, 3, self.embedding_size)

To illustrate the model let us import our dataloader.

In [None]:
path = 'data/clean_train.txt'
batch_size = 5
max_seq_len = 10
emo_dict = {'others': 0, 'sad': 1, 'angry': 2, 'happy': 3}
loader = dataloader(path, max_seq_len, batch_size, emo_dict)
input_ids, attention_mask, labels = next(iter(loader))

The DistilBert model outputs 

* 768-dimensional embeddings for each of the 'max_seq_len' tokens and each of the three utterances of the `batch_size` conversations and
* a list of the hidden-states in all of the 6 DistilBert transformer layers (including the first embedding)

In [None]:
embeds_model = sentence_embeds_model()

last_layer, hidden_states = embeds_model.transformer(input_ids = input_ids.flatten(end_dim = 1), attention_mask = attention_mask.flatten(end_dim = 1))
input_embeds = embeds_model.transformer.embeddings(input_ids.flatten(end_dim = 1))

assert torch.all(hidden_states[0] == input_embeds)
assert torch.all(hidden_states[-1] == last_layer)

len(hidden_states), last_layer.shape

(7, torch.Size([15, 10, 768]))

Let us now create sentence embeddings (we put the model in evaluation mode to deactivate dropout for later consistency checks). Note that the forward method of the model reshapes the output again back to the shape of the corresponding `input_ids`.

In [None]:
embeds_model.eval()
assert(embeds_model.transformer.transformer.layer[0].dropout.training == False)

sentence_embeds = embeds_model(input_ids = input_ids, attention_mask = attention_mask)

In [None]:
assert input_ids.shape[:2] == sentence_embeds.shape[:2]
assert sentence_embeds.shape[-1] == embeds_model.embedding_size
input_ids.shape, sentence_embeds.shape

(torch.Size([5, 3, 10]), torch.Size([5, 3, 1536]))

We also check if the `layerwise_lr` method outputs all model parameters.

In [None]:
count = 0
for group in embeds_model.layerwise_lr(2.0e-5,0.95):
    count += len(list(group['params']))

assert count == len(list(embeds_model.parameters()))

## Context Transformer and Classification

Next we use another transformer model to create contextual sentence embeddings, i.e. we model that a conversation consists of three utterances. This is partly motivated by the [BERTSUM paper](https://arxiv.org/abs/1903.10318).

Moreover, we add a classification model for the emotion of the last utterance where we augment the loss by a binary loss due to the unbalanced data.

Note that for our convenience we use

* a linear projection of the sentence embeddings to a given `projection_size`
* a (not pre-trained) DistilBertForSequenceClassification and flip the order of the utterances as the first input embedding gets classified by default
* only one attention head, see also the paper [Are Sixteen Heads Really Better than One?](https://arxiv.org/abs/1905.10650).

In [None]:
# export 
class context_classifier_model(torch.nn.Module):
    """
    instantiates the DisitlBertForSequenceClassification model, the position embeddings of the utterances, 
    and the binary loss function
    """
    
    def __init__(self, embedding_size, projection_size, n_layers, emo_dict, dropout = 0.1):
        super(context_classifier_model, self).__init__()
        
        self.projection_size = projection_size
        self.projection = torch.nn.Linear(embedding_size, projection_size)         
        self.position_embeds = torch.nn.Embedding(3, projection_size)
        self.norm = torch.nn.LayerNorm(projection_size)
        self.drop = torch.nn.Dropout(dropout)
    
        context_config = DistilBertConfig(dropout=dropout, 
                                dim=projection_size,
                                hidden_dim=4*projection_size,
                                n_layers=n_layers,
                                n_heads = 1,
                                num_labels=4)

        self.context_transformer = DistilBertForSequenceClassification(context_config)
        self.others_label = emo_dict['others']
        self.bin_loss_fct = torch.nn.BCEWithLogitsLoss()
        
    def bin_loss(self, logits, labels):
        """
        defined the additional binary loss for the `others` label
        """
        bin_labels = torch.where(labels == self.others_label, torch.ones_like(labels), 
                                 torch.zeros_like(labels)).float()
        bin_logits = logits[:, self.others_label]    
        return self.bin_loss_fct(bin_logits, bin_labels)

    def forward(self, sentence_embeds, labels = None):
        """
        returns the logits and the corresponding loss if `labels` are given
        """
        
        position_ids = torch.arange(3, dtype=torch.long, device=sentence_embeds.device)
        position_ids = position_ids.expand(sentence_embeds.shape[:2]) 
        position_embeds = self.position_embeds(position_ids)
        sentence_embeds = self.projection(sentence_embeds) + position_embeds 
        sentence_embeds = self.drop(self.norm(sentence_embeds))
        if labels is None:
            return self.context_transformer(inputs_embeds = sentence_embeds.flip(1), labels = labels)[0]
        
        else:
            loss, logits = self.context_transformer(inputs_embeds = sentence_embeds.flip(1), labels = labels)
            return loss + self.bin_loss(logits, labels), logits

Let us initiate a the `context_classifier_model` with the corresponding `projection_size` of the sentence embedding model

In [None]:
projection_size = 100
n_layers = 2

classifier = context_classifier_model(embeds_model.embedding_size, projection_size, n_layers, emo_dict)

and do some basic checks.

In [None]:
classifier.eval()
assert(classifier.context_transformer.distilbert.transformer.layer[0].dropout.training == False)

loss, logits = classifier(sentence_embeds, labels = labels)
assert torch.all(logits == classifier(sentence_embeds))
assert loss == torch.nn.CrossEntropyLoss()(logits, labels) + classifier.bin_loss(logits, labels)

loss, logits

(tensor(2.1022, grad_fn=<AddBackward0>),
 tensor([[-0.0509, -0.0632, -0.0165,  0.0358],
         [-0.0536, -0.0644, -0.0126,  0.0392],
         [-0.0487, -0.0654, -0.0186,  0.0388],
         [-0.0541, -0.0641, -0.0206,  0.0491],
         [-0.0566, -0.0647, -0.0171,  0.0375]], grad_fn=<AddmmBackward>))

Finally, for our main consistency check we compute the gradient of the loss w.r.t. to the input embeddings. 

In [None]:
input_embeds = input_embeds.clone().detach().requires_grad_(True)
sentence_embeds_check = embeds_model(input_embeds = input_embeds, attention_mask = attention_mask)
logits_check = classifier(sentence_embeds_check)
assert torch.all(logits == logits_check)

In [None]:
logits_check[1,0].backward()
input_embeds.grad[:,0,0]

tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00, -7.6451e-07, -1.2572e-06,
        -7.2055e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00])

As anticipated, we see that only the fourth, fifth, and sixth input embedding effect the second prediction.
These correspond to the second conversation:

In [None]:
assert torch.all(input_embeds[3:6] == embeds_model.transformer.embeddings(input_ids[1]))

## Metrics

Lastly, we define the metrics, i.e. microaveraged precision, recall, and f1-score (ignoring the others class), for the evaluation of our model according to the [SemEval-2019 Task 3 challenge](https://www.aclweb.org/anthology/S19-2005/). 

In [None]:
# export
def metrics(loss, logits, labels):
    cm = torch.zeros((4,4), device = loss.device)
    preds = torch.argmax(logits, dim=1)
    acc = (labels == preds).float().mean()
    for label, pred in zip(labels.view(-1), preds.view(-1)):
        cm[label.long(), pred.long()] += 1
        
    tp = cm.diagonal()[1:].sum()
    fp = cm[:, 1:].sum() - tp
    fn = cm[1:, :].sum() - tp 
    return {'val_loss': loss, 'val_acc': acc, 'tp': tp, 'fp': fp, 'fn': fn}

def f1_score(tp, fp, fn):
    prec_rec_f1 = {}
    prec_rec_f1['precision'] = tp / (tp + fp)
    prec_rec_f1['recall'] = tp / (tp + fn)
    prec_rec_f1['f1_score'] = 2 * (prec_rec_f1['precision'] * prec_rec_f1['recall']) / (prec_rec_f1['precision'] + prec_rec_f1['recall'])
    return prec_rec_f1

In [None]:
metric = metrics(loss, logits, labels)
metric

{'val_loss': tensor(2.1022, grad_fn=<AddBackward0>),
 'val_acc': tensor(0.),
 'tp': tensor(0.),
 'fp': tensor(5.),
 'fn': tensor(3.)}

In [None]:
f1_score(metric['tp'], metric['fp'], metric['fn'])

{'precision': tensor(0.), 'recall': tensor(0.), 'f1_score': tensor(nan)}