In [None]:
import torch
import torch.nn as nn
from lstm_sentiment import build_vocab, preprocess_data
from torchtext.datasets import IMDB
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils import save_plot

In [None]:
# GRU Model
class GRUClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_classes=2):
        super(GRUClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        _, h_n = self.gru(x)
        return self.fc(h_n[-1])

In [None]:
# Training loop (reusing from Step 5)
from lstm_sentiment import train_model

if __name__ == "__main__":
    # Data
    train_iter, test_iter = IMDB(split=("train", "test"))
    vocab = build_vocab(train_iter)
    train_iter, test_iter = IMDB(split=("train", "test"))
    collate_fn = preprocess_data(train_iter, test_iter, vocab)
    train_iter, test_iter = IMDB(split=("train", "test"))
    train_loader = DataLoader(list(train_iter), batch_size=32, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(list(test_iter), batch_size=32, shuffle=False, collate_fn=collate_fn)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # LSTM
    from lstm_sentiment import LSTMClassifier
    lstm_model = LSTMClassifier(len(vocab)).to(device)
    lstm_losses, lstm_accs = train_model(lstm_model, train_loader, test_loader, epochs=2, device=device)

    # GRU
    gru_model = GRUClassifier(len(vocab)).to(device)
    gru_losses, gru_accs = train_model(gru_model, train_loader, test_loader, epochs=2, device=device)
    
    # Plot accuracy comparison
    fig, ax = plt.subplots()
    ax.plot(lstm_accs, label="LSTM")
    ax.plot(gru_accs, label="GRU")
    ax.set_title("GRU vs LSTM on IMDB")
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Accuracy")
    ax.legend()
    save_plot(fig, "gru_vs_lstm.png")