In [None]:

import logging
import sys

import torch

from neural_bandits.benchmark.datasets.imdb_reviews import ImdbMovieReviews

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))


dataset = ImdbMovieReviews(max_len=256, partition="train", dest_path="./data")

In [None]:
from torch import optim
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification

batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=2)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 1
model.train()

In [None]:
import tqdm

for epoch in range(num_epochs):
    total_loss = 0.0
    exp_moving_avg = 0.0
    progress_bar = tqdm.tqdm(dataloader, total=len(dataloader))
    for batch_inputs, labels in progress_bar:
        optimizer.zero_grad()
        # Forward pass: note that passing labels causes the model to compute the loss.
        outputs = model(*batch_inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        exp_moving_avg = 0.9 * exp_moving_avg + 0.1 * loss.item()
        progress_bar.set_description(f"Loss: {loss.item():.4f}, Exp. Moving Avg. Loss: {exp_moving_avg:.4f}")
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")



In [None]:
model.eval()

dataset_eval = ImdbMovieReviews(max_len=256, partition="test", dest_path="./data")
dataloader = DataLoader(dataset_eval, batch_size=batch_size, shuffle=True)

preds = []
labels = []
with torch.no_grad():
    for i in range(100):
        batch_inputs, labels_batch = next(iter(dataloader))
        outputs = model(*batch_inputs)
        logits = outputs.logits
        preds.extend(logits.argmax(dim=-1).tolist())
        labels.extend(labels_batch.argmax(dim=-1).tolist())

print("Accuracy:", sum([preds[i] == labels[i] for i in range(0, 100 * batch_size)]) / (100 * batch_size))