In [1]:
import pandas as pd

In [2]:
from torch.utils.data import Dataset
from torchvision import transforms

In [82]:
noop_tok = '@@@'
sentence_length = 128
fpath = '../data/quora/train.csv'

In [47]:
class QuoraDataset(Dataset):
    def __init__(self, fpath):
        self.fpath = fpath
        self.df = pd.read_csv(fpath)
        self.df = self.df.dropna()
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        item = self.df.iloc[index]
        return [item['is_duplicate'], item['question1'], item['question2']]
        

In [48]:
qdata = QuoraDataset(fpath)
qdata[2], len(qdata)

([0,
  'How can I increase the speed of my internet connection while using a VPN?',
  'How can Internet speed be increased by hacking through DNS?'],
 404287)

In [76]:
from gensim.corpora.dictionary import Dictionary
documents = qdata.df['question1'] + qdata.df['question2'] + [noop_tok]
documents = [sentence.split(' ') for sentence in documents]

In [78]:
class FixSentencesLength(object):
    def __init__(self, sentence_length=128, padding_token=noop_tok):
        self.sentence_length = sentence_length
        self.padding_token = padding_token
        
    def fix_length(self, sentence):
        sentence = sentence.split(' ')
        if len(sentence) >= self.sentence_length:
            return ' '.join(sentence[:self.sentence_length])
        else:
            return ' '.join(sentence + [self.padding_token]*(self.sentence_length - len(sentence)))
         
    def __call__(self, item):
        assert len(item) == 3
        item[1] = self.fix_length(item[1])
        item[2] = self.fix_length(item[2])
        return item
    
class TokToID(object):
    def __init__(self, documents=[]):
        self.dictionary = Dictionary(documents)
        
    def toID(self, sentence):
        sentence = sentence.split(' ')
        return self.dictionary.doc2idx(sentence)
        
    def __call__(self, item):
        item[1] = self.toID(item[1])
        item[2] = self.toID(item[2])
        return item

In [81]:
transformer = transforms.Compose([
    FixSentencesLength(sentence_length),
    TokToID(documents)
])
qdata[0], transformer(qdata[0])

([0,
  'What is the step by step guide to invest in share market in india?',
  'What is the step by step guide to invest in share market?'],
 [0,
  [0,
   6,
   11,
   10,
   1,
   10,
   2,
   12,
   5,
   3,
   9,
   7,
   3,
   68583,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   1455,
   145