#1. RNN for baseline model

In [None]:
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence

class IMDBRNN(nn.Module):
  def __init__(self, input_size, embed_dim, hidden_size, padding_idx):
    super().__init__()
    self.embedding = nn.Embedding(input_size, embed_dim, padding_idx=padding_idx)
    self.rnn = nn.RNN(embed_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
    self.linear = nn.Linear(hidden_size, 1)

  def forward(self, x, lengths):
    x = self.embedding(x) #(B, T, E)
    packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) #shuffle=True in train_set
    out, h = self.rnn(packed) # dimension of h => (1, B, H)
    h_n = h[-1] # h_n -> (B, H)
    logits = self.linear(h_n).squeeze(-1) #(B, 1) -(squeeze(-1))-> (B,)

    return logits

#2. LSTM model

In [None]:
class IMDBLSTM(nn.Module):
  def __init__(self, input_size, embed_dim, hidden_size, padding_idx):
    super().__init__()
    self.embedding = nn.Embedding(input_size, embed_dim, padding_idx=padding_idx)
    self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers=1, batch_first=True)
    self.linear = nn.Linear(hidden_size, 1)

  def forward(self, x, lengths):
    x = self.embedding(x)
    packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
    out, (h, c) = self.lstm(packed)
    h_n = h[-1]
    logits = self.linear(h_n).squeeze(-1)

    return logits

#3. Transformer model

Positional Encoding

In [None]:
import torch
import math

class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=5000):
    super().__init__()
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    return x + self.pe[:, :x.size(1)]

Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.fc = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)

    q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) #(batch, seq, d_model) => (batch, heads, seq, d_k)
    k = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
    v = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) #(batch, heads, seq, seq)
    if mask is not None:
      scores = scores.masked_fill(mask == 1, -1e9)

    attn = torch.softmax(scores, dim=-1)
    context = torch.matmul(attn, v) # [batch, heads, seq, d_k]

    context  = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
    return self.fc(context)

Encoder Block

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super().__init__()
    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model)
    )
    self.layernorm1 = nn.LayerNorm(d_model)
    self.layernorm2 = nn.LayerNorm(d_model)

  def forward(self, x, mask):
    attn_out = self.mha(x, x, x, mask)
    x = self.layernorm1(x + attn_out) #Add & Norm

    ffn_out = self.ffn(x)
    x = self.layernorm2(x + ffn_out)
    return x

Transformer Classifier

In [None]:
class TransformerClassifier(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, num_classes, max_len):
      super().__init__()
      self.embedding = nn.Embedding(vocab_size, d_model)
      self.pos_encoding = PositionalEncoding(d_model, max_len)
      self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
      self.classifier = nn.Linear(d_model, num_classes)

  def forward(self, x, mask):
    x = self.embedding(x)
    x = self.pos_encoding(x)

    for layer in self.layers:
      x = layer(x, mask)

    x = x.mean(dim=1)
    return self.classifier(x)