In [103]:
# import dependencies
import nltk
import json
import gzip
import torch
import string
import random
import jsonlines
import pandas as pd
import pickle as pkl
import numpy as np
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import Dataset, RandomSampler, SequentialSampler, DataLoader

In [2]:
USE_CUDA = torch.cuda.is_available()
gpus = [0]
torch.cuda.set_device(gpus[0])

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [196]:
def convCategories(categories, category_to_index):
    return [category_to_index[category] for category in categories]

In [197]:
def train_validate_test_split(df, train_percent=.8, validate_percent=.1, seed=None):
    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]
    return train, validate, test

In [198]:
def tokenize_dataset(dataset, word_to_index):
    _current_dictified = []
    for l in tqdm(dataset['tokens']):
        encoded_l = [word_to_index[i] if i in word_to_index else word_to_index['<UNK>'] for i in l]
        _current_dictified.append(encoded_l)
    return _current_dictified

In [3]:
OUTPUT_FILE = 'wikitext_tokenized.p'
wiki_df =  pkl.load(open(OUTPUT_FILE, "rb"))

In [86]:
categories = []
for i in list(wiki_df['mid_level_categories']):
    categories.extend(i)
categories = list(set(categories))

In [89]:
category_to_index = {categories[i]:i for i in range(0, len(categories))}
index_to_category = {v:k for k, v in word_to_index.items()}

In [93]:
wiki_df['category_tokens'] = wiki_df.apply(lambda row: convCategories(row['mid_level_categories'], category_to_index), axis=1)


In [95]:
wiki_train, wiki_valid, wiki_test = train_validate_test_split(wiki_df)

In [98]:
y_train = list(wiki_train['category_tokens'])
y_val = list(wiki_valid['category_tokens'])

In [96]:
vocab = set([y for x in list(wiki_train['tokens']) for y in x])

In [97]:
len(vocab)

605558

In [8]:
word_to_index = {"<PAD>":0, "<UNK>":1}
for word in vocab:
    if(word not in word_to_index):
        word_to_index[word]=len(word_to_index)
index_to_word = {v:k for k, v in word_to_index.items()}

In [75]:
wiki_tokenized_train = tokenize_dataset(wiki_train, word_to_index)

100%|██████████| 82547/82547 [00:08<00:00, 9417.68it/s] 


In [82]:
wiki_tokenized_val = tokenize_dataset(wiki_valid, word_to_index)

100%|██████████| 10318/10318 [00:02<00:00, 4042.74it/s]


In [83]:
wiki_tokenized_datasets = {}
wiki_tokenized_datasets['train'] = wiki_tokenized_train
wiki_tokenized_datasets['val'] = wiki_tokenized_val

In [100]:
wiki_tensor_dataset = {}
wiki_tensor_dataset['train'] = TensoredDataset(wiki_tokenized_datasets['train'], y_train)
wiki_tensor_dataset['val'] = TensoredDataset(wiki_tokenized_datasets['val'], y_val)

In [185]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    
    for t in list_of_tensors:
        #print(t.reshape(1, -1).shape)
        #print(torch.tensor([[pad_token]*(max_length - t.size(-1))])[0].shape)
        padded_tensor = torch.cat([t.reshape(1, -1), torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
    padded_tensor = torch.cat(padded_list, dim=0)
    return padded_tensor

def pad_collate_fn(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    
    #pad_token = persona_dict.get_id('<pad>')
    pad_token = word_to_index['<PAD>']
    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    #target_tensor = pad_list_of_tensors(target_list, pad_token)
    
    return input_tensor, target_list

In [186]:
wiki_loaders = {}

batch_size = 32

for split, wiki_dataset in wiki_tensor_dataset.items():
    wiki_loaders[split] = DataLoader(wiki_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)

In [189]:
for i, (inp, target) in enumerate(wiki_loaders['train']):
    #print(inp, target)
    break