In [None]:
from functools import partial
from typing import Any

import lightning as pl
import torch
from torch.utils.data import DataLoader, Subset
from transformers import BertModel, DataCollatorForTokenClassification

from calvera.bandits import NeuralLinearBandit
from calvera.benchmark.datasets import ImdbMovieReviews
from calvera.benchmark import BanditBenchmarkEnvironment, BertWrapper
from calvera.utils import ListDataBuffer, AllDataBufferStrategy

import logging

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.FATAL)

In [None]:
def transformers_collate(batch: Any, data_collator: DataCollatorForTokenClassification) -> Any:
    """Custom collate function for the DataLoader.

    Args:
        batch: The batch to collate.
        data_collator: The data collator to use.

    Returns:
        The collated batch.
    """
    examples = []
    for item in batch:
        inputs = item[0]
        example = {
            "input_ids": inputs[0],
            "attention_mask": inputs[1],
            "token_type_ids": inputs[2],
        }
        examples.append(example)

    # Let the data collator process the list of individual examples.
    context = data_collator(examples)
    input_ids = context["input_ids"]
    attention_mask = context["attention_mask"]
    token_type_ids = context["token_type_ids"]

    if len(batch[0]) == 2:
        realized_rewards = torch.stack([item[1] for item in batch])
        return (input_ids, attention_mask, token_type_ids), realized_rewards

    embedded_actions = None if batch[0][1] is None else torch.stack([item[1] for item in batch])
    realized_rewards = torch.stack([item[2] for item in batch])
    chosen_actions = None if batch[0][3] is None else torch.stack([item[3] for item in batch])

    return (input_ids, attention_mask, token_type_ids), embedded_actions, realized_rewards, chosen_actions

In [None]:
# Load the StatLog dataset
dataset = ImdbMovieReviews()
print(f"Dataset context size: {dataset.context_size}")
print(f"Dataset sample count: {len(dataset)}")

# Create data loader for a subset of the data
collate_fn = partial(transformers_collate, data_collator=dataset.get_data_collator())
train_loader = DataLoader(Subset(dataset, range(10000)), batch_size=32, shuffle=True, collate_fn=collate_fn)

# Set up the environment
accelerator = "cpu"
env = BanditBenchmarkEnvironment(train_loader, device=accelerator)

In [None]:
# create network and bandit
network = BertWrapper(BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2", output_hidden_states=True).eval())

buffer = ListDataBuffer(
    buffer_strategy=AllDataBufferStrategy(),
    max_size=1024,
)

bandit_module = NeuralLinearBandit(
    network=network,
    buffer=buffer,
    n_embedding_size=128,
    contextualization_after_network=True,  # <------- Very Important
    n_arms=2,  # <-----------------------------------
    initial_train_steps=128,
    min_samples_required_for_training=128,
)

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

rewards = np.array([])
regrets = np.array([])
progress = tqdm(iter(env), total=len(env))

for contextualized_actions in progress:
    # 1. Select actions
    chosen_actions, _ = bandit_module.forward(contextualized_actions)

    # 2. Create a trainer for this step
    trainer = pl.Trainer(
        max_epochs=1,
        enable_progress_bar=False,
        enable_model_summary=False,
        accelerator=accelerator,
    )

    # 3. Get feedback from environment
    chosen_contextualized_actions, realized_rewards = env.get_feedback(chosen_actions)
    batch_regret = env.compute_regret(chosen_actions)

    # 4. Track metrics
    rewards = np.append(rewards, realized_rewards.cpu().numpy())
    regrets = np.append(regrets, batch_regret.cpu().numpy())
    progress.set_postfix(
        {"reward": realized_rewards.mean().item(), "regret": batch_regret.mean().item(), "avg_regret": regrets.mean()}
    )

    # 5. Update the bandit
    bandit_module.record_feedback(chosen_contextualized_actions, realized_rewards, chosen_actions)
    trainer.fit(bandit_module)
    bandit_module = bandit_module.to(accelerator)

# Store metrics
metrics = pd.DataFrame(
    {
        "reward": rewards,
        "regret": regrets,
    }
)

In [None]:
import matplotlib.pyplot as plt

# Calculate cumulative metrics
cumulative_reward = np.cumsum(metrics["reward"])
cumulative_regret = np.cumsum(metrics["regret"])

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(cumulative_reward[:1000], label="reward")
plt.plot(cumulative_regret[:1000], label="regret")
plt.xlabel("steps")
plt.ylabel("cumulative reward/regret")
plt.legend()
plt.show()

# Print average metrics at different time horizons
print(f"Average reward (first 10 rounds): {np.mean(metrics['reward'][:10]):.4f}")
print(f"Average reward (first 100 rounds): {np.mean(metrics['reward'][:100]):.4f}")
print(f"Average reward (all rounds): {np.mean(metrics['reward']):.4f}")
print("")
print(f"Average regret (first 10 rounds): {np.mean(metrics['regret'][:10]):.4f}")
print(f"Average regret (first 100 rounds): {np.mean(metrics['regret'][:100]):.4f}")
print(f"Average regret (all rounds): {np.mean(metrics['regret']):.4f}")