In [1]:
%cd ..

/media/crocoder/New Volume/Personal/Projects/Task-4/ReCAM


In [2]:
from src.datasets.concat_dataset import ConcatDataset

In [3]:
from src.utils.configuration import Config

In [4]:
data_config = Config(path= './configs/datasets/concat/bert_answer_attention.yaml')

In [5]:
data_config.train

Config(dic={'name': 'concat', 'data_dir': './data/train/', 'truncate_length': 512, 'preprocessor': {'name': 'transformersPreprocessor', 'tokenizer': {'name': 'AutoTokenizer', 'init_params': {'pretrained_model_name_or_path': 'bert-base-uncased'}}}, 'file_path': './data/train/Task_1_train.jsonl', 'split': 'train'})

In [6]:
from src.utils.mapper import configmapper
from src.modules.tokenizers import *


In [7]:
tokenizer = configmapper.get_object("tokenizers",data_config.train.preprocessor.tokenizer.name).from_pretrained(**data_config.train.preprocessor.tokenizer.init_params.as_dict())

In [8]:
from transformers import LongformerTokenizer
tok2 = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

In [9]:
dat = ConcatDataset(data_config.train, tokenizer)

In [10]:
from torch.utils.data import DataLoader

In [11]:
loader = DataLoader(dat, batch_size=4,shuffle=True, collate_fn=dat.custom_collate_fn)

In [12]:
sample = next(iter(loader))

In [13]:
sample[0].keys()

dict_keys(['concats_token_ids', 'concats_token_type_ids', 'answer_indices', 'concats_attention_masks', 'options_indices', 'options_attention_masks'])

In [14]:
sample[0]['concats_token_ids'].shape

torch.Size([4, 493])

In [15]:
sample[0]['concats_token_type_ids'].shape

torch.Size([4, 493])

In [16]:
sample[0]['answer_indices']

tensor([376, 458, 454, 462])

In [17]:
sample[0]['concats_attention_masks'].shape

torch.Size([4, 493])

In [18]:
sample[0]['options_indices']

tensor([[[384],
         [386],
         [388],
         [390],
         [392]],

        [[464],
         [466],
         [468],
         [470],
         [472]],

        [[483],
         [485],
         [487],
         [489],
         [491]],

        [[474],
         [476],
         [478],
         [480],
         [482]]])

In [19]:
sample[0]['options_attention_masks']

tensor([[[1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1]]])

In [20]:
import gc

In [21]:
gc.collect()

148

In [22]:
from transformers import BertModel, BertConfig, BertPreTrainedModel

In [23]:
from torch import nn
import torch.nn.functional as F


class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()

        self.linear = nn.Linear(in_features=in_features, out_features=out_features)
        self.init_params()

    def init_params(self):
        nn.init.kaiming_normal_(self.linear.weight.data)
        nn.init.constant_(self.linear.bias.data, 0)

    def forward(self, x):

        # x: [batch_size, seq_len, in_features]
        x = self.linear(x)
        # x: [batch_size, seq_len, out_features]
        return x


In [24]:
class MLPAttentionLogits(nn.Module):
    def __init__(self, dim, dropout):
        super(MLPAttentionLogits, self).__init__()

        self.Q_W = Linear(dim, dim)
        self.K_W = Linear(dim, dim)
        self.V_W = Linear(dim, dim)

        self.linear = Linear(dim, 1)

        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V):
        # Q: [batch_size, dim]
        # K: [batch_size, seq_len, dim]
        # V: [batch_size, seq_len, dim]

        Q = self.dropout(self.Q_W(Q))  # [batch_size, dim]
        K = self.dropout(self.K_W(K))  # [batch_size, seq_len, dim]
        V = self.dropout(self.V_W(V))  # [batch_size, seq_len, dim]

        Q = Q.unsqueeze(1)  # [batch_size, 1, dim]
        M = self.dropout(Q * K)  # [batch_size, seq_len, dim]
        scores = self.dropout(self.linear(M))  # [batch_size, seq_len, 1]

        return scores

In [137]:
class AnswerAttentionBert(nn.Module):
    def __init__(self, config):
        super(AnswerAttentionBert, self).__init__()
        self.config = config
        self.bert = BertModel.from_pretrained(self.config.bert_pretrained_name)
        self.attention = MLPAttentionLogits(self.config.hidden_size, self.config.dropout)

    def forward(self, batch):
        concats_token_ids = batch["concats_token_ids"]  # [batch_size,seq_length]
        concats_token_type_ids = batch[
            "concats_token_type_ids"
        ]  # [batch_size, seq_length]
        answer_indices = batch["answer_indices"]  # [batch_size,]
        concats_attention_masks = batch[
            "concats_attention_masks"
        ]  # [batch_size, seq_length]
        options_indices = batch[
            "options_indices"
        ]  # [batch_size, 5, max_options_length] May be padded using 1000000
        options_attention_masks = batch["options_attention_masks"] # [batch_size, 5, max_options_length]

        concat_embeddings = self.bert(
            input_ids = concats_token_ids,
            attention_mask = concats_attention_masks,
            token_type_ids=concats_token_type_ids,
        )[0]  # [batch_size, seq_length, hidden_size]

        batch_size = answer_indices.shape[0]
        hidden_size = concat_embeddings.shape[-1]
        answer_indices = answer_indices.reshape(-1, 1, 1).expand(
            batch_size, 1, hidden_size
        )  # [batch_size, 1, hidden_size]
        mask_embedding = torch.gather(
            concat_embeddings, 1, answer_indices
        )  # [batch_size,1,768]

        tokens_per_option_per_batch = torch.sum(options_attention_masks,dim=2) ##[batch_size,5]

        ops_avg_embeddings = []

        for i in range(5):
            ops_i_indices = options_indices[:,i,:] ## [batch_size,max_options_length]
            ops_i_indices = ops_i_indices.reshape(batch_size,-1,1).expand(batch_size,-1,hidden_size) ##[batch_size,max_options_length,hidden_size]
            ops_i_masks = options_attention_masks[:,i,:].reshape(batch_size,-1,1).expand(batch_size,-1,hidden_size) ##[batch_size,max_options_length,hidden_size]
            ops_i_embeddings = torch.gather(concat_embeddings, 1,ops_i_indices) ## [batch_size,max_options_length,hidden_size]
            ops_i_embeddings = ops_i_masks*ops_i_embeddings ## [batch_size,max_options_length,hidden_size]
            ops_i_avg_embeddings = torch.sum(ops_i_embeddings,dim=1)/tokens_per_option_per_batch[:,i].reshape(-1,1).expand(-1,hidden_size) ##[batch_size,hidden_size]
            ops_avg_embeddings.append(ops_i_avg_embeddings.unsqueeze(1))
        ops_avg_embeddings = torch.cat(ops_avg_embeddings,dim=1)

        out_logits = self.attention(mask_embedding.squeeze(), ops_avg_embeddings, ops_avg_embeddings).squeeze()

        return out_logits   



In [138]:
model_config = {
    'bert_pretrained_name':'bert-base-uncased',
    'num_labels': 5,
    'hidden_size': 768,
    'dropout':0.2,

}

In [139]:
model_config = Config(dic=model_config)

In [140]:
aab = AnswerAttentionBert(model_config)

In [141]:
aab(sample[0])

tensor([[-0.4213, -0.0611,  0.4953, -0.2289,  0.0000],
        [-0.2388,  0.3919,  0.4119, -0.4489, -0.0000],
        [-0.2882, -0.0000, -1.2821,  0.9153,  0.0890],
        [-0.3784,  0.0237, -0.5242, -0.8572,  0.1734]],
       grad_fn=<SqueezeBackward0>)

In [142]:
gc.collect()

270