In [1]:
import os
import sys
import torch

PROJ_DIR = os.path.join(os.environ['WORKSPACE'], 'tutorial/')

if PROJ_DIR not in sys.path:
    sys.path.append(PROJ_DIR)

# Download the IMDB dataset

In [2]:
!cd .. \
&& [ ! -f aclImdb_v1.tar.gz ] \
&& wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \
&& tar -xzf  aclImdb_v1.tar.gz || echo "Data (most likely) already downloaded"

Data (most likely) already downloaded


# Read the files and tokenize the data

In [3]:
import random
import copy
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset

random.seed(2)

def read_files(datadir, sentiment, maxlen):
    sent_dir = os.path.join(datadir, sentiment)
    
    tokens = [word_tokenize(open(os.path.join(sent_dir, sent_file)).read())[:maxlen]
              for sent_file in os.listdir(sent_dir)
              if sent_file.endswith('.txt')]
    labels = [sentiment] * len(tokens)
    
    return tokens, labels
    
def shuffle(tokens, labels):
    z = list(zip(tokens, labels))
    random.shuffle(z)
    return zip(*z)

In [4]:
class IMDBDatset(Dataset):
    def __init__(self, datadir, maxlen=64):
        assert os.path.exists(datadir), datadir
        
        self.tokens = []
        self.labels = []
        
        self.maxlen = maxlen
        self.label_to_index = {'pos': 1, 'neg': 0}
        
        pos_tokens, pos_labels = read_files(datadir, 'pos', maxlen)
        neg_tokens, neg_labels = read_files(datadir, 'neg', maxlen)
        
        self.tokens.extend(pos_tokens + neg_tokens)
        self.labels.extend(pos_labels + neg_labels)
        
        self.tokens, self.labels = shuffle(self.tokens, self.labels)
        self.labels = [self.label_to_index[label] for label in self.labels]

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, item):
        return self.tokens[item], self.labels[item]
    
    def split_data(self, size):
        dataset = copy.deepcopy(self)
        dataset.tokens = dataset.tokens[-size:]
        dataset.labels = dataset.labels[-size:]

        self.tokens = self.tokens[:-size]
        self.labels = self.labels[:-size]

        return dataset

In [5]:
train = IMDBDatset(os.path.join(PROJ_DIR, 'aclImdb/train'))
test = IMDBDatset(os.path.join(PROJ_DIR, 'aclImdb/test'))

len(train), len(test)

(25000, 25000)

In [6]:
print(train[1])

(['This', 'slick', 'and', 'gritty', 'film', 'consistently', 'delivers', '.', 'It', "'s", 'one', 'of', 'Frankenheimer', "'s", 'best', 'and', 'most', 'underrated', 'films', 'and', 'it', "'s", 'easily', 'the', 'best', 'Elmore', 'Leonard', 'adaptation', 'to', 'date', '(', 'and', 'if', 'you', 'are', 'scratching', 'your', 'head', 'thinking', '``', 'but', 'I', 'loved', 'GET', 'SHORTY', "''", 'you', 'need', 'to', 'be', 'punched', 'in', 'the', 'face', ')', '.', 'In', 'my', 'opinion', ',', 'no', 'one', 'captures', 'the'], 1)


In [7]:
_ = train.split_data(10000) # Drop some data 
_ = test.split_data(15000) # Drop some data 

dev = train.split_data(5000)

In [8]:
len(train), len(dev), len(test)

(10000, 5000, 10000)

# Inspect the dataset

In [9]:
for i in range(len(train)):
    tokens, label = train[i]
    if len(tokens) <= 20:
        print(i, label)
        print(tokens)
        print()

1718 0
['Ming', 'The', 'Merciless', 'does', 'a', 'little', 'Bardwork', 'and', 'a', 'movie', 'most', 'foul', '!']

4895 1
['Adrian', 'Pasdar', 'is', 'excellent', 'is', 'this', 'film', '.', 'He', 'makes', 'a', 'fascinating', 'woman', '.']

5336 0
['This', 'movie', 'is', 'terrible', 'but', 'it', 'has', 'some', 'good', 'effects', '.']

8134 1
['This', 'is', 'the', 'definitive', 'movie', 'version', 'of', 'Hamlet', '.', 'Branagh', 'cuts', 'nothing', ',', 'but', 'there', 'are', 'no', 'wasted', 'moments', '.']



# Dump it to avoid processing again

In [10]:
import pickle

corpus = {
    'train': train,
    'dev': dev,
    'test': test
}

with open('data.pickle', 'wb') as fp:
    pickle.dump(corpus, fp)

with open('data.pickle', 'rb') as fp:
    corpus = pickle.load(fp)

In [11]:
len(corpus['train']), len(corpus['dev']), len(corpus['test'])

(10000, 5000, 10000)