In [None]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="torchtext")

# Set random seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_input_length = tokenizer.model_max_length

# Load IMDB dataset using Hugging Face's `datasets` library
dataset = load_dataset("imdb")

# Define custom PyTorch dataset class
class IMDBDataset(Dataset):
    def __init__(self, split):
        self.data = dataset[split]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data["text"][idx]
        label = self.data["label"][idx]
        return label, text

# Create dataset instances
train_dataset = IMDBDataset("train")
test_dataset = IMDBDataset("test")

# Optimized batch collation function
def collate_batch(batch):
    labels, texts = zip(*batch)
    labels = torch.tensor(labels, dtype=torch.float)
    texts = tokenizer(list(texts), truncation=True, padding="max_length", max_length=max_input_length, return_tensors="pt")
    return labels.to(device), texts["input_ids"].to(device)

# Define batch size and device
BATCH_SIZE = 16  # Reduced for better performance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

# Load Pre-trained BERT Model
bert = BertModel.from_pretrained("bert-base-uncased")

# Freeze BERT parameters for efficiency
for param in bert.parameters():
    param.requires_grad = False

# Define BERT + GRU Model
class BERTGRUSentiment(nn.Module):
    def __init__(self, bert, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        self.bert = bert
        embedding_dim = bert.config.hidden_size
        self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, batch_first=True, dropout=0 if n_layers < 2 else dropout)
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        attention_mask = (text != tokenizer.pad_token_id).long()
        with torch.no_grad():
            outputs = self.bert(input_ids=text, attention_mask=attention_mask)
            embedded = outputs.last_hidden_state
        
        _, hidden = self.rnn(embedded)
        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
        
        return self.out(hidden)

# Initialize Model
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25

model = BERTGRUSentiment(bert, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)
model.to(device)

# Set Optimizer and Loss Function
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss().to(device)

# Train Function with Debug Prints
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss, epoch_acc = 0, 0
    
    for i, (labels, text) in enumerate(iterator):
        print(f"Processing batch {i+1}/{len(iterator)}")
        optimizer.zero_grad()
        predictions = model(text).squeeze(1)
        loss = criterion(predictions, labels)
        acc = ((torch.round(torch.sigmoid(predictions)) == labels).float().mean()).item()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc
        if i % 10 == 0:
            print(f"Batch {i}: Loss {loss.item():.4f}, Accuracy {acc*100:.2f}%")

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Evaluate Function
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss, epoch_acc = 0, 0
    
    with torch.no_grad():
        for labels, text in iterator:
            predictions = model(text).squeeze(1)
            loss = criterion(predictions, labels)
            acc = ((torch.round(torch.sigmoid(predictions)) == labels).float().mean()).item()
            epoch_loss += loss.item()
            epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Training Loop
N_EPOCHS = 5
best_valid_loss = float("inf")

for epoch in range(N_EPOCHS):
    train_loss, train_acc = train(model, train_dataloader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_dataloader, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "bert_gru_model.pt")

    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%")
    print(f"Valid Loss: {valid_loss:.3f} | Valid Acc: {valid_acc*100:.2f}%")

# Load Best Model and Test
model.load_state_dict(torch.load("bert_gru_model.pt"))
test_loss, test_acc = evaluate(model, test_dataloader, criterion)
print(f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%")

# Predict Sentiment
def predict_sentiment(model, tokenizer, sentence):
    model.eval()
    tokens = tokenizer.encode(sentence, truncation=True, padding="max_length", max_length=max_input_length, return_tensors="pt").to(device)
    prediction = torch.sigmoid(model(tokens)).item()
    return prediction

print(predict_sentiment(model, tokenizer, "This film is terrible"))
print(predict_sentiment(model, tokenizer, "This film is great"))


Processing batch 1/1563
Batch 0: Loss 0.6704, Accuracy 56.25%
Processing batch 2/1563
Processing batch 3/1563
Processing batch 4/1563
Processing batch 5/1563
Processing batch 6/1563
Processing batch 7/1563
Processing batch 8/1563
Processing batch 9/1563
Processing batch 10/1563
Processing batch 11/1563
Batch 10: Loss 0.6529, Accuracy 56.25%
Processing batch 12/1563
Processing batch 13/1563
Processing batch 14/1563
Processing batch 15/1563
Processing batch 16/1563
Processing batch 17/1563
Processing batch 18/1563
Processing batch 19/1563
Processing batch 20/1563
Processing batch 21/1563
Batch 20: Loss 0.4891, Accuracy 87.50%
Processing batch 22/1563
Processing batch 23/1563
Processing batch 24/1563
Processing batch 25/1563
Processing batch 26/1563
Processing batch 27/1563
Processing batch 28/1563
Processing batch 29/1563
Processing batch 30/1563
Processing batch 31/1563
Batch 30: Loss 0.3547, Accuracy 93.75%
Processing batch 32/1563
Processing batch 33/1563
Processing batch 34/1563
Proc