https://github.com/bentrevett/pytorch-text-classification

In [5]:
import utils
import functools
import numpy as np
import matplotlib.pyplot as plt
import datasets
import collections
import torch
import random
import spacy
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm

In [6]:
seed = 1234

torch.manual_seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
%load_ext autoreload
%autoreload 2

# Loading the Dataset

In [8]:
imdb = datasets.load_dataset('imdb')

Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [9]:
imdb

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [10]:
imdb['train'][0]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}

In [11]:
train_data, test_data = datasets.load_dataset('imdb', split=['train','test'])

Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [12]:
print(len(train_data),len(test_data))

25000 25000


In [13]:
print(train_data[0])

{'label': 1, 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}


In [14]:
def get_train_valid_split(train_data, valid_ratio = 0.2, shuffle = True):
    data = train_data.train_test_split(test_size = valid_ratio, shuffle=shuffle)
    train_data = data['train']
    valid_data = data['test']
    return train_data, valid_data

In [15]:
train_data, valid_data = get_train_valid_split(train_data, 0.2, True)

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-a76c42879a7d62bf.arrow and /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-80e03183e3efcbd4.arrow


In [16]:
print(len(train_data), len(valid_data))

20000 5000


# Tokenizing the Data

In [17]:
tokenizer_fn = lambda x: x.split()

In [18]:
tokenizer = utils.tokenizer.Tokenizer(tokenizer_fn)

In [19]:
example_string = "Hello Guys!! What's going on?"

In [20]:
print(tokenizer.tokenize(example_string))

['Hello', 'Guys!!', "What's", 'going', 'on?']


In [21]:
import spacy
nlp = spacy.load('en_core_web_sm')

def spacy_tokenize(s:str, nlp:spacy.lang):
    return [t.text for t in nlp.tokenizer(s)]

In [22]:
print(spacy_tokenize(example_string,nlp))

['Hello', 'Guys', '!', '!', 'What', "'s", 'going', 'on', '?']


In [23]:
_spacy_tokenize = functools.partial(spacy_tokenize, nlp = nlp)

In [24]:
print(_spacy_tokenize(example_string))

['Hello', 'Guys', '!', '!', 'What', "'s", 'going', 'on', '?']


In [25]:
tokenizer = utils.tokenizer.Tokenizer(_spacy_tokenize)

In [26]:
tokenizer.tokenize(example_string)

['Hello', 'Guys', '!', '!', 'What', "'s", 'going', 'on', '?']

# Building Vocabulary

In [27]:
field = 'text'

counter = utils.vocab.build_vocab_counter(train_data, field, tokenizer)

In [28]:
counter.most_common(10)

[('the', 232322),
 (',', 220773),
 ('.', 190010),
 ('a', 125392),
 ('and', 125259),
 ('of', 115263),
 ('to', 107115),
 ('is', 87381),
 ('in', 70335),
 ('I', 61975)]

In [29]:
min_freq = 6
max_size = 30_000

vocab = utils.vocab.Vocab(counter, min_freq, max_size)

In [30]:
len(vocab)

28392

In [31]:
example_string = 'Hello world! How is everyone doing today?'

example_tokens = tokenizer.tokenize(example_string)

print(example_tokens)

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']


In [32]:
vocab.stoi(example_tokens)

[7589, 223, 42, 568, 9, 353, 428, 572, 58]

In [33]:
vocab.itos(vocab.stoi(example_tokens))

['Hello', 'world', '!', 'How', 'is', 'everyone', 'doing', 'today', '?']

In [34]:
example_string = 'My best friend is named Cthulhu'

example_tokens = tokenizer.tokenize(example_string)

vocab.itos(vocab.stoi(example_tokens))

['My', 'best', 'friend', 'is', 'named', '<unk>']

# Creating the DataLoader

In [35]:
text_transforms = utils.transforms.sequential_transforms(tokenizer.tokenize,
                                                           vocab.stoi,
                                                           utils.transforms.to_longtensor)

In [36]:
label_transforms = utils.transforms.sequential_transforms(utils.transforms.to_longtensor)

In [37]:
train_dataset = utils.dataset.TextClassificationDataset(train_data, text_transforms, label_transforms)

In [38]:
train_dataset[0]

(tensor([ 7471,  2611,    11,   277,     8,    36,    43,   348,   835,    18,
           706,    31,    47,  5532,    15,  4418,    31,   264,    47,  2950,
            18,  4014,     2,    22,     9,   176,     6,    67,    32,    82,
           343,   322,     6,    59,     7,     2,   131,    69,   648,    36,
           144,  2159,    20,    14,   993,    14,    26,     2,    75,     9,
         16570,    18,   116,    75,     7,   110,     2,   223,    69,    36,
             6,   112,  1675,    30,  1023,     9,    62,     0,    18,  2536,
             2,   150,     2,   925,    19,    29,     5,   925,    26,     2,
            75, 12992,    10,    43,   850,    40,     2,   233, 22121,    18,
           425,    15,  9305,   713,   293,  4712,    63,    36,    14,  2067,
          4368,    14,    10,  1338,     7,   110,     2,  1173,     9,   644,
             8,    36,   907,    17,   223,   790,     4]), tensor(1))

In [39]:
train_dataset.data[0]

{'label': 1,
 'text': 'Soylent Green I found to be an excellent movie.<br /><br />If you like Logan\'s Run you\'ll like this.<br /><br />Yes the movie is old and there are no special effects and some of the acting can somewhat be best described as "cheesy" but the story is excellent.<br /><br />The story of how the world can be and its impact on society is very poignant.<br /><br />At the end the mystery wasn\'t a mystery but the story unfolded in an easy at the right pace.<br /><br />It\'s nearest modern day equivalent would be "Dark Angel" in terms of how the US is shown to be third-world country.'}

In [40]:
vocab.stoi('Soylent')

7471

In [41]:
valid_dataset = utils.dataset.TextClassificationDataset(valid_data, text_transforms, label_transforms)

In [42]:
test_dataset = utils.dataset.TextClassificationDataset(test_data, text_transforms, label_transforms)

In [43]:
pad_idx = vocab.stoi(vocab.pad_token)

print(pad_idx)

1


In [44]:
collator = utils.collator.TextClassificationCollator(pad_idx)

In [45]:
batch_size = 256

In [64]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size= batch_size,
                                          shuffle = True,
                                          collate_fn = collator.collate,
                                          num_workers = torch.get_num_threads())

In [65]:
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                          batch_size = batch_size,
                                          shuffle = False,
                                          collate_fn = collator.collate,
                                          num_workers = torch.get_num_threads())

In [66]:
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size = batch_size,
                                          shuffle = False,
                                          collate_fn = collator.collate,
                                          num_workers = torch.get_num_threads())

# Creating the NBOW model

In [74]:
class NBOW(pl.LightningModule):
    def __init__(self, input_dim: int, emb_dim: int, output_dim: int, pad_idx: int):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.fc = nn.Linear(emb_dim, output_dim)
        
    def forward(self, text:torch.LongTensor) -> torch.FloatTensor:
        #text[seq_len, batch_size]
        embedded = self.embedding(text)
        #embedded[seq_len, batch_size, emb_dim]
        pooled = embedded.mean(0)
        #pooled[batch_size,emb_dim]
        prediction = self.fc(pooled)
        #prediction[batch_size,output_dim]
        return prediction
    
    def training_step(self,batch,batch_idx):
        text,labels = batch
        #text[seq_len,bs]
        #labels[bs]
        predictions = self.forward(text)
        #predictions[bs,2]
        loss = F.cross_entropy(predictions, labels)
        acc = utils.util.calculate_accuracy(predictions,labels)
        return {'loss':loss, 'acc':acc}
        
    def training_epoch_end(self,training_step_outputs):
        loss,acc = self.calculate_metrics(training_step_outputs)
        print(f'Epoch: {self.current_epoch:2}')
        print(f' Train_loss: {loss:.3f}  | Train_acc: {acc*100:.2f}%')
        
    def validation_step(self, batch, batch_idx):
        text, labels = batch
        predictions = self.forward(text)
        loss = F.cross_entropy(predictions, labels)
        acc = utils.util.calculate_accuracy(predictions, labels)
        self.log('valid_loss', loss)
        return {'loss': loss, 'acc':acc}
    
    def validation_epoch_end(self,validation_step_outputs):
        loss, acc = self.calculate_metrics(validation_step_outputs)
        print(f' valid_loss: {loss:.3f} | valid_acc: {acc*100:.2f}%')
        
    def test_step(self, batch, batch_idx):
        text,labels = batch
        predictions = self.forward(text)
        loss = F.cross_entropy(predictions, labels)
        acc = utils.util.calculate_accuracy(predictions, labels)
        return {'loss': loss, 'acc': acc}
    
    def test_epoch_end(self,test_step_outputs):
        loss, acc = self.calculate_metrics(test_step_outputs)
        print(f' test_loss: {loss:.3f} | acc: {acc}')
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters())
    
    def calculate_metrics(self, step_outputs):
        loss = torch.mean(torch.stack([x['loss'] for x in step_outputs]))
        acc = torch.mean(torch.stack([x['acc'] for x in step_outputs]))
        return loss, acc

In [75]:
input_dim = len(vocab)
emb_dim = 100
output_dim = 2

model = NBOW(input_dim, emb_dim, output_dim, pad_idx)

In [76]:
early_stopping_callback = pl.callbacks.EarlyStopping(monitor='valid_loss',
                                                    mode='min',
                                                    patience=0)

In [77]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='valid_loss',
                                                  mode='min')

In [78]:
trainer = pl.Trainer(max_epochs=10,
                    gpus=1,
                    callbacks=[early_stopping_callback,
                              checkpoint_callback],
                    deterministic=True,
                    num_sanity_val_steps=0,
                    progress_bar_refresh_rate=0)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [79]:
_ = trainer.fit(model,train_loader,valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | embedding | Embedding | 2.8 M 
1 | fc        | Linear    | 202   
----------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.358    Total estimated model params size (MB)


Epoch:  0
 Train_loss: 0.690  | Train_acc: 54.09%
 valid_loss: 0.685 | valid_acc: 61.92%
Epoch:  1
 Train_loss: 0.676  | Train_acc: 67.47%
 valid_loss: 0.667 | valid_acc: 68.67%
Epoch:  2
 Train_loss: 0.648  | Train_acc: 72.32%
 valid_loss: 0.637 | valid_acc: 73.08%
Epoch:  3
 Train_loss: 0.609  | Train_acc: 75.83%
 valid_loss: 0.600 | valid_acc: 75.74%
Epoch:  4
 Train_loss: 0.566  | Train_acc: 78.62%
 valid_loss: 0.563 | valid_acc: 78.25%
Epoch:  5
 Train_loss: 0.523  | Train_acc: 81.22%
 valid_loss: 0.528 | valid_acc: 80.31%
Epoch:  6
 Train_loss: 0.483  | Train_acc: 83.41%
 valid_loss: 0.496 | valid_acc: 81.86%
Epoch:  7
 Train_loss: 0.447  | Train_acc: 85.31%
 valid_loss: 0.468 | valid_acc: 83.18%
Epoch:  8
 Train_loss: 0.416  | Train_acc: 86.57%
 valid_loss: 0.444 | valid_acc: 84.16%
Epoch:  9
 Train_loss: 0.389  | Train_acc: 87.67%
 valid_loss: 0.424 | valid_acc: 85.09%


In [80]:
_ = trainer.test(test_dataloaders=test_loader, verbose=False)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


 test_loss: 0.423 | acc: 0.845241904258728


In [81]:
def predict(sentence, text_transforms, model):
    model.eval()
    tensor = text_transforms(sentence).unsqueeze(-1).to(model.device)
    prediction = model(tensor)
    probabilities = nn.functional.softmax(prediction,dim=-1)
    pos_probability = probabilities.squeeze(0)[-1].item()
    return pos_probability

In [82]:
sentence = 'the absolute worst movie of all time.'

predict(sentence, text_transforms, model)

3.4957501338084285e-09

In [83]:
sentence = 'one of the greatest films i have ever seen in my life.'

predict(sentence, text_transforms, model)

1.0

In [84]:
sentence = "i thought it was going to be one of the greatest films i have ever seen in my life, \
but it was actually the absolute worst movie of all time."

predict(sentence, text_transforms, model)

0.9154227375984192