In [1]:
import sys
sys.path.append("../")
sys.path.append("../../")

In [2]:
import os
import torch
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import BertModel, BertTokenizerFast
from cosmosqa.data_loader.dataloader import create_data_loader

In [3]:
from cosmosqa.data_loader.dataset import CosmosQADataset

In [4]:
from collections import defaultdict

In [5]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [6]:
PRE_TRAINED_MODEL_NAME = "bert-base-cased"
MAX_LEN = 160
BATCH_SIZE = 2
EPOCHS = 1

In [7]:
tokenizer = BertTokenizerFast.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [8]:
df_train = pd.read_csv("../data/train_sample.csv")
df_valid = pd.read_csv("../data/valid_sample.csv")

In [9]:
train_data_loader = create_data_loader(
    df=df_train, tokenizer=tokenizer, max_len=MAX_LEN, batch_size=BATCH_SIZE
)
valid_data_loader = create_data_loader(
    df=df_valid, tokenizer=tokenizer, max_len=MAX_LEN, batch_size=BATCH_SIZE
)

In [10]:
# BERT with multiway attention
class BertMultiwayMatch(nn.Module):
    def __init__(self, config, num_choices=4):
        super(BertMultiwayMatch, self).__init__()
        self.num_choices = num_choices
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.linear_trans = nn.Linear(config.hidden_size, config.hidden_size)
        self.linear_fuse_p = nn.Linear(
            config.hidden_size * 2, config.hidden_size
        )
        self.linear_fuse_q = nn.Linear(
            config.hidden_size * 2, config.hidden_size
        )
        self.linear_fuse_a = nn.Linear(
            config.hidden_size * 2, config.hidden_size
        )
        self.classifier = nn.Linear(config.hidden_size * 3, 1)

    def matching(
        self,
        passage_encoded,
        question_encoded,
        passage_attention_mask,
        question_attention_mask,
    ):
        # linear trans the other way
        passage_encoded_trans = self.linear_trans(passage_encoded)
        question_encoded_trans = self.linear_trans(question_encoded)
        p2q_scores = torch.matmul(
            passage_encoded_trans, question_encoded_trans.transpose(2, 1)
        )

        # fp16 compatibility
        merged_attention_mask = (
            passage_attention_mask.unsqueeze(2)
            .float()
            .matmul(question_attention_mask.unsqueeze(1).float())
        )
        merged_attention_mask = merged_attention_mask.to(
            dtype=next(self.parameters()).dtype
        )
        merged_attention_mask = (1.0 - merged_attention_mask) * -10000.0

        p2q_scores_ = p2q_scores + merged_attention_mask
        # Normalize the attention scores to probabilities.
        p2q_w = nn.Softmax(dim=-1)(p2q_scores_)
        p2q_w_ = nn.Softmax(dim=1)(p2q_scores_)

        # question attentive passage representation
        mp = torch.matmul(p2q_w, question_encoded)
        # passage attentive question representation
        mq = torch.matmul(p2q_w_.transpose(2, 1), passage_encoded)

        return mp, mq

    # sub and multiply
    def fusing_mlp(
        self,
        passage_encoded,
        mp_q,
        mp_a,
        mp_qa,
        question_encoded,
        mq_p,
        mq_a,
        mq_pa,
        answers_encoded,
        ma_p,
        ma_q,
        ma_pq,
    ):
        new_mp_q = torch.cat(
            [mp_q - passage_encoded, mp_q * passage_encoded], 2
        )
        new_mp_a = torch.cat(
            [mp_a - passage_encoded, mp_a * passage_encoded], 2
        )
        new_mp_qa = torch.cat(
            [mp_qa - passage_encoded, mp_qa * passage_encoded], 2
        )
        new_mq_p = torch.cat(
            [mq_p - question_encoded, mq_p * question_encoded], 2
        )
        new_mq_a = torch.cat(
            [mq_a - question_encoded, mq_a * question_encoded], 2
        )
        new_mq_pa = torch.cat(
            [mq_pa - question_encoded, mq_pa * question_encoded], 2
        )
        new_ma_p = torch.cat(
            [ma_p - answers_encoded, ma_p * answers_encoded], 2
        )
        new_ma_q = torch.cat(
            [ma_q - answers_encoded, ma_q * answers_encoded], 2
        )
        new_ma_pq = torch.cat(
            [ma_pq - answers_encoded, ma_pq * answers_encoded], 2
        )

        new_mp = torch.cat([new_mp_q, new_mp_a, new_mp_qa], 1)
        new_mq = torch.cat([new_mq_p, new_mq_a, new_mq_pa], 1)
        new_ma = torch.cat([new_ma_p, new_ma_q, new_ma_pq], 1)

        # use separate linear functions
        new_mp_ = F.relu(self.linear_fuse_p(new_mp))
        new_mq_ = F.relu(self.linear_fuse_q(new_mq))
        new_ma_ = F.relu(self.linear_fuse_a(new_ma))

        new_p_max, new_p_idx = torch.max(new_mp_, 1)
        new_q_max, new_q_idx = torch.max(new_mq_, 1)
        new_a_max, new_a_idx = torch.max(new_ma_, 1)

        new_p_max_ = new_p_max.view(-1, self.num_choices, new_p_max.size(1))
        new_q_max_ = new_q_max.view(-1, self.num_choices, new_q_max.size(1))
        new_a_max_ = new_a_max.view(-1, self.num_choices, new_a_max.size(1))

        c = torch.cat([new_p_max_, new_q_max_, new_a_max_], 2)

        return c

    def forward(
        self,
        input_ids,
        token_type_ids=None,
        attention_mask=None,
        doc_len=None,
        ques_len=None,
        option_len=None,
        labels=None,
    ):
#         flat_input_ids = input_ids.view(-1, input_ids.size(-1))
#         doc_len = doc_len.view(-1, doc_len.size(0) * doc_len.size(1)).squeeze()
#         ques_len = ques_len.view(
#             -1, ques_len.size(0) * ques_len.size(1)
#         ).squeeze()
#         option_len = option_len.view(
#             -1, option_len.size(0) * option_len.size(1)
#         ).squeeze()
#         flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
#         flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))

#         sequence_output, pooled_output = self.bert.forward(
#             flat_input_ids,
#             flat_token_type_ids,
#             flat_attention_mask,
#             output_all_encoded_layers=False,
#         )
        print("input_ids")
        print(input_ids.size())
        print("doc_len")
        print(doc_len.size())
        print("ques_len")
        print(ques_len.size())
        print("option_len")
        print(option_len.size())
        print("labels")
        print(labels.size())
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        print("flat_input_ids")
        print(flat_input_ids.size())
        doc_len = doc_len.view(-1, doc_len.size(0) * doc_len.size(1)).squeeze()
        print("doc_len_squeezed")
        print(doc_len.size())
        ques_len = ques_len.view(-1, ques_len.size(0) * ques_len.size(1)).squeeze()
        print("ques_len_squeezed")
        print(ques_len.size())
        option_len = option_len.view(-1, option_len.size(0) * option_len.size(1)).squeeze()
        print("option_len_squeezed")
        print(option_len.size())
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        print("flat_token_type_ids")
        print(flat_token_type_ids.size())
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        print("flat_attention_mask")
        print(flat_attention_mask.size())

        sequence_output, pooled_output = self.bert.forward(flat_input_ids, flat_token_type_ids,flat_attention_mask)

        print("sequence_output")
        print(sequence_output.size())
        print("pooled_output")
        print(pooled_output.size())
        (
            passage_encoded,
            question_encoded,
            answers_encoded,
            passage_question_encoded,
            passage_answer_encoded,
            question_answer_encoded,
        ) = seperate_seq(sequence_output, doc_len, ques_len, option_len)
        print("passage_encoded")
        print(passage_encoded.size())
        print("passage_question_encoded")
        print(passage_question_encoded.size())
        (
            passage_attention_mask,
            question_attention_mask,
            answers_attention_mask,
            passage_question_attention_mask,
            passage_answer_attention_mask,
            question_answer_attention_mask,
        ) = seperate_seq_attention(
            flat_attention_mask, doc_len, ques_len, option_len
        )

        # matching layer
        mp_q, mq_p = self.matching(
            passage_encoded,
            question_encoded,
            passage_attention_mask,
            question_attention_mask,
        )
        print("mp_q")
        print(mp_q.size())
        mp_a, ma_p = self.matching(
            passage_encoded,
            answers_encoded,
            passage_attention_mask,
            answers_attention_mask,
        )
        mp_qa, mqa_p = self.matching(
            passage_encoded,
            question_answer_encoded,
            passage_attention_mask,
            question_answer_attention_mask,
        )
        mq_a, ma_q = self.matching(
            question_encoded,
            answers_encoded,
            question_attention_mask,
            answers_attention_mask,
        )
        mq_pa, mpa_q = self.matching(
            question_encoded,
            passage_answer_encoded,
            question_attention_mask,
            passage_answer_attention_mask,
        )
        ma_pq, mpq_a = self.matching(
            answers_encoded,
            passage_question_encoded,
            answers_attention_mask,
            passage_question_attention_mask,
        )

        # MLP fuse
        c = self.fusing_mlp(
            passage_encoded,
            mp_q,
            mp_a,
            mp_qa,
            question_encoded,
            mq_p,
            mq_a,
            mq_pa,
            answers_encoded,
            ma_p,
            ma_q,
            ma_pq,
        )
        c_ = c.view(-1, c.size(2))
        logits = self.classifier(c_)
        reshaped_logits = logits.view(-1, self.num_choices)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
            return loss, reshaped_logits
        else:
            return reshaped_logits

In [11]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [12]:
config = {
    "hidden_dropout_prob": 0.1,
    "hidden_size": 1024
}
config = dotdict(config)

In [13]:
model = BertMultiwayMatch(config=config, num_choices=4)

In [None]:
for d in train_data_loader:
    input_ids = d["input_ids"]
    attention_mask = d["attention_mask"]
    token_type_ids = d["token_type_ids"]
    c_len = d["c_len"]
    q_len = d["q_len"]
    a_len = d["a_len"]
    targets = d["target"].type(dtype=torch.long)

    loss, logits = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        doc_len=c_len,
        ques_len=q_len,
        option_len=a_len,
        labels=targets,
    )
    print("loss")
    print(loss)
    print("logits")
    print(logits)
    loss.backward()