# LSTM for Sentiment Analysis

Followed [this tutorial](https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html), and completed it [here](tut/nlp/4_lstm.ipynb)

Helpful blogs:
* [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
* [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tokenizers
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import sys
sys.path.append("../..")


class SentimentAnalysisModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        emb_dim: int = 30,
        hidden_size: int = 40,
        n_rnn_layers: int = 1,
    ):
        super().__init__()
        self.emb = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=emb_dim,
        )

        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hidden_size,
            num_layers=n_rnn_layers,
        )

        self.seq = nn.Sequential(
            nn.Linear(hidden_size, 500),
            nn.ReLU(),
            nn.Linear(500, 2),
        )

    def forward(self, x: torch.Tensor, lengths: torch.Tensor):


        # x shape: (B, L)
        # convert token indices to embedding values
        x = self.emb(x)
        # x shape: (B, L, Emb dim)

        
        x = x.transpose(0, 1)
        
        # Pack
        x = pack_padded_sequence(x, lengths, enforce_sorted=False)

        _, (h, _) = self.lstm(x)

        return self.seq(h[-1, :, :])

In [2]:
from tut.sentiment_analysis.helpers import load_sentiment_data
( 
    train_data,
    train_labels,
    train_lengths,
    test_data,
    test_labels,
    test_lengths,
) = load_sentiment_data()

In [3]:
# load in tokenizer
tokenizer = tokenizers.Tokenizer.from_file("models/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentimentAnalysisModel(vocab_size=vocab_size).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_function = F.cross_entropy

# Training
from tqdm import tqdm

epochs = 10
batch_size = 1000
eval_batch_size = 1000


def _calc_accuracy(
    data: torch.Tensor,
    labels: torch.Tensor,
    lengths: torch.Tensor,
    batch_size: torch.Tensor,
):
    cum_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(0, len(data), batch_size)):
            input = data[i : i + batch_size].to(device)
            _labels = labels[i : i + batch_size].to(device)
            _lengths = lengths[i : i + batch_size]

            output = model(input, _lengths)
            loss = loss_function(output, _labels)

            cum_loss += loss
            _, y_pred = torch.max(output, dim=1)
            correct += sum(_labels == y_pred)
            total += len(_labels)

    return correct, total, cum_loss


for epoch in range(epochs):
    print(epoch, epochs)
    test_correct, test_total, test_loss = _calc_accuracy(
        test_data, test_labels, test_lengths, batch_size=eval_batch_size
    )
    train_correct, train_total, train_loss = _calc_accuracy(
        train_data, train_labels, train_lengths, batch_size=eval_batch_size
    )

    print(f"test: {test_correct / test_total: 0.2f}, {test_loss}")
    print(f"train: {train_correct / train_total: 0.2f}, {train_loss}")

    for i in tqdm(range(0, len(train_data), batch_size)):
        input = train_data[i : i + batch_size].to(device)
        labels = train_labels[i : i + batch_size].to(device)
        lengths = train_lengths[i : i + batch_size]

        optimizer.zero_grad()
        output = model(input, lengths)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()


0 10


100%|██████████| 10/10 [00:01<00:00,  6.38it/s]
100%|██████████| 40/40 [00:04<00:00,  9.25it/s]


test:  0.50, 6.928741931915283
train:  0.50, 27.71198081970215


100%|██████████| 40/40 [00:13<00:00,  3.06it/s]


1 10


100%|██████████| 10/10 [00:01<00:00,  8.21it/s]
100%|██████████| 40/40 [00:04<00:00,  8.88it/s]


test:  0.63, 6.498149394989014
train:  0.66, 25.50154685974121


100%|██████████| 40/40 [00:12<00:00,  3.16it/s]


2 10


100%|██████████| 10/10 [00:01<00:00,  8.29it/s]
100%|██████████| 40/40 [00:04<00:00,  8.58it/s]


test:  0.77, 5.111581325531006
train:  0.82, 17.988630294799805


100%|██████████| 40/40 [00:13<00:00,  3.07it/s]


3 10


100%|██████████| 10/10 [00:01<00:00,  8.09it/s]
100%|██████████| 40/40 [00:04<00:00,  8.72it/s]


test:  0.71, 5.686638355255127
train:  0.74, 20.193777084350586


100%|██████████| 40/40 [00:12<00:00,  3.17it/s]


4 10


100%|██████████| 10/10 [00:01<00:00,  8.13it/s]
100%|██████████| 40/40 [00:04<00:00,  8.83it/s]


test:  0.65, 6.162026405334473
train:  0.69, 22.805904388427734


100%|██████████| 40/40 [00:12<00:00,  3.19it/s]


5 10


100%|██████████| 10/10 [00:01<00:00,  8.40it/s]
100%|██████████| 40/40 [00:04<00:00,  8.62it/s]


test:  0.83, 4.191022872924805
train:  0.89, 11.422185897827148


100%|██████████| 40/40 [00:12<00:00,  3.21it/s]


6 10


100%|██████████| 10/10 [00:01<00:00,  8.36it/s]
100%|██████████| 40/40 [00:04<00:00,  8.73it/s]


test:  0.86, 3.7186737060546875
train:  0.92, 8.56570053100586


100%|██████████| 40/40 [00:12<00:00,  3.20it/s]


7 10


100%|██████████| 10/10 [00:01<00:00,  7.87it/s]
100%|██████████| 40/40 [00:04<00:00,  8.81it/s]


test:  0.88, 3.3242077827453613
train:  0.96, 5.444023132324219


100%|██████████| 40/40 [00:12<00:00,  3.17it/s]


8 10


100%|██████████| 10/10 [00:01<00:00,  7.84it/s]
100%|██████████| 40/40 [00:04<00:00,  8.72it/s]


test:  0.88, 3.539999485015869
train:  0.97, 4.208895206451416


100%|██████████| 40/40 [00:12<00:00,  3.17it/s]


9 10


100%|██████████| 10/10 [00:01<00:00,  8.04it/s]
100%|██████████| 40/40 [00:04<00:00,  8.47it/s]


test:  0.88, 4.035975456237793
train:  0.98, 3.1144649982452393


100%|██████████| 40/40 [00:12<00:00,  3.19it/s]
