In [10]:
import torch
from torch import nn
import numpy as np
from torch.nn import CrossEntropyLoss, BCELoss
from transformers import BertForQuestionAnswering, BertTokenizer, BertModel, AdamW

In [25]:
from transformers import BertPreTrainedModel
class DST_SPAN(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
        self.clf = nn.Linear(config.hidden_size, 3)
        
        self.init_weights()


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        slot_label=None,
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        rCLS = sequence_output[:, 0, :]
        slot_logits = self.clf(rCLS)

        outputs = (start_logits, end_logits,) + outputs[2:]
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            
            loss_bce = BCELoss()
            slot_loss = loss_fct(slot_logits, slot_label)
            
            total_loss = (start_loss + end_loss + slot_loss) / 3
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)


In [26]:
model = DST_SPAN.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters)        

Some weights of the model checkpoint at bert-base-uncased were not used when initializing DST_SPAN: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing DST_SPAN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DST_SPAN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DST_SPAN were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias

In [27]:
def preprocess(context, question, value):
    max_len = 512
    tokenized_context = tokenizer(context)
    tokenized_question = tokenizer(question)
    tokenized_value = tokenizer(value)['input_ids'][1:-1]

    # Create inputs
    input_ids = tokenized_context['input_ids'] + tokenized_question['input_ids'][1:]
    token_type_ids = [0] * len(tokenized_context['input_ids']) + [1] * len(
        tokenized_question['input_ids'][1:]
    )
    attention_mask = [1] * len(input_ids)

    # Pad and create attention masks.
    # Skip if truncation is needed
    padding_length = max_len - len(input_ids)
    if padding_length > 0:  # pad
        input_ids = input_ids + ([0] * padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)

    # Find start and stop position of span
    l = len(tokenized_value)
    start, end = -1, -1
    for i in range(len(input_ids), l, -1):
        if input_ids[i-l: i] == [0] * l:
            continue
        elif input_ids[i-l: i] == tokenized_value:
            start = i-l
            end = i-1
            break
    return np.array(input_ids), np.array(token_type_ids), np.array(attention_mask), start, end

In [28]:
batch = preprocess("book a hotel on monday", "hotel-day", "monday")

In [42]:
model.train()
for i in range(10):
    inputs = {
        "input_ids": torch.tensor([batch[0]]+[batch[0]]),
        "attention_mask": torch.tensor([batch[1]]+[batch[1]]),
        "token_type_ids": torch.tensor([batch[2]]+[batch[2]]),
        "start_positions": torch.tensor([batch[3], batch[3]]),
        "end_positions": torch.tensor([batch[4], batch[4]]),
        "slot_label": torch.tensor([2,2]),
    }
    outputs = model(**inputs)
    loss = outputs[0]
    print(loss.item())
    loss.backward()
    optimizer.step()
    model.zero_grad()

4.639381408691406
4.377444744110107
4.268583297729492
4.237524032592773
4.354714870452881
4.079228401184082
4.131597995758057
4.092955589294434
4.08426570892334
4.602624416351318


In [40]:
inputs['input_ids'].shape

torch.Size([2, 512])

In [35]:
inputs['end_positions']

tensor([5, 5])

In [17]:
outputs

(tensor([[ 3.1427e-02,  7.1057e-02, -2.6099e-01, -6.7041e-02,  6.6192e-02,
           3.0436e-01, -1.3487e-01, -1.4287e-01,  5.1434e-01,  1.0080e-01,
          -1.3572e-01,  4.6623e-02,  1.8986e-02,  1.6334e-02,  4.4456e-03,
           9.9585e-03,  4.9683e-02,  8.7223e-02, -1.3443e-02, -5.6302e-03,
           4.2708e-03,  4.0149e-02,  1.5471e-01,  6.4796e-03,  2.0075e-02,
           3.4052e-02,  3.6993e-02,  8.3071e-02,  1.7934e-02,  1.1810e-02,
           3.1950e-02,  1.7915e-01,  5.9460e-02,  1.1314e-01,  8.0607e-03,
           3.0791e-02,  3.8122e-02,  1.7604e-02,  3.7104e-02,  9.3630e-03,
           2.0538e-02,  2.6220e-03, -5.2061e-03,  6.9189e-03, -1.2484e-02,
          -1.1577e-03,  1.6699e-01,  1.7600e-01,  3.5491e-02, -2.1768e-03,
           2.9017e-03,  3.8176e-02,  4.7572e-02,  3.5803e-02,  3.1142e-02,
           2.5433e-02,  3.0105e-02,  1.2157e-01,  1.1454e-02,  2.7681e-03,
           7.3168e-02,  1.8511e-01,  3.5458e-02,  6.2214e-02, -1.3537e-02,
           3.8877e-03,  2