


In this notebook we perform one forward pass using BERT embeddings in BiDAF network. This is done for demonstration purposes. The forward pass is performed on a batch of 64 samples. 

To perform full training the code from this notebook is copied into the 'models' and 'layers' files of BiDAF implementation from Stanford. 
For details on the code that is concerned with BERT embedding replacement, refer the 'embedding_replacement' in this repository.

The implementation of code might seem inefficient but this was done for easy merging with the BiDAF implementation from Stanford.

In [1]:
import torch 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
!git clone https://github.com/huggingface/transformers \
&& cd transformers \
&& git checkout a3085020ed0d81d4903c50967687192e3101e770 

fatal: destination path 'transformers' already exists and is not an empty directory.


In [3]:
!pip install ./transformers
!pip install tensorboardX

Processing ./transformers
Building wheels for collected packages: transformers
  Building wheel for transformers (setup.py) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-2.3.0-cp36-none-any.whl size=458556 sha256=890dc92db49dd4bf5a02f6977d8360f3d786066d01318fdb9c2b0c9b516df627
  Stored in directory: /tmp/pip-ephem-wheel-cache-jxbuvok4/wheels/23/19/dd/2561a4e47240cf6b307729d58e56f8077dd0c698f5992216cf
Successfully built transformers
Installing collected packages: transformers
  Found existing installation: transformers 2.3.0
    Uninstalling transformers-2.3.0:
      Successfully uninstalled transformers-2.3.0
Successfully installed transformers-2.3.0


In [4]:
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

import numpy as np
import torch 
from transformers import BertModel

In [5]:
import json  
f = open('/content/word2idx.json',) 
data = json.load(f)

In [6]:
idx2word = {}
for key in data.keys():
    idx2word[data[key]] = key

In [7]:
# Tensors obtained from the pre-processing steps of the base line BiDAF 
# These tensors will now be transformed into equivalent BERT embeddings.

# Context file for one batch
import pickle
with open('cw_idxs.pickle', 'rb') as handle:
    cw_idxs = pickle.load(handle)

# Question file for one batch
with open('qw_idxs.pickle', 'rb') as handle:
    qw_idxs = pickle.load(handle)

# Answer starts
with open('y1.pickle', 'rb') as handle:
    y1 = pickle.load(handle)

# Answer ends
with open('y2.pickle', 'rb') as handle:
    y2 = pickle.load(handle)  

In [8]:
# shapes of pre-processed tensors

print(cw_idxs.shape)
print(qw_idxs.shape)
print(y1.shape)
print(y2.shape)

torch.Size([64, 376])
torch.Size([64, 23])
torch.Size([64])
torch.Size([64])


In [9]:
b = 64
cw_idxs = cw_idxs[:b]
qw_idxs = qw_idxs[:b]
y1 = y1[:b]
y2 = y2[:b]

print(cw_idxs.shape)
print(qw_idxs.shape)
print(y1.shape)
print(y2.shape)

torch.Size([64, 376])
torch.Size([64, 23])
torch.Size([64])
torch.Size([64])


In [10]:
# NEW SWAP_TOKENS FUNCTION
def swap_tokens(cw_idxs):
    cw_idxs_words = []
    for c in cw_idxs:
        new_list = []
        for i in c:
            new_list.append(idx2word[i.item()])
        cw_idxs_words.append(new_list) 

    sentences = []
    for l in cw_idxs_words:
        sent = []
        for i in l:
            if i=='--OOV--' or i =='--NULL--':
                continue
            else:
                sent.append(i)

        sent = ' '.join(sent)   
        sentences.append(sent)

    sentences_tokenized = []

    bert_words = []
    for s in sentences:
        tt = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(s))
        bert_words.append(tokenizer.convert_ids_to_tokens(tt))
        tt = torch.Tensor(tt).type(torch.LongTensor)
        sentences_tokenized.append(tt)

    max_len = 0
    for s in sentences_tokenized:
        max_len = max(len(s),max_len)

    sentences_tokenized_tensors = [] 
    for s in sentences_tokenized:
        tt = torch.nn.ConstantPad1d((0, max_len - s.shape[0]), 0)(s)
        sentences_tokenized_tensors.append(tt)

    CT_new = torch.Tensor([])

    for l in sentences_tokenized_tensors:
        l = l.reshape((1,l.shape[0]))
        CT_new = torch.cat((CT_new, l), 0)   

    c_mask = torch.zeros_like(CT_new) != CT_new 

    return (CT_new, c_mask, bert_words)    

In [11]:
def collect_hash_words(bert_words):
    import more_itertools as mit
    hash_words_list = []

    for sample in range(len(bert_words)):
        test_mask = []
        for i in range(len(bert_words[sample])):
            if '#' in bert_words[sample][i]:
                test_mask.append(1)
            else:
                test_mask.append(0)

        ones = []
        for i in range(len(test_mask)):
            if test_mask[i]==1:
                ones.append(i)

        start_ones = []
        for i in ones:
            start_ones.append(i-1)
        full_ones = sorted(list(set(sorted(start_ones + ones))))

        ll = [list(group) for group in mit.consecutive_groups(full_ones)] 
        hash_words_list.append(ll)

    return hash_words_list

In [12]:
# model_name_or_path = 'bert-base-uncased'
import torch
import torch.nn as nn
from transformers import BertModel
from tqdm import tqdm
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
hidden_size = 100

class Bertify(nn.Module):

  def __init__(self, hidden_size):
    super(Bertify, self).__init__()
    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
  
  def forward(self, input_ids, attention_mask):
    last_hidden_state ,_ = self.bert(input_ids=input_ids,attention_mask=attention_mask)
    output = last_hidden_state
    return output

In [13]:
def remove_hash(f, hash_words_list, hs):
    sub = []
    for l in hash_words_list[f]:
        arr = []
        for i in l:
            c = hs[f][i].detach().numpy()
            arr.append(c)

        arr = np.array(arr)
        arr = np.mean(arr, axis=0)
        sub.append((arr, l[0]))

    # sub --> [([],__),  ([],__),  ([],__)....]    

    #  Replace all means
    for s,i in sub:
        hs[f][i] = torch.Tensor(s)     

    # Remove unnecessary values
    remove = []
    for l in hash_words_list[f]:
        remove.append(l[1:])
    flat_list = [item for sublist in remove for item in sublist]  


    hs_new = torch.Tensor([])
    for i in range(len(hs[f])):
        if i in flat_list:
            continue
        else:    
            p = hs[f][i].reshape((1,-1))
            hs_new = torch.cat((hs_new, p), 0)

    return hs_new, flat_list        

In [14]:
def generate_final_bert_embeddings(cw_idxs, q_or_c):
    context, context_m, bert_words_C = swap_tokens(cw_idxs)
    hash_words_list_C = collect_hash_words(bert_words_C)

    context = torch.Tensor(context).type(torch.LongTensor)

    model = Bertify(hidden_size)
    
    with torch.no_grad():
        c_hs = model(input_ids=context[:b].reshape((b,context.shape[1])), 
                attention_mask=context_m[:b].reshape((b,context.shape[1])))
    
    all_mods = []
    all_falt_lists_C = []

    for i in range(b):
        hs_new, flat_list = remove_hash(i, hash_words_list_C, c_hs)
        all_falt_lists_C.append(flat_list)
        all_mods.append(hs_new)

    if q_or_c == 'c':
        max_len = 388
    elif q_or_c == 'q':
        max_len = 25

    all_mods_tensors = [] 
    for s in all_mods:
        tt = torch.transpose(torch.nn.ConstantPad2d((0, max_len - s.shape[0]), 0)(torch.transpose(s, 0, 1)), 0, 1)
        all_mods_tensors.append(tt) 

    rect_c_hs = torch.Tensor([])
    for l in all_mods_tensors:
        l = l.reshape((1,l.shape[0], l.shape[1]))
        rect_c_hs = torch.cat((rect_c_hs, l), 0) 

    print(rect_c_hs.shape)    

    # generating new mask
    context_np = context.numpy()       

    all_mod_mask_C = []
    for i in range(len(context_np)):
        arr = []
        for j in range(len(context_np[i])):
            if j in all_falt_lists_C[i]:
                continue
            else:
                arr.append(context_np[i][j])

        for z in range(len(all_falt_lists_C[i])):
            arr.append(0)
        all_mod_mask_C.append(arr) 

    all_mod_mask_C_pt = []

    for l in all_mod_mask_C:
        all_mod_mask_C_pt.append(torch.Tensor(l))     

    context_new = torch.Tensor([])

    for l in all_mod_mask_C_pt:
        l = l.reshape((1,l.shape[0]))
        context_new = torch.cat((context_new, l), 0)

    c_mask = torch.zeros_like(context_new) != context_new     

    print(c_mask.shape)    

    c_len = c_mask.sum(-1)
    print(c_len.shape)  
         
    return rect_c_hs, c_mask, c_len


In [15]:
c_emb_new, c_mask, c_len = generate_final_bert_embeddings(cw_idxs, 'c')

torch.Size([64, 388, 768])
torch.Size([64, 388])
torch.Size([64])


In [16]:
q_emb_new, q_mask, q_len = generate_final_bert_embeddings(qw_idxs, 'q')

torch.Size([64, 25, 768])
torch.Size([64, 25])
torch.Size([64])


Implementing the layers of BiDAF network. For more details on the formulation of the layers, refer the project handout in this repository.

In [17]:
class HighwayEncoder(nn.Module):
    
    def __init__(self, num_layers, hidden_size):
        super(HighwayEncoder, self).__init__()
        self.transforms = nn.ModuleList([nn.Linear(hidden_size, hidden_size)
                                         for _ in range(num_layers)])
        self.gates = nn.ModuleList([nn.Linear(hidden_size, hidden_size)
                                    for _ in range(num_layers)])

    def forward(self, x):
        for gate, transform in zip(self.gates, self.transforms):
            # Shapes of g, t, and x are all (batch_size, seq_len, hidden_size)
            g = torch.sigmoid(gate(x))
            t = F.relu(transform(x))
            x = g * t + (1 - g) * x

        return x


# model_name_or_path = 'bert-base-uncased'
hidden_size = 100
class Bert_Embeddings(nn.Module):

  def __init__(self, hidden_size):
    super(Bert_Embeddings, self).__init__()
    # self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.1)
    self.lin = nn.Linear(768, hidden_size)
    self.hwy = HighwayEncoder(2, hidden_size)
  
  def forward(self, c_emb):
    # last_hidden_state ,_ = self.bert(input_ids=input_ids,attention_mask=attention_mask)
    output = self.drop(c_emb)
    # output = last_hidden_state
    output = self.lin(output)
    output = self.hwy(output)
    return output        

In [18]:
class RNNEncoder(nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_size,
                 num_layers,
                 drop_prob=0.):
        super(RNNEncoder, self).__init__()
        self.drop_prob = drop_prob
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers,
                           batch_first=True,
                           bidirectional=True,
                           dropout=drop_prob if num_layers > 1 else 0.)

    def forward(self, x, lengths):
        # Save original padded length for use by pad_packed_sequence
        orig_len = x.size(1)

        # Sort by length and pack sequence for RNN
        lengths, sort_idx = lengths.sort(0, descending=True)
        x = x[sort_idx]     # (batch_size, seq_len, input_size)
        x = pack_padded_sequence(x, lengths, batch_first=True)

        # Apply RNN
        x, _ = self.rnn(x)  # (batch_size, seq_len, 2 * hidden_size)

        # Unpack and reverse sort
        x, _ = pad_packed_sequence(x, batch_first=True, total_length=orig_len)
        _, unsort_idx = sort_idx.sort(0)
        x = x[unsort_idx]   # (batch_size, seq_len, 2 * hidden_size)

        # Apply dropout (RNN applies dropout after all but the last layer)
        x = F.dropout(x, self.drop_prob, self.training)

        return x

In [19]:
class BiDAFAttention(nn.Module):

    def __init__(self, hidden_size, drop_prob=0.1):
        super(BiDAFAttention, self).__init__()
        self.drop_prob = drop_prob
        self.c_weight = nn.Parameter(torch.zeros(hidden_size, 1))
        self.q_weight = nn.Parameter(torch.zeros(hidden_size, 1))
        self.cq_weight = nn.Parameter(torch.zeros(1, 1, hidden_size))
        for weight in (self.c_weight, self.q_weight, self.cq_weight):
            nn.init.xavier_uniform_(weight)
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, c, q, c_mask, q_mask):
        batch_size, c_len, _ = c.size()
        q_len = q.size(1)
        s = self.get_similarity_matrix(c, q)        # (batch_size, c_len, q_len)
        c_mask = c_mask.view(batch_size, c_len, 1)  # (batch_size, c_len, 1)
        q_mask = q_mask.view(batch_size, 1, q_len)  # (batch_size, 1, q_len)
        s1 = masked_softmax(s, q_mask, dim=2)       # (batch_size, c_len, q_len)
        s2 = masked_softmax(s, c_mask, dim=1)       # (batch_size, c_len, q_len)

        # (bs, c_len, q_len) x (bs, q_len, hid_size) => (bs, c_len, hid_size)
        a = torch.bmm(s1, q)
        # (bs, c_len, c_len) x (bs, c_len, hid_size) => (bs, c_len, hid_size)
        b = torch.bmm(torch.bmm(s1, s2.transpose(1, 2)), c)

        x = torch.cat([c, a, c * a, c * b], dim=2)  # (bs, c_len, 4 * hid_size)

        return x

    def get_similarity_matrix(self, c, q):

        c_len, q_len = c.size(1), q.size(1)
        c = F.dropout(c, self.drop_prob, self.training)  # (bs, c_len, hid_size)
        q = F.dropout(q, self.drop_prob, self.training)  # (bs, q_len, hid_size)

        # Shapes: (batch_size, c_len, q_len)
        s0 = torch.matmul(c, self.c_weight).expand([-1, -1, q_len])
        s1 = torch.matmul(q, self.q_weight).transpose(1, 2)\
                                           .expand([-1, c_len, -1])
        s2 = torch.matmul(c * self.cq_weight, q.transpose(1, 2))
        s = s0 + s1 + s2 + self.bias

        return s

In [20]:
class BiDAFOutput(nn.Module):

    def __init__(self, hidden_size, drop_prob):
        super(BiDAFOutput, self).__init__()
        self.att_linear_1 = nn.Linear(8 * hidden_size, 1)
        self.mod_linear_1 = nn.Linear(2 * hidden_size, 1)

        self.rnn = RNNEncoder(input_size=2 * hidden_size,
                              hidden_size=hidden_size,
                              num_layers=1,
                              drop_prob=drop_prob)

        self.att_linear_2 = nn.Linear(8 * hidden_size, 1)
        self.mod_linear_2 = nn.Linear(2 * hidden_size, 1)

    def forward(self, att, mod, mask):
        # Shapes: (batch_size, seq_len, 1)
        logits_1 = self.att_linear_1(att) + self.mod_linear_1(mod)
        mod_2 = self.rnn(mod, mask.sum(-1))
        logits_2 = self.att_linear_2(att) + self.mod_linear_2(mod_2)

        # Shapes: (batch_size, seq_len)
        log_p1 = masked_softmax(logits_1.squeeze(), mask, log_softmax=True)
        log_p2 = masked_softmax(logits_2.squeeze(), mask, log_softmax=True)

        return log_p1, log_p2

In [21]:
def masked_softmax(logits, mask, dim=-1, log_softmax=False):
    
    mask = mask.type(torch.float32)
    masked_logits = mask * logits + (1 - mask) * -1e30
    softmax_fn = F.log_softmax if log_softmax else F.softmax
    probs = softmax_fn(masked_logits, dim)

    return probs

In [22]:
hidden_size = 100

class BiDAF(nn.Module):

    def __init__(self, hidden_size, drop_prob=0.):
        super(BiDAF, self).__init__()
        self.bert_emb = Bert_Embeddings(hidden_size=hidden_size)

        self.enc = RNNEncoder(input_size=hidden_size,
                                     hidden_size=hidden_size,
                                     num_layers=1,
                                     drop_prob=drop_prob)

        self.att = BiDAFAttention(hidden_size=2 * hidden_size,
                                         drop_prob=drop_prob)

        self.mod = RNNEncoder(input_size=8 * hidden_size,
                                     hidden_size=hidden_size,
                                     num_layers=2,
                                     drop_prob=drop_prob)

        self.out = BiDAFOutput(hidden_size=hidden_size,
                                      drop_prob=drop_prob)



    def forward(self, cw_idxs, qw_idxs):
        
        c_emb_new, c_mask, c_len = generate_final_bert_embeddings(cw_idxs, 'c')
        q_emb_new, q_mask, q_len = generate_final_bert_embeddings(qw_idxs, 'q')

        c_emb_new = c_emb_new.to(device)
        q_emb_new = q_emb_new.to(device)
        c_mask = c_mask.to(device)
        q_mask = q_mask.to(device)
        
        c_emb = self.bert_emb(c_emb_new)
        q_emb = self.bert_emb(q_emb_new)
        
        print("------------")
        print(c_emb.shape)
        print(q_emb.shape)

        c_enc = self.enc(c_emb, c_len)   
        q_enc = self.enc(q_emb, q_len) 

        print("------------")
        print(c_enc.shape)
        print(q_enc.shape) 

        att = self.att(c_enc, q_enc,
                       c_mask, q_mask) 
        
        print("------------")
        print(att.shape)

        mod = self.mod(att, c_len)  

        print("------------")
        print(mod.shape)

        out = self.out(att, mod, c_mask)
        print("------------")
        # print(out.shape)

        return out

In [23]:
hidden_size = 100
bidaf_model = BiDAF(hidden_size = hidden_size)
bidaf_model = bidaf_model.to(device)
log_p1, log_p2 = bidaf_model(cw_idxs, qw_idxs)

torch.Size([64, 388, 768])
torch.Size([64, 388])
torch.Size([64])
torch.Size([64, 25, 768])
torch.Size([64, 25])
torch.Size([64])
------------
torch.Size([64, 388, 100])
torch.Size([64, 25, 100])
------------
torch.Size([64, 388, 200])
torch.Size([64, 25, 200])
------------
torch.Size([64, 388, 800])
------------
torch.Size([64, 388, 200])
------------


In [24]:
loss = F.nll_loss(log_p1, y1[:b]) + F.nll_loss(log_p2, y2[:b])

Loss after one full pass through the BiDAF network

In [25]:
loss

tensor(9.9463, grad_fn=<AddBackward0>)