In [1]:
import pandas as pd
import re
import spacy
import numpy as np
import gzip
import gensim.downloader
import torch
from torch.utils.data import Dataset, DataLoader
import pickle

In [2]:
class WordDataset(Dataset):
    
    def __init__(self, sfile):
        d = pickle.load(open(sfile, 'rb'))
        
        self.sentences = d['sentences']
        self.indices = d['sub_indices']
        
        for sent in self.sentences:
            for i, word in enumerate(sent):
                sent[i] = word.lower().replace("'", "")
                if sent[i] == '':
                    del sent[i]
        
        self.glove = gensim.downloader.load('glove-wiki-gigaword-200')
        
        self.sentence_embeddings = []
        
        for sent in self.sentences:
            temp = []
            for word in sent:
                if word in self.glove:
                    temp.append(self.glove[word])
            self.sentence_embeddings.append(np.array(temp))
            
    def __len__(self):
        return len(self.sentence_embeddings)
    
    def __getitem__(self, idx):
        return (self.sentences[idx], self.sentence_embeddings[idx], self.indices[idx])

In [3]:
train_dataset = WordDataset('subdata.pkl')
test_dataset = WordDataset('subdata_test.pkl')

In [4]:
print(len(train_dataset))
print(train_dataset[1])

print(len(test_dataset))
print(test_dataset[1])

3586
(['chancellor', 'of', 'the', 'exchequer', 'nigel', 'lawsons', 'restated', 'commitment', 'to', 'a', 'firm', 'monetary', 'policy', 'has', 'helped', 'to', 'prevent', 'a', 'freefall', 'in', 'sterling', 'over', 'the', 'past', 'week', '.'], array([[ 0.091849, -0.27507 , -0.040291, ...,  0.2757  , -0.25509 ,
        -0.52322 ],
       [ 0.052924,  0.25427 ,  0.31353 , ..., -0.086254, -0.41917 ,
         0.46496 ],
       [-0.071549,  0.093459,  0.023738, ...,  0.33617 ,  0.030591,
         0.25577 ],
       ...,
       [-0.33563 ,  0.17808 , -0.43981 , ...,  0.39282 , -0.018467,
         0.41027 ],
       [-0.28622 ,  0.61687 , -0.42819 , ..., -0.049013,  0.040753,
         0.057147],
       [ 0.12289 ,  0.58037 , -0.069635, ..., -0.039174, -0.16236 ,
        -0.096652]], dtype=float32), 7)
881
(['rockwell', ',', 'based', 'in', 'el', 'segundo', ',', 'calif.', ',', 'is', 'an', 'aerospace', ',', 'electronics', ',', 'automotive', 'and', 'graphics', 'concern', '.'], array([[ 0.13841  , -0.06