# New Dataloader for BERT

Huggingface BERT needs inputs only as text. Uses BertTokenizer

In [30]:
import numpy as np
import pandas as pd
import time
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [26]:
data = np.load('data/articles_comb.npy', allow_pickle=True) 

In [27]:
data.shape

(142568, 2)

In [28]:
data[40000,0]

'North Korean prison camp survivor changes story'

In [29]:
data[40000,1]

'(CNN) Shin ’s horrific descriptions of his time in a North Korean prison camp became a book, made him a key witness before the United Nations and grabbed headlines around the world. He was one of the most North Korean defectors, winning several human rights awards and inspiring a documentary'

In [53]:
len(data)

142568

In [57]:
data[3][0].split()

['Among',
 'Deaths',
 'in',
 '2016,',
 'a',
 'Heavy',
 'Toll',
 'in',
 'Pop',
 'Music',
 '-',
 'The',
 'New',
 'York',
 'Times']

In [58]:
max_title = 0
max_text = 0
for i in range(len(data)):
    title = data[i][0].split()
    if max_title< len(title):
        max_title = len(title)
    
    text = data[i][1].split()
    if max_text< len(text):
        max_text = len(text)

In [59]:
max_title

33

In [60]:
max_text

50

In [87]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, pt_model = 'bert-base-uncased'):
        '''
        Args:
        root_dir (string): Path to npy directory
        pt_model (string): Which pre-trained model to use
        
        Outputs:
        title (torch.tensor): indexed input for model
        text (torch.tensor): indexed target for model
        title_mask (torch.tensor): masking for titles
        text_mask (torch.tensor): masking for text
        '''
        self.data = np.load(root_dir, allow_pickle=True) 
        self.tokenizer = BertTokenizer.from_pretrained(pt_model)
        #self.max_len_title = max_title
        #self.max_len_text = max_text
        #self.word2idx, self.idx2word = self.indexify_vocab(self.data)
        
    
    def __len__(self):
        return len(self.data) 
    
    def __getitem__(self, idx):
        #load text and title data and tokenize with ready-made tokenizer
        #tokenized_title = self.tokenizer(self.data[idx, 0], return_tensors='pt', padding=True)
        #tokenized_text = self.tokenizer(self.data[idx, 1], return_tensors='pt',padding = True)
        
        # get indexed text & title
        #title = tokenized_title['input_ids']
        #text = tokenized_text['input_ids']
        
        # masking for text (+2 in torch.ones comes from adding start and stop tokens)
        #title_mask = tokenized_title['attention_mask']
        #text_mask = tokenized_text['attention_mask']
        
        #convert to tensors
        #title = title.to(device)
        #text = text.to(device)
        
        title = self.data[idx, 0]
        text = self.data[idx, 1]
        
        return title, text  #, title_mask, text_mask
    

## Testing

In [88]:
start = time.time()
ds = CustomDataset("data/articles_comb.npy")
print(time.time()-start)

2.053868055343628


In [89]:
test, train = torch.utils.data.random_split(ds, [round(len(data)*0.2), round(len(data)*0.8)])

In [103]:
trainloader = DataLoader(dataset=train, batch_size=8, shuffle=True)

In [104]:
for i, hopo in enumerate(trainloader):
    if i==1:
        break
    print(hopo[0])
    print(hopo[1])

('On The State Dinner Guest List, A ’Letter Writer’ Ignites Imaginations', 'Thirteen killed, 31 injured in California tour bus crash', 'Trump in Mexico: ’Who Pays for the Wall? We Didn’t Discuss it’', 'Artist Accused of Disowning a Painting Testifies - The New York Times', 'One Police Shift: Patrolling an Anxious America - The New York Times', 'WATCH: Apple Stores in Bay Area Targeted by Gangs of Thieves - Breitbart', 'U.S. Open 2015: Jordan Spieth wins second major title', 'The ‘Strength’ of Vladimir Putin')


In [105]:
hopo[0][4]

'Sources: Russian Government Intercepted and Will Release Hillary’s Emails - Breitbart'

In [106]:
hopo[1][4]

'WASHINGTON DC — Sources are reporting that the Russian government is prepared to release private emails that it obtained from Hillary Clinton’s email server, proving Clinton allowed classified information to end up in the hands of foreign adversaries. [Intelligence sources are bracing for the Russian release of Clinton’s intercepted emails,'

In [107]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [108]:
hopo[0]

('This is the best way to escape bad Christmas\xa0music',
 'Father’s Day Fast Facts',
 'Poll: Donald Trump Surges to New High, Doubling Closest Competitor - Breitbart',
 'Joe Girardi’s magic is gone — and so is his\xa0patience',
 'Sources: Russian Government Intercepted and Will Release Hillary’s Emails - Breitbart',
 'Human rights groups vow to challenge burkini ban on Cannes beaches',
 '17 Colombian Nationals Arrested for Burglarizing Houston Communities',
 'WSJ: Donald Trump Overhauls Campaign Team Stephen K. Bannon Named CEO - Breitbart')

In [111]:
encoded_target = tokenizer(hopo[0], return_tensors='pt', padding=True)

In [112]:
encoded_target

{'input_ids': tensor([[  101,  2023,  2003,  1996,  2190,  2126,  2000,  4019,  2919,  4234,
          2189,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2269,  1521,  1055,  2154,  3435,  8866,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  8554,  1024,  6221,  8398, 12058,  2015,  2000,  2047,  2152,
          1010, 19383,  7541, 12692,  1011,  7987, 20175,  8237,  2102,   102,
             0,     0,     0,     0],
        [  101,  3533, 21025, 25561,  2072,  1521,  1055,  3894,  2003,  2908,
          1517,  1998,  2061,  2003,  2010, 11752,   102,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  4216,  1024,  2845,  2231, 16618,  1998,  2097,  2713, 18520,
          1521,  1055, 22028,  1011,  7987, 20175,  8237,  2102,   102,     0,
             0,     0,     0,     0],
      