Useful resources
* [Andrej Karpathy Let's Build GPT](https://www.youtube.com/watch?v=kCc8FmEb1nY)
* [PyTorch-Transformers](https://pytorch.org/hub/huggingface_pytorch-transformers/) (I think this got spun out into the transformers package)
* [Actual tranformers in pytorch](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)
* [In depth Sentiment analysis w transformers on kaggle](https://www.kaggle.com/code/emirkocak/in-depth-series-sentiment-analysis-w-transformers)
* Some papers
    * [Text Sentiment Analysis Based on Transformer and Augmentation](https://www.frontiersin.org/articles/10.3389/fpsyg.2022.906061/full)
    * [Transformer-based deep learning models for the sentiment analysis of social media data](https://www.sciencedirect.com/science/article/pii/S2590005622000224)

In [5]:
import sys
sys.path.append("../..")

import torch
import torch.nn as nn
import torch.nn.functional as F


class SentimentAnalysis(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int = 128,
        num_classes: int = 2,
    ):
        super().__init__()

        self.emb = nn.Embedding(vocab_size, embedding_dim=embedding_dim)
        self.transformer = nn.Transformer(embedding_dim, 8, 4, 4, 4, 0.1)

        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.emb(x)
        x = x.permute(1, 0, 2)
        x = self.transformer(x, x)
        x = x.permute(1, 0, 2)
        x = x[:, -1]
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

In [11]:
from tut.sentiment_analysis.helpers import load_sentiment_data, load_tokenizer, calc_accuracy

(
    train_data,
    train_labels,
    train_lengths,
    test_data,
    test_labels,
    test_lengths,
) = load_sentiment_data()

tokenizer = load_tokenizer()
vocab_size = tokenizer.get_vocab_size()