<a href="https://colab.research.google.com/github/jppgks/DL-from-Scratch-with-PyTorch/blob/main/RNN_for_Text_Classification_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [617]:
import torch
from torch import nn
from torch.nn import functional as F

## Data

In [638]:
epochs = 200
texts = [
  "I am very interested in learning about the process of making coffee.",
  "I admire the hard work that goes into producing coffee beans.",
  "Conditions of workers on coffee plantations are really bad.",
  "I do not like coffee at all.",
  "Do not drink coffee."
] * epochs
sentiment = [
  1,
  1,
  0,
  0,
  0,
] * epochs

### Tokenize text

In [619]:
token2idx = dict()

for idx, token in enumerate({token.lower() for text in texts for token in text.strip(".").split()}):
  token2idx[token] = idx

token2idx["EOS"] = idx + 1

In [620]:
idx2token = dict((idx, token) for (token, idx) in token2idx.items())

In [621]:
str(token2idx)

"{'in': 0, 'producing': 1, 'really': 2, 'beans': 3, 'conditions': 4, 'work': 5, 'am': 6, 'on': 7, 'interested': 8, 'into': 9, 'not': 10, 'bad': 11, 'i': 12, 'the': 13, 'of': 14, 'all': 15, 'about': 16, 'plantations': 17, 'goes': 18, 'like': 19, 'learning': 20, 'hard': 21, 'are': 22, 'making': 23, 'process': 24, 'very': 25, 'coffee': 26, 'that': 27, 'at': 28, 'workers': 29, 'do': 30, 'drink': 31, 'admire': 32, 'EOS': 33}"

In [622]:
token2idx['interested']

8

In [623]:
idx2token[22]

'are'

In [624]:
indexes = [
  torch.tensor(
    [token2idx[word] for word in text.lower().strip(".").split()] + [token2idx["EOS"]],
    dtype=torch.long
  ).view(-1, 1)
  for text in texts
]
indexes[0].shape

torch.Size([13, 1])

## Model

In [625]:
class SentimentClassifier(nn.Module):
  def __init__(self, vocab_size):
    super(SentimentClassifier, self).__init__()

    self.embed = nn.Embedding(vocab_size, 64)
    self.rnn = nn.GRU(64, 64)
    self.out = nn.Linear(64, 2)
  
  def forward(self, inputs, hidden):
    x = self.embed(inputs).view(1, 1, -1)

    ctx, hidden = self.rnn(x, hidden)
    
    out = self.out(ctx)

    return out, hidden

In [626]:
model = SentimentClassifier(len(token2idx))

## Optimization

In [627]:
loss_fn = nn.CrossEntropyLoss()

In [628]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [630]:
from random import sample

for txt_ids, sent in sample(list(zip(indexes, sentiment)), len(indexes)):
  model.zero_grad()
  optimizer.zero_grad()

  hidden = torch.zeros(1, 1, 64)

  for tok_i in range(txt_ids.shape[1]):
    out, hidden = model(txt_ids[0][tok_i], hidden)
  
  loss = loss_fn(out.view(-1, 2), torch.tensor([sent]))

  loss.backward()
  optimizer.step()

## Inference

In [645]:
for i in range(round(len(texts) / epochs)):
  print(f"Text: {texts[i]}\n True label: {sentiment[i]}")
  txt_ids = indexes[i]
  hidden = torch.zeros(1, 1, 64)

  for tok_i in range(txt_ids.shape[1]):
    out, hidden = model(txt_ids[0][tok_i], hidden)

  correct = torch.argmax(F.softmax(out.squeeze(), dim=0)).item() == sentiment[i]
  print(f" Prediction correct: {correct}")

Text: I am very interested in learning about the process of making coffee.
 True label: 1
 Prediction correct: True
Text: I admire the hard work that goes into producing coffee beans.
 True label: 1
 Prediction correct: True
Text: Conditions of workers on coffee plantations are really bad.
 True label: 0
 Prediction correct: True
Text: I do not like coffee at all.
 True label: 0
 Prediction correct: False
Text: Do not drink coffee.
 True label: 0
 Prediction correct: True
