# Pipeline until input is encoded by BERT

In [1]:
import numpy as np
import pandas as pd
import time
import torch
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
data = np.load('data/articles_comb.npy', allow_pickle=True) 

In [4]:
class CustomDataset(Dataset):
    def __init__(self, root_dir='data/articles_comb.npy', pt_model = 'bert-base-uncased', max_title=50, max_text=75):
        '''
        Args:
        root_dir (string): Path to npy directory
        pt_model (string): Which pre-trained model to use
        
        Outputs:
        title (torch.tensor): indexed input for model
        text (torch.tensor): indexed target for model
        title_mask (torch.tensor): masking for titles
        text_mask (torch.tensor): masking for text
        '''
        self.data = np.load(root_dir, allow_pickle=True) 
        self.tokenizer = BertTokenizer.from_pretrained(pt_model)
        self.max_title = max_title
        self.max_text = max_text
        
    
    def __len__(self):
        return len(self.data) 
    
    def __getitem__(self, idx):
        #load text and title data and tokenize with ready-made tokenizer
        tokenized_title = self.tokenizer(self.data[idx, 0], return_tensors='pt', padding='max_length',truncation=True, max_length=self.max_title)
        tokenized_text = self.tokenizer(self.data[idx, 1], return_tensors='pt', padding='max_length',truncation=True, max_length=self.max_text)
        
        # get indexed text & title
        title = tokenized_title['input_ids']
        text = tokenized_text['input_ids']
        
        # masking for text (+2 in torch.ones comes from adding start and stop tokens)
        title_mask = tokenized_title['attention_mask']
        text_mask = tokenized_text['attention_mask']
        
        #move to device
        title = title.to(device)
        text = text.to(device)
        title_mask = title_mask.to(device)
        text_mask = text_mask.to(device)
        
        return title, text , title_mask, text_mask

## Encoder

In [5]:
start = time.time()
ds = CustomDataset("data/articles_comb.npy")
print(time.time()-start)

2.161966323852539


In [6]:
test, train = torch.utils.data.random_split(ds, [round(len(data)*0.2), round(len(data)*0.8)])

In [7]:
trainloader = DataLoader(dataset=train, batch_size=8, shuffle=True)

In [18]:
encoder = BertModel.from_pretrained("bert-base-uncased")
encoder.to(device)


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [19]:
for i, data in enumerate(trainloader):
    if i==1:
        break
    targets = data[0].squeeze()
    inputs = data[1].squeeze()
    targets_masks = data[2].squeeze()
    inputs_masks = data[3].squeeze()
    print(inputs.shape)
    print(inputs_masks.shape)
    print(targets_masks.shape)
    print(targets.shape)
    encoded_inputs = encoder(inputs, attention_mask=inputs_masks)
    encoded_inputs = encoded_inputs[0].permute(1,0,2)
    
    # Next encoded would go through decoder and then a loss would be calculated

torch.Size([8, 75])
torch.Size([8, 75])
torch.Size([8, 50])
torch.Size([8, 50])


In [25]:
print(encoded_inputs[0].shape) #This is the last hidden shape(batch, len, dim)

#transformer decoder wants shape (len, batch, dim)
print(encoded_inputs[0].permute(1,0,2).shape)

torch.Size([8, 75, 768])
torch.Size([75, 8, 768])
