In [1]:
import collections
import datasets
import functools
import mininlp
import random
import spacy
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm

In [2]:
seed = 1234

torch.manual_seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Loading the Dataset

In [3]:
imdb = datasets.load_dataset('imdb')

Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [4]:
imdb

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [5]:
imdb['train'][0]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}

In [6]:
train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])

Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [7]:
print(len(train_data), len(test_data))

25000 25000


In [8]:
train_data[0]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}

In [9]:
def get_train_valid_split(train_data, valid_ratio=0.2, shuffle=True):
    data = train_data.train_test_split(test_size=valid_ratio, shuffle=shuffle)
    train_data = data['train']
    valid_data = data['test']
    return train_data, valid_data

In [10]:
valid_ratio = 0.2
shuffle = True

train_data, valid_data = get_train_valid_split(train_data, valid_ratio, shuffle)

Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-920dca5bb59550b9.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-65748ac44d181710.arrow


In [11]:
print(len(train_data), len(valid_data), len(test_data))

20000 5000 25000


## Initializing the Tokenizer

In [12]:
tokenize_fn = lambda x : x.split()

In [13]:
tokenizer = mininlp.tokenizer.Tokenizer(tokenize_fn)

In [14]:
example_string = 'Hello world! How is everyone doing today?'

In [15]:
tokenizer.tokenize(example_string)

['Hello', 'world!', 'How', 'is', 'everyone', 'doing', 'today?']

In [16]:
nlp = spacy.load('en_core_web_sm')

def spacy_tokenize(s: str, nlp: spacy.lang):
    return [t.text for t in nlp.tokenizer(s)]

In [17]:
spacy_tokenize(example_string, nlp)

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']

In [18]:
_spacy_tokenize = functools.partial(spacy_tokenize, nlp=nlp)

In [19]:
_spacy_tokenize(example_string)

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']

In [20]:
tokenizer = mininlp.tokenizer.Tokenizer(_spacy_tokenize)

In [21]:
tokenizer.tokenize(example_string)

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']

## Building the Vocabulary

In [22]:
field = 'text'

counter = mininlp.vocab.build_vocab_counter(train_data, field, tokenizer)

In [23]:
counter.most_common(10)

[('the', 232322),
 (',', 220773),
 ('.', 189909),
 ('a', 125392),
 ('and', 125260),
 ('of', 115263),
 ('to', 107115),
 ('is', 87381),
 ('in', 70335),
 ('I', 61975)]

In [24]:
min_freq = 6
max_size = 30_000

In [25]:
vocab = mininlp.vocab.Vocab(counter, min_freq, max_size)

In [26]:
len(vocab)

28386

In [27]:
vocab.stoi('Hello')

7594

In [28]:
vocab.stoi('hello')

11977

In [29]:
vocab.itos(11977)

'hello'

In [30]:
vocab.stoi('Cthulhu')

0

In [31]:
vocab.itos(0)

'<unk>'

In [32]:
example_string = 'Hello world! How is everyone doing today?'

example_tokens = tokenizer.tokenize(example_string)

print(example_tokens)

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']


In [33]:
vocab.stoi(example_tokens)

[7594, 223, 42, 568, 9, 353, 428, 572, 58]

In [34]:
vocab.itos(vocab.stoi(example_tokens))

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']

In [35]:
example_string = 'My best friend is named Cthulhu'

example_tokens = tokenizer.tokenize(example_string)

vocab.itos(vocab.stoi(example_tokens))

['My', 'best', 'friend', 'is', 'named', '<unk>']

## Creating the DataLoader

In [36]:
text_transforms = mininlp.transforms.sequential_transforms(tokenizer.tokenize,
                                                           vocab.stoi,
                                                           mininlp.transforms.to_longtensor)

In [37]:
label_transforms = mininlp.transforms.sequential_transforms(mininlp.transforms.to_longtensor)

In [38]:
train_dataset = mininlp.dataset.TextClassificationDataset(train_data, text_transforms, label_transforms)

In [39]:
train_dataset[0]

(tensor([ 7473,  2616,    11,   277,     8,    36,    43,   348,   836,    18,
           707,    31,    47,  5534,    15,  4419,    31,   264,    47,  2955,
            18,  4015,     2,    22,     9,   176,     6,    67,    32,    82,
           343,   323,     6,    59,     7,     2,   131,    69,   648,    36,
           144,  2160,    20,    14,   993,    14,    26,     2,    75,     9,
         16574,    18,   116,    75,     7,   111,     2,   223,    69,    36,
             6,   112,  1674,    29,  1024,     9,    62,     0,    18,  2536,
             2,   150,     2,   925,    19,    28,     5,   925,    26,     2,
            75, 12995,    10,    43,   851,    40,     2,   233, 22119,    18,
           425,    15,  9308,   715,   293,  4716,    63,    36,    14,  2067,
          4369,    14,    10,  1338,     7,   111,     2,  1174,     9,   644,
             8,    36,   907,    17,   223,   790,     4]),
 tensor(1))

In [40]:
train_dataset.data[0]

{'label': 1,
 'text': 'Soylent Green I found to be an excellent movie.<br /><br />If you like Logan\'s Run you\'ll like this.<br /><br />Yes the movie is old and there are no special effects and some of the acting can somewhat be best described as "cheesy" but the story is excellent.<br /><br />The story of how the world can be and its impact on society is very poignant.<br /><br />At the end the mystery wasn\'t a mystery but the story unfolded in an easy at the right pace.<br /><br />It\'s nearest modern day equivalent would be "Dark Angel" in terms of how the US is shown to be third-world country.'}

In [41]:
vocab.stoi('Soylent')

7473

In [42]:
valid_dataset = mininlp.dataset.TextClassificationDataset(valid_data, text_transforms, label_transforms)

In [43]:
test_dataset = mininlp.dataset.TextClassificationDataset(test_data, text_transforms, label_transforms)

In [44]:
pad_idx = vocab.stoi(vocab.pad_token)

print(pad_idx)

1


In [45]:
collator = mininlp.collator.TextClassificationCollator(pad_idx)

In [46]:
batch_size = 256

In [47]:
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           collate_fn=collator.collate)

In [48]:
valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=False, 
                                           collate_fn=collator.collate)

In [49]:
test_loader = torch.utils.data.DataLoader(train_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False, 
                                          collate_fn=collator.collate)

## Creating the NBOW model

In [50]:
class NBOW(nn.Module):
    def __init__(self, input_dim: int, emb_dim: int, output_dim: int, pad_idx: int):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)
        self.fc = nn.Linear(emb_dim, output_dim)
        
    def forward(self, text: torch.LongTensor) -> torch.FloatTensor:
        
        # text = [seq len, batch size]
        
        embedded = self.embedding(text)
        
        # embedded = [seq len, batch size, emb dim]
        
        pooled = embedded.mean(0)
        
        # pooled = [batch size, emb dim]
        
        prediction = self.fc(pooled)
        
        # prediction = [batch size, output dim]
        
        return prediction

In [51]:
input_dim = len(vocab)
emb_dim = 100
output_dim = 2

model = NBOW(input_dim, emb_dim, output_dim, pad_idx)

In [52]:
print(f'The model has {mininlp.utils.count_parameters(model):,} trainable parameters')

The model has 2,838,802 trainable parameters


In [53]:
optimizer = optim.Adam(model.parameters())

In [54]:
criterion = nn.CrossEntropyLoss()

In [55]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using: {device}')

Using: cuda


In [56]:
model = model.to(device)
criterion = criterion.to(device)

In [57]:
def train(model, data_loader, optimizer, criterion, device):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for text, labels in data_loader:
        
        text = text.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        predictions = model(text)
        
        loss = criterion(predictions, labels)
        
        acc = mininlp.utils.calculate_accuracy(predictions, labels)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(data_loader), epoch_acc / len(data_loader)

In [58]:
def evaluate(model, data_loader, criterion, device):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for text, labels in data_loader:

            text = text.to(device)
            labels = labels.to(device)
            
            predictions = model(text)
            
            loss = criterion(predictions, labels)
            
            acc = mininlp.utils.calculate_accuracy(predictions, labels)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(data_loader), epoch_acc / len(data_loader)

In [59]:
n_epochs = 10

best_valid_loss = float('inf')

for epoch in tqdm.notebook.tqdm(range(n_epochs)):

    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion, device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'nbow-model.pt')
        
    print(f'Epoch: {epoch:2}')
    print(f'  Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'  Valid Loss: {valid_loss:.3f} | Valid Acc: {valid_acc*100:.2f}%')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Epoch:  0
  Train Loss: 0.691 | Train Acc: 53.32%
  Valid Loss: 0.686 | Valid Acc: 62.64%
Epoch:  1
  Train Loss: 0.676 | Train Acc: 68.02%
  Valid Loss: 0.668 | Valid Acc: 70.93%
Epoch:  2
  Train Loss: 0.648 | Train Acc: 73.11%
  Valid Loss: 0.637 | Valid Acc: 73.11%
Epoch:  3
  Train Loss: 0.609 | Train Acc: 76.55%
  Valid Loss: 0.599 | Valid Acc: 76.36%
Epoch:  4
  Train Loss: 0.566 | Train Acc: 79.25%
  Valid Loss: 0.562 | Valid Acc: 78.44%
Epoch:  5
  Train Loss: 0.522 | Train Acc: 81.53%
  Valid Loss: 0.527 | Valid Acc: 80.63%
Epoch:  6
  Train Loss: 0.483 | Train Acc: 83.35%
  Valid Loss: 0.495 | Valid Acc: 82.26%
Epoch:  7
  Train Loss: 0.445 | Train Acc: 85.44%
  Valid Loss: 0.467 | Valid Acc: 83.05%
Epoch:  8
  Train Loss: 0.414 | Train Acc: 86.93%
  Valid Loss: 0.443 | Valid Acc: 83.94%
Epoch:  9
  Train Loss: 0.385 | Train Acc: 87.86%
  Valid Loss: 0.422 | Valid Acc: 84.75%



In [60]:
model.load_state_dict(torch.load('nbow-model.pt'))

test_loss, test_acc = evaluate(model, test_loader, criterion, device)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.368 | Test Acc: 88.40%


In [61]:
def predict(sentence, text_transforms, model, device):
    model.eval()
    tensor = text_transforms(sentence).unsqueeze(-1).to(device)
    prediction = model(tensor)
    probabilities = nn.functional.softmax(prediction, dim=-1)
    pos_probability = probabilities.squeeze(0)[-1].item()
    return pos_probability

In [62]:
sentence = 'the absolute worst movie of all time.'

predict(sentence, text_transforms, model, device)

3.0972864806244615e-07

In [63]:
sentence = 'one of the greatest films i have ever seen in my life.'

predict(sentence, text_transforms, model, device)

1.0

In [64]:
sentence = "i thought it was going to be one of the greatest films i have ever seen in my life, \
but it was actually the absolute worst movie of all time."

predict(sentence, text_transforms, model, device)

0.8011810779571533

In [65]:
sentence = "i thought it was going to be the absolute worst movie of all time, \
but it was actually one of the greatest films i have ever seen in my life."

predict(sentence, text_transforms, model, device)

0.8011810779571533