In [9]:
import torch
import torchtext

### Create Dataset Object


In [1]:
from torchtext.datasets import IMDB
train_iter, test_iter = IMDB(split=('train', 'test'))

In [2]:
next(iter(train_iter))

(1,
 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far betwee

### Build Data Processing Pipeline

In [3]:
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('basic_english')

In [12]:
from collections import Counter
from torchtext.vocab import vocab

train_iter = IMDB(split='train')
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
vocabulary = vocab(counter, min_freq=10, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))
vocabulary.set_default_index(vocabulary['<unk>'])

In [5]:
print("The length of the new vocab is", len(vocabulary))
new_stoi = vocabulary.get_stoi()
print("The index of '<BOS>' is", new_stoi['<BOS>'])
new_itos = vocabulary.get_itos()
print("The token at index 2 is", new_itos[2])

The length of the new vocab is 13020
The index of '<BOS>' is 1
The token at index 2 is <EOS>


In [6]:
text_transform = lambda x: [vocabulary['<BOS>']] + [vocabulary[token] for token in tokenizer(x)] + [vocabulary['<EOS>']]
label_transform = lambda x: 1 if x == 'pos' else 0

# Print out the output of text_transform
print("input to the text_transform:", "here is an example")
print("output of the text_transform:", text_transform("here is an example"))

input to the text_transform: here is an example
output of the text_transform: [1, 938, 54, 195, 3244, 2]


### Generate Batch Iterator

In [25]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
   label_list, text_list = [], []
   for (_label, _text) in batch:
      label_list.append(label_transform(_label))
      processed_text = torch.tensor(text_transform(_text))
      text_list.append(processed_text)
   return pad_sequence(text_list, padding_value=3.0), torch.tensor(label_list)

train_iter = IMDB(split='train')
train_dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True, collate_fn=collate_batch)

#### BucketIterator

In [None]:
import random

train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8  # A batch size of 8

def batch_sampler():
    indices = [(i, len(tokenizer(s[1]))) for i, s in enumerate(train_list)]
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
                               collate_fn=collate_batch)

print(next(iter(bucket_dataloader)))