## Sentiment Analysis: Determining the polarity of a text (positive or negative).

## Data

[IMDB](http://ai.stanford.edu/~amaas/data/sentiment/) Dataset
- A dataset for binary sentiment classification.
- It provides a set of 25,000 highly polar movie reviews for training, and 25,000 for testing.


**Note**: to run the following codes, you need to dowloand the dataset from the provided link and change the `data_dir` in the following cell accordingly.

## Libraries

In [199]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import re
import sys
import spacy  # just for NLP
import pickle
import numpy as np

from glob import glob
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from utils import *
from data_utils import Vocabulary, tokenizer
from train_utils import train


# setup
NLP = spacy.load('en_core_web_sm')  # NLP toolkit

## Tokenizing

In [2]:
text = """
Bromwell High is a cartoon comedy. 
It ran at the same time as some other programs about school life, such as 'Teachers'. 
My 35 years in the teaching profession lead me to believe that Bromwell High's 
satire is much closer to reality than is 'Teachers'. 
The scramble to survive financially, the insightful students who can see 
right through their pathetic teachers' pomp, the pettiness of the whole situation, 
all remind me of the schools I knew and their students. 
When I saw the episode in which a student repeatedly tried to burn down the school, 
I immediately recalled ......... at .......... High. 
A classic line: INSPECTOR: I'm here to sack one of your teachers. 
STUDENT: Welcome to Bromwell High. 
I expect that many adults of my age think that Bromwell High is far fetched. 
What a pity that it isn't!!!
"""

In [3]:
''' Remove the followimg characters and replace with space  '''
text = re.sub(r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’;]", " ", str(text)) 
print(text)

 Bromwell High is a cartoon comedy.  It ran at the same time as some other programs about school life, such as 'Teachers'.  My 35 years in the teaching profession lead me to believe that Bromwell High's  satire is much closer to reality than is 'Teachers'.  The scramble to survive financially, the insightful students who can see  right through their pathetic teachers' pomp, the pettiness of the whole situation,  all remind me of the schools I knew and their students.  When I saw the episode in which a student repeatedly tried to burn down the school,  I immediately recalled ......... at .......... High.  A classic line  INSPECTOR  I'm here to sack one of your teachers.  STUDENT  Welcome to Bromwell High.  I expect that many adults of my age think that Bromwell High is far fetched.  What a pity that it isn't!!! 


In [4]:
'''Replace some spaces with one space'''
text = re.sub(r"[ ]+", " ", text)
print(text)

 Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as 'Teachers'. My 35 years in the teaching profession lead me to believe that Bromwell High's satire is much closer to reality than is 'Teachers'. The scramble to survive financially, the insightful students who can see right through their pathetic teachers' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line INSPECTOR I'm here to sack one of your teachers. STUDENT Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn't!!! 


In [5]:
'''Replace some signs ! with one !'''

text = re.sub(r"\!+", "!", text)
print(text)

 Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as 'Teachers'. My 35 years in the teaching profession lead me to believe that Bromwell High's satire is much closer to reality than is 'Teachers'. The scramble to survive financially, the insightful students who can see right through their pathetic teachers' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line INSPECTOR I'm here to sack one of your teachers. STUDENT Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn't! 


In [6]:
'''tonenize'''

tokens = [w.text for w in NLP.tokenizer(text)]
print(tokens)

[' ', 'Bromwell', 'High', 'is', 'a', 'cartoon', 'comedy', '.', 'It', 'ran', 'at', 'the', 'same', 'time', 'as', 'some', 'other', 'programs', 'about', 'school', 'life', ',', 'such', 'as', "'", 'Teachers', "'", '.', 'My', '35', 'years', 'in', 'the', 'teaching', 'profession', 'lead', 'me', 'to', 'believe', 'that', 'Bromwell', 'High', "'s", 'satire', 'is', 'much', 'closer', 'to', 'reality', 'than', 'is', "'", 'Teachers', "'", '.', 'The', 'scramble', 'to', 'survive', 'financially', ',', 'the', 'insightful', 'students', 'who', 'can', 'see', 'right', 'through', 'their', 'pathetic', 'teachers', "'", 'pomp', ',', 'the', 'pettiness', 'of', 'the', 'whole', 'situation', ',', 'all', 'remind', 'me', 'of', 'the', 'schools', 'I', 'knew', 'and', 'their', 'students', '.', 'When', 'I', 'saw', 'the', 'episode', 'in', 'which', 'a', 'student', 'repeatedly', 'tried', 'to', 'burn', 'down', 'the', 'school', ',', 'I', 'immediately', 'recalled', '.........', 'at', '..........', 'High', '.', 'A', 'classic', 'line'

### Tokenizer and Vocabulary

We have defined a function in `utils.py`, which gets the inputs text and splits it to a sequence of tokens. We have used **SpaCy** toolkit for tokeniztion and you need to install it to run the codes.

```python
def tokenizer(text):
    text = re.sub(r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’;]", " ", str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    return [x.text for x in NLP.tokenizer(text) if x.text != " "]
```

In [201]:

data_dir = 'dataset/aclImdb/'

vocab_path = 'vocab.pkl'

# parameters
max_len = 200  # By this initilazatio we consider just 200 character of each text. we determine based on mean + 2 * sigma
min_count = 10     #we replace every token which repeat less than 10 times with the spetial token. This is UNK = '<unk>'. 
batch_size = 50

In [224]:
import splitfolders
input_folder='dataset/aclImdb/test/'
splitfolders.ratio(input_folder, output="dataset/aclImdb/valid", seed=1337, ratio=(0.88, 0.1, 0.02)) 


Copying files: 0 files [00:00, ? files/s][A
Copying files: 1 files [00:02,  2.39s/ files][A
Copying files: 9 files [00:02,  4.90 files/s][A
Copying files: 18 files [00:02, 11.22 files/s][A
Copying files: 28 files [00:02, 19.67 files/s][A
Copying files: 38 files [00:02, 29.22 files/s][A
Copying files: 48 files [00:02, 39.22 files/s][A
Copying files: 57 files [00:03, 47.53 files/s][A
Copying files: 68 files [00:03, 58.82 files/s][A
Copying files: 78 files [00:03, 66.00 files/s][A
Copying files: 88 files [00:03, 72.69 files/s][A
Copying files: 98 files [00:03, 74.99 files/s][A
Copying files: 107 files [00:03, 78.15 files/s][A
Copying files: 117 files [00:03, 82.57 files/s][A
Copying files: 127 files [00:03, 86.33 files/s][A
Copying files: 137 files [00:03, 88.64 files/s][A
Copying files: 147 files [00:04, 84.52 files/s][A
Copying files: 156 files [00:04, 79.60 files/s][A
Copying files: 166 files [00:04, 83.43 files/s][A
Copying files: 176 files [00:04, 87.43 files/s][

Copying files: 3226 files [00:36, 99.92 files/s][A
Copying files: 3237 files [00:36, 100.79 files/s][A
Copying files: 3248 files [00:36, 100.85 files/s][A
Copying files: 3259 files [00:36, 96.37 files/s] [A
Copying files: 3269 files [00:37, 93.78 files/s][A
Copying files: 3279 files [00:37, 91.98 files/s][A
Copying files: 3289 files [00:37, 92.67 files/s][A
Copying files: 3299 files [00:37, 91.18 files/s][A
Copying files: 3309 files [00:37, 90.87 files/s][A
Copying files: 3319 files [00:37, 89.93 files/s][A
Copying files: 3329 files [00:37, 89.04 files/s][A
Copying files: 3339 files [00:37, 89.13 files/s][A
Copying files: 3349 files [00:37, 91.38 files/s][A
Copying files: 3359 files [00:37, 93.29 files/s][A
Copying files: 3369 files [00:38, 94.95 files/s][A
Copying files: 3380 files [00:38, 97.62 files/s][A
Copying files: 3391 files [00:38, 100.06 files/s][A
Copying files: 3402 files [00:38, 101.20 files/s][A
Copying files: 3413 files [00:38, 101.70 files/s][A
Copyin

Copying files: 6615 files [01:09, 102.21 files/s][A
Copying files: 6626 files [01:09, 104.15 files/s][A
Copying files: 6637 files [01:09, 104.95 files/s][A
Copying files: 6648 files [01:09, 105.52 files/s][A
Copying files: 6659 files [01:09, 105.01 files/s][A
Copying files: 6670 files [01:09, 104.65 files/s][A
Copying files: 6681 files [01:09, 105.31 files/s][A
Copying files: 6692 files [01:09, 104.56 files/s][A
Copying files: 6703 files [01:10, 102.88 files/s][A
Copying files: 6714 files [01:10, 100.62 files/s][A
Copying files: 6725 files [01:10, 99.90 files/s] [A
Copying files: 6736 files [01:10, 100.50 files/s][A
Copying files: 6747 files [01:10, 100.92 files/s][A
Copying files: 6758 files [01:10, 102.92 files/s][A
Copying files: 6769 files [01:10, 102.62 files/s][A
Copying files: 6780 files [01:10, 102.69 files/s][A
Copying files: 6791 files [01:10, 103.32 files/s][A
Copying files: 6802 files [01:11, 103.18 files/s][A
Copying files: 6813 files [01:11, 103.07 files

Copying files: 9999 files [01:42, 100.37 files/s][A
Copying files: 10010 files [01:42, 98.63 files/s][A
Copying files: 10021 files [01:42, 100.15 files/s][A
Copying files: 10032 files [01:43, 100.40 files/s][A
Copying files: 10043 files [01:43, 102.55 files/s][A
Copying files: 10054 files [01:43, 102.94 files/s][A
Copying files: 10065 files [01:43, 102.34 files/s][A
Copying files: 10076 files [01:43, 102.50 files/s][A
Copying files: 10087 files [01:43, 102.90 files/s][A
Copying files: 10098 files [01:43, 103.18 files/s][A
Copying files: 10109 files [01:43, 102.51 files/s][A
Copying files: 10120 files [01:43, 102.90 files/s][A
Copying files: 10131 files [01:44, 102.32 files/s][A
Copying files: 10142 files [01:44, 101.35 files/s][A
Copying files: 10153 files [01:44, 101.24 files/s][A
Copying files: 10164 files [01:44, 100.06 files/s][A
Copying files: 10175 files [01:44, 100.61 files/s][A
Copying files: 10186 files [01:44, 102.98 files/s][A
Copying files: 10197 files [01

Copying files: 13326 files [02:17, 106.20 files/s][A
Copying files: 13338 files [02:17, 107.71 files/s][A
Copying files: 13349 files [02:18, 107.46 files/s][A
Copying files: 13360 files [02:18, 105.77 files/s][A
Copying files: 13371 files [02:18, 104.31 files/s][A
Copying files: 13382 files [02:18, 105.65 files/s][A
Copying files: 13393 files [02:18, 106.30 files/s][A
Copying files: 13405 files [02:18, 107.80 files/s][A
Copying files: 13416 files [02:18, 104.82 files/s][A
Copying files: 13427 files [02:18, 103.94 files/s][A
Copying files: 13438 files [02:18, 104.20 files/s][A
Copying files: 13449 files [02:19, 104.98 files/s][A
Copying files: 13460 files [02:19, 104.05 files/s][A
Copying files: 13471 files [02:19, 103.40 files/s][A
Copying files: 13482 files [02:19, 103.53 files/s][A
Copying files: 13493 files [02:19, 103.33 files/s][A
Copying files: 13504 files [02:19, 101.48 files/s][A
Copying files: 13515 files [02:19, 101.05 files/s][A
Copying files: 13526 files [

Copying files: 16651 files [02:50, 99.11 files/s][A
Copying files: 16661 files [02:50, 96.10 files/s][A
Copying files: 16672 files [02:50, 97.56 files/s][A
Copying files: 16683 files [02:51, 99.41 files/s][A
Copying files: 16694 files [02:51, 100.72 files/s][A
Copying files: 16705 files [02:51, 99.96 files/s] [A
Copying files: 16716 files [02:51, 99.45 files/s][A
Copying files: 16727 files [02:51, 101.01 files/s][A
Copying files: 16738 files [02:51, 100.72 files/s][A
Copying files: 16749 files [02:51, 101.92 files/s][A
Copying files: 16760 files [02:51, 103.06 files/s][A
Copying files: 16771 files [02:51, 101.57 files/s][A
Copying files: 16782 files [02:52, 102.51 files/s][A
Copying files: 16793 files [02:52, 101.76 files/s][A
Copying files: 16804 files [02:52, 100.69 files/s][A
Copying files: 16815 files [02:52, 100.22 files/s][A
Copying files: 16826 files [02:52, 99.63 files/s] [A
Copying files: 16837 files [02:52, 101.41 files/s][A
Copying files: 16848 files [02:52

Copying files: 19969 files [03:23, 101.52 files/s][A
Copying files: 19980 files [03:23, 100.52 files/s][A
Copying files: 19991 files [03:23, 99.57 files/s] [A
Copying files: 20001 files [03:23, 98.06 files/s][A
Copying files: 20012 files [03:23, 99.21 files/s][A
Copying files: 20022 files [03:23, 98.59 files/s][A
Copying files: 20033 files [03:24, 100.15 files/s][A
Copying files: 20044 files [03:24, 101.25 files/s][A
Copying files: 20055 files [03:24, 101.74 files/s][A
Copying files: 20066 files [03:24, 102.65 files/s][A
Copying files: 20077 files [03:24, 101.29 files/s][A
Copying files: 20088 files [03:24, 100.36 files/s][A
Copying files: 20099 files [03:24, 100.55 files/s][A
Copying files: 20110 files [03:24, 101.51 files/s][A
Copying files: 20121 files [03:24, 101.92 files/s][A
Copying files: 20132 files [03:25, 100.24 files/s][A
Copying files: 20143 files [03:25, 102.14 files/s][A
Copying files: 20154 files [03:25, 103.81 files/s][A
Copying files: 20165 files [03:

Copying files: 23302 files [03:55, 101.23 files/s][A
Copying files: 23313 files [03:56, 101.43 files/s][A
Copying files: 23324 files [03:56, 101.30 files/s][A
Copying files: 23335 files [03:56, 99.28 files/s] [A
Copying files: 23346 files [03:56, 100.33 files/s][A
Copying files: 23357 files [03:56, 101.64 files/s][A
Copying files: 23368 files [03:56, 99.79 files/s] [A
Copying files: 23379 files [03:56, 101.25 files/s][A
Copying files: 23390 files [03:56, 102.29 files/s][A
Copying files: 23401 files [03:56, 102.18 files/s][A
Copying files: 23412 files [03:56, 102.10 files/s][A
Copying files: 23423 files [03:57, 101.20 files/s][A
Copying files: 23434 files [03:57, 101.69 files/s][A
Copying files: 23445 files [03:57, 101.20 files/s][A
Copying files: 23456 files [03:57, 102.26 files/s][A
Copying files: 23467 files [03:57, 103.31 files/s][A
Copying files: 23478 files [03:57, 102.03 files/s][A
Copying files: 23489 files [03:57, 100.05 files/s][A
Copying files: 23500 files [

In [227]:
data_dir='dataset/aclImdb/dev/'
os.listdir(data_dir)

['test', 'train']

In [228]:
os.listdir(f'{data_dir}/train')

['neg', 'pos']

### Statistics

In [229]:
all_filenames = glob(f'{data_dir}/*/*/*.txt')
num_words = [len(open(f, encoding="utf-8").read().split(' ')) for f in tqdm.notebook.tqdm(all_filenames)]

# print statistics
print('Min length =', min(num_words))
print('Max length =', max(num_words))

print('Mean = {:.2f}'.format(np.mean(num_words)))
print('Std  = {:.2f}'.format(np.std(num_words)))

print('mean + 2 * sigma = {:.2f}'.format(np.mean(num_words) + 2.0 * np.std(num_words)))

  0%|          | 0/3000 [00:00<?, ?it/s]

Min length = 9
Max length = 1090
Mean = 228.50
Std  = 170.35
mean + 2 * sigma = 569.19


## Dataset

In [230]:
PAD = '<pad>'  # special symbol we use for padding text
UNK = '<unk>'  # special symbol we use for rare or unknown word

In [231]:
class TextClassificationDataset(Dataset):
    
    def __init__(self, path, tokenizer, 
                 split='train', 
                 vocab_path='vocab.pkl', 
                 max_len=100, min_count=10):
        
        self.path = path
        assert split in ['train', 'test']
        self.split = split
        self.vocab_path = vocab_path
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.min_count = min_count
        
        self.cache = {}
        self.vocab = None
        
        self.classes = []
        self.class_to_index = {}
        self.text_files = []
        
        split_path = f'{path}/{split}'
        
        for cls_idx, label in enumerate(os.listdir(split_path)):
            text_files = [(fname, cls_idx) for fname in glob(f'{split_path}/{label}/*.txt')]
            self.text_files += text_files
            self.classes += [label]
            self.class_to_index[label] = cls_idx
        
        self.num_classes = len(self.classes)
            
        # build vocabulary from training and validation texts
        self.build_vocab()
        
    def __getitem__(self, index):
        # read the tokenized text file and its label (neg=0, pos=1)
        fname, class_idx = self.text_files[index]
        
        if fname in self.cache:
            return self.cache[fname], class_idx
        
        # read text file 
        text = open(fname, encoding="utf-8").read()
        
        # tokenize the text file
        tokens = self.tokenizer(text.lower().strip())
        
        # padding and trimming
        if len(tokens) < self.max_len:
            num_pads = self.max_len - len(tokens)
            tokens = [PAD] * num_pads + tokens
        elif len(tokens) > self.max_len:
            tokens = tokens[:self.max_len]
            
        # numericalizing
        ids = torch.LongTensor(self.max_len)
        for i, word in enumerate(tokens):
            if word not in self.vocab.word2index:
                ids[i] = self.vocab.word2index[UNK]  # unknown words
            elif word != PAD and self.vocab.word2count[word] < self.min_count:
                ids[i] = self.vocab.word2index[UNK]  # rare words
            else:
                ids[i] = self.vocab.word2index[word]
                
        # save in cache for future use
        self.cache[fname] = ids
        
        return ids, class_idx
    
    def __len__(self):
        return len(self.text_files)
    
    def build_vocab(self):
        if not os.path.exists(self.vocab_path):
            vocab = Vocabulary(self.tokenizer)
            filenames = glob(f'{self.path}/*/*/*.txt')
            for filename in tqdm.notebook.tqdm(filenames, desc='Building Vocab'):
                with open(filename, encoding='utf8') as f:
                    for line in f:
                        vocab.add_sentence(line.lower())

            # sort words by their frequencies
            words = [(0, PAD), (0, UNK)]
            words += sorted([(c, w) for w, c in vocab.word2count.items()], reverse=True)

            self.vocab = Vocabulary(self.tokenizer)
            for i, (count, word) in enumerate(words):
                self.vocab.word2index[word] = i
                self.vocab.word2count[word] = count
                self.vocab.index2word[i] = word
                self.vocab.count += 1

            pickle.dump(self.vocab, open(self.vocab_path, 'wb'))
        else:
            self.vocab = pickle.load(open(self.vocab_path, 'rb'))

In [232]:
train_ds = TextClassificationDataset(data_dir, tokenizer, 'train', vocab_path, max_len, min_count)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

valid_ds = TextClassificationDataset(data_dir, tokenizer, 'test', vocab_path, max_len, min_count)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)

In [233]:
len(train_ds)

2500

In [234]:
len(valid_ds)

500

In [235]:
train_ds.classes

['neg', 'pos']

In [236]:
train_ds.class_to_index

{'neg': 0, 'pos': 1}

In [237]:
ids, label = train_ds[0]

print(train_ds.classes[label])
print(ids.numpy())

neg
[   14     9    40   476     7   144     2  2236     7   213   116    30
     2   176     4  3920     5   368     3    45    16    73   171   283
   157   140     4     6   594   455     7     2   104  1174 14516  2030
     7  1788  1640     5  1788  5044     3    41   147   259 13558   118
   224   131    15    38    30  2196     7   124     3     5   124    82
     4    49    28  1273    22    14    35     3   150    74   175   660
   529     3     1    46   115   175   849  6226    21  1788  1640     3
    46 21342  2813     2  2844     3  1999  2860    46  2399    21  1788
  5044     5    74     2   154   804     4  1788  1640    16  2261  7362
   421   633   170    14    25 18465    36     2  2702     3     5    13
   155   137  1531    56     2  2143   960  5892    19   413    12    14
    25    59     5   144     2  2143    83    31   219   295     2  2539
   176   114    59    43  2062 21040     3   178    25    13   147   119
    22   960  5892    55   101   400     2  253

In [238]:
# convert back the sequence of integers into original text
print(' '.join([train_ds.vocab.index2word[i.item()] for i in ids]))

this is an example of why the majority of action films are the same . generic and boring , there 's really nothing worth watching here . a complete waste of the then barely tapped talents of ice t and ice cube , who 've each proven many times over that they are capable of acting , and acting well . do n't bother with this one , go see new jack city , <unk> or watch new york undercover for ice t , or boyz n the hood , higher learning or friday for ice cube and see the real deal . ice t 's horribly cliched dialogue alone makes this film grate at the teeth , and i 'm still wondering what the heck bill paxton was doing in this film ? and why the heck does he always play the exact same character ? from aliens onward , every film i 've seen with bill paxton has him playing the exact same irritating character , and at least in aliens his character died , which made it somewhat gratifying ... <br > < br > overall , this is second rate action trash . there are countless


In [239]:
# print the original text
print(open(train_ds.text_files[0][0]).read())

This is an example of why the majority of action films are the same. Generic and boring, there's really nothing worth watching here. A complete waste of the then barely-tapped talents of Ice-T and Ice Cube, who've each proven many times over that they are capable of acting, and acting well. Don't bother with this one, go see New Jack City, Ricochet or watch New York Undercover for Ice-T, or Boyz n the Hood, Higher Learning or Friday for Ice Cube and see the real deal. Ice-T's horribly cliched dialogue alone makes this film grate at the teeth, and I'm still wondering what the heck Bill Paxton was doing in this film? And why the heck does he always play the exact same character? From Aliens onward, every film I've seen with Bill Paxton has him playing the exact same irritating character, and at least in Aliens his character died, which made it somewhat gratifying...<br /><br />Overall, this is second-rate action trash. There are countless better films to see, and if you really want to se

### Vovcabulary size

In [240]:
vocab = train_ds.vocab
freqs = [(count, word) for (word, count) in vocab.word2count.items() if count >= min_count]
vocab_size = len(freqs) + 2  # for PAD and UNK tokens
print(f'Vocab size = {vocab_size}')

print('\nMost common words:')
for c, w in sorted(freqs, reverse=True)[:10]:
    print(f'{w}: {c}')

Vocab size = 29506

Most common words:
the: 666713
,: 543467
.: 470130
and: 324156
a: 321800
of: 289313
to: 267961
is: 217022
>: 202243
it: 187974


## LSTM Classifier with Attention mechanism

In [241]:
# Attention computes a weighted average of the hidden states of the LSTM Model.
# In fact, it produce a weight for each hidden state at different time steps

class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )
    
    def forward(self, encoder_outputs):
        # encoder_outputs = [batch size, sent len, hid dim]
        energy = self.projection(encoder_outputs)
        # energy = [batch size, sent len, 1]
        weights = F.softmax(energy.squeeze(-1), dim=1)
        # weights = [batch size, sent len]
        outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)
        # outputs = [batch size, hid dim]
        return outputs, weights

    
class AttentionLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embed_size
        self.num_layers = n_layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
                            bidirectional=bidirectional, 
                            dropout= 0 if n_layers < 2 else dropout)
        self.attention = SelfAttention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x = [sent len, batch size]
        embedded = self.embedding(x)
        # embedded = [sent len, batch size, emb dim]
        output, (hidden, cell) = self.lstm(embedded)
        # use 'batch_first' if you want batch size to be the 1st para
        # output = [sent len, batch size, hid dim*num directions]
        output = output[:, :, :self.hidden_dim] + output[:, :, self.hidden_dim:]
        # output = [sent len, batch size, hid dim]
        ouput = output.permute(1, 0, 2)
        # ouput = [batch size, sent len, hid dim]
        new_embed, weights = self.attention(ouput)
        # new_embed = [batch size, hid dim]
        # weights = [batch size, sent len]
        new_embed = self.dropout(new_embed)
        return self.fc(new_embed)

In [242]:
vocab_size = 2 + len([w for (w, c) in train_ds.vocab.word2count.items() if c >= min_count])
print(vocab_size)

29506


## Model

In [243]:
# LSTM parameters
embed_size = 100
hidden_size = 256 
num_layers = 1

# training parameters
lr = 0.001
num_epochs = 5

In [244]:
model = AttentionLSTM(vocab_size, embed_size, hidden_size, 
                      output_dim=train_ds.num_classes, 
                      n_layers=num_layers, bidirectional=True, dropout=0.5)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [245]:
criterion = nn.CrossEntropyLoss().to(device)
criterion = criterion.to(device)
    
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.7, 0.99))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

### Training

In [247]:
hist = train(model, train_dl, valid_dl, criterion, optimizer, device, scheduler, num_epochs)

[Epoch:  1/ 5] | Training Loss: 0.0122 | Testing Loss: 0.0118 | Training Acc:           66.76 | Testing Acc: 70.60
[Epoch:  2/ 5] | Training Loss: 0.0097 | Testing Loss: 0.0118 | Training Acc:           76.72 | Testing Acc: 70.80
[Epoch:  3/ 5] | Training Loss: 0.0082 | Testing Loss: 0.0121 | Training Acc:           81.68 | Testing Acc: 73.20
[Epoch:  4/ 5] | Training Loss: 0.0064 | Testing Loss: 0.0133 | Training Acc:           86.96 | Testing Acc: 71.80



Training:   0%|          | 0/50 [00:00<?, ?it/s]

Validation:   0%|          | 0/10 [00:00<?, ?it/s]

### QRNN

<img src='imgs/QRNN.png' width='100%'/>