In [None]:
from typing import Union
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader

import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import pandas as pd
import numpy as np

In [None]:
torch.__version__, torchtext.__version__

('1.10.0+cu111', '0.11.0')

In [None]:
# downloading dataset
! wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv

--2021-11-27 09:07:54--  http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
Resolving qim.fs.quoracdn.net (qim.fs.quoracdn.net)... 151.101.1.2, 151.101.65.2, 151.101.129.2, ...
Connecting to qim.fs.quoracdn.net (qim.fs.quoracdn.net)|151.101.1.2|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58176133 (55M) [text/tab-separated-values]
Saving to: ‘quora_duplicate_questions.tsv’


2021-11-27 09:07:56 (144 MB/s) - ‘quora_duplicate_questions.tsv’ saved [58176133/58176133]



# The Data

In [None]:
df = pd.read_csv("quora_duplicate_questions.tsv", sep="\t")
df

Unnamed: 0,id,qid1,qid2,question1,question2,is_duplicate
0,0,1,2,What is the step by step guide to invest in sh...,What is the step by step guide to invest in sh...,0
1,1,3,4,What is the story of Kohinoor (Koh-i-Noor) Dia...,What would happen if the Indian government sto...,0
2,2,5,6,How can I increase the speed of my internet co...,How can Internet speed be increased by hacking...,0
3,3,7,8,Why am I mentally very lonely? How can I solve...,Find the remainder when [math]23^{24}[/math] i...,0
4,4,9,10,"Which one dissolve in water quikly sugar, salt...",Which fish would survive in salt water?,0
...,...,...,...,...,...,...
404285,404285,433578,379845,How many keywords are there in the Racket prog...,How many keywords are there in PERL Programmin...,0
404286,404286,18840,155606,Do you believe there is life after death?,Is it true that there is life after death?,1
404287,404287,537928,537929,What is one coin?,What's this coin?,0
404288,404288,537930,537931,What is the approx annual cost of living while...,I am having little hairfall problem but I want...,0


In [None]:
len(df)

404290

In [None]:
df.isnull().values.sum()

3

Looks like there are only three rows with NA values, so we can just drop them.

In [None]:
df.dropna(inplace=True)
df.isnull().values.sum()

0

# The Dataset

In [None]:
class QuoraDuplicateQuestions(Dataset):
    def __init__(self, root: Union[str, Path]):
        super().__init__()
        path = Path(root) if isinstance(root, str) else root
        df_ = pd.read_csv(path/"quora_duplicate_questions.tsv", sep="\t")
        self.df = df_[df_["is_duplicate"] == 1]

        self.tokenizer = get_tokenizer("basic_english")
        def yield_tokens(dataframe: pd.DataFrame):
            for row in dataframe.itertuples():
                yield self.tokenizer(row.question1) + self.tokenizer(row.question2)
        
        self.vocab = build_vocab_from_iterator(yield_tokens(self.df), specials=["<unk>", "<sos>", "<eos>", "<pad>"])
        self.unk_idx = self.vocab["<unk>"]
        self.eos_idx = self.vocab["<eos>"]
        self.sos_idx = self.vocab["<sos>"]
        self.pad_idx = self.vocab["<pad>"]
        self.vocab.set_default_index(self.unk_idx)

        self.text_pipeline = lambda x: self.vocab(self.tokenizer(x))
        self.label_pipeline = lambda x: self.vocab(self.tokenizer(x))
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return self.text_pipeline(row["question1"]), self.label_pipeline(row["question2"])

    def __len__(self):
        return len(self.df)

    def collate_fn(self):
        def wrapper(batch):
            texts, labels = zip(*batch)
            lengths = torch.LongTensor([len(s) for s in texts])

            # adding the SOS and EOS tokens
            texts = [
                torch.cat([
                    torch.tensor([self.sos_idx]), 
                    torch.tensor(s), 
                    torch.tensor([self.eos_idx])
                ]) for s in texts
            ]
            labels = [
                torch.cat([
                    torch.tensor([self.sos_idx]), 
                    torch.tensor(l), 
                    torch.tensor([self.eos_idx])
                ]) for l in labels
            ]

            # adding padding
            texts = torch.nn.utils.rnn.pad_sequence(texts, padding_value=self.pad_idx, batch_first=True)
            labels = torch.nn.utils.rnn.pad_sequence(labels, padding_value=self.pad_idx, batch_first=True)

            return texts, labels, lengths

        return wrapper

In [None]:
dataset = QuoraDuplicateQuestions(root=".")
loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True, collate_fn=dataset.collate_fn())

In [None]:
texts, labels, lengths = next(iter(loader))
texts.shape, labels.shape, lengths.shape

(torch.Size([16, 23]), torch.Size([16, 24]), torch.Size([16]))

In [None]:
for t in texts:
    print(" ".join(dataset.vocab.get_itos()[x] for x in t))

<sos> astrology i am a capricorn sun cap moon and cap rising . . . what does that say about me ? <eos>
<sos> how can i be a good geologist ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> how do i read and find my youtube comments ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> what can make physics easy to learn ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> what was your first sexual experience like ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> what would a trump presidency mean for current international master’s students on an f1 visa ? <eos> <pad> <pad> <pad> <pad> <pad>
<sos> what does manipulation mean ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> why are so many quora users posting questions that are readily answered on google ? <eos> <pad> <pad> <pad> 

In [None]:
for l in labels:
    print(" ".join(dataset.vocab.get_itos()[x] for x in l))

<sos> i ' m a triple capricorn ( sun , moon and ascendant in capricorn ) what does this say about me ? <eos>
<sos> what should i do to be a great geologist ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> how can i see all my youtube comments ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> how can you make physics easy to learn ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> what was your first sexual experience ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> how will a trump presidency affect the students presently in us or planning to study in us ? <eos> <pad> <pad> <pad> <pad>
<sos> what does manipulation means ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> why do people ask quora questions which can be answered easily by google ? <eos> 

# Trying out BucketIterator

In [None]:
from torchtext.legacy.data import BucketIterator

In [None]:
bucketiter = BucketIterator(dataset, batch_size=16, sort_key = lambda x: len(x[0]), sort=False, shuffle=True, sort_within_batch=True, device="cpu")

In [None]:
a = next(iter(bucketiter))

AttributeError: ignored

I guess bucket iterator is not supported in the new api, and googling this weird error doesnt yield any results. On inspection it looks like torchtext legacy's Batch class is attempting to call the `fields` attribute on our dataset, which does not exist, and that causes the error. No idea what the `fields` attribute should be, I guess it has something to do with the legacy api which I'm not using here.