In [1]:
import torch
import torch.nn as nn
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader

In [2]:
# Tokenizer
tokenizer = get_tokenizer("basic_english")
train_iter = IMDB(split="train")

In [3]:
def yield_tokens(data_iter):
    for label, line in data_iter:
        yield tokenizer(line)


In [4]:
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])


In [8]:
# Simple LSTM model
class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(SimpleLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        _, (h, _) = self.lstm(x)
        return self.fc(h[-1])


In [9]:
# Example model
model = SimpleLSTM(vocab_size=len(vocab), embed_dim=64, hidden_dim=128)
print(model)

SimpleLSTM(
  (embedding): Embedding(100683, 64)
  (lstm): LSTM(64, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=2, bias=True)
)
