In [51]:
import logging

import lightning as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import pandas as pd
import numpy as np

from calvera.bandits.neural_linear_bandit import NeuralLinearBandit
from calvera.benchmark.datasets.statlog import StatlogDataset
from calvera.utils.data_storage import InMemoryDataBuffer, AllDataBufferStrategy


from calvera.benchmark.environment import BanditBenchmarkEnvironment
from calvera.benchmark.logger_decorator import OnlineBanditLoggerDecorator

In [52]:
class Network(nn.Module):
    def __init__(self, dim, hidden_size=100, n_embedding_size=10):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_size)
        self.activate = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, n_embedding_size)

    def forward(self, x):
        return self.fc2(self.activate(self.fc1(x)))

In [53]:
dataset = StatlogDataset()
print(dataset.context_size)
print(len(dataset))

63
58000


In [54]:
buffer = InMemoryDataBuffer(
    buffer_strategy=AllDataBufferStrategy(),
    max_size=10000,
)

network = Network(dataset.context_size, hidden_size=100, n_embedding_size=10)

train_loader = DataLoader(Subset(dataset, range(10000)), batch_size=1, shuffle=True)
env = BanditBenchmarkEnvironment(train_loader)
bandit_module = NeuralLinearBandit(
    n_embedding_size=10,
    network=network,
    buffer=buffer,
    train_batch_size=32,
    early_stop_threshold=1e-3,
    weight_decay=1e-3,
    learning_rate=1e-3,
    min_samples_required_for_training=8,
    initial_train_steps=2048,
)

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.FATAL)
logger = OnlineBanditLoggerDecorator(
    CSVLogger("logs", name="neural_linear_bandit", flush_logs_every_n_steps=100),
    enable_console_logging=False,
)

In [None]:
rewards = np.array([])
regrets = np.array([])
progress_bar = tqdm(iter(env), total=len(env))
for contextualized_actions in progress_bar:
    chosen_actions, _ = bandit_module.forward(contextualized_actions)

    trainer = pl.Trainer(
        max_epochs=1,
        max_steps=1000,
        logger=logger,
        log_every_n_steps=1,
        enable_progress_bar=False,
        enable_model_summary=False,
        enable_checkpointing=False,
    )
    chosen_contextualized_actions, realized_rewards = env.get_feedback(chosen_actions)
    batch_regret = env.compute_regret(chosen_actions)

    # append batch of rewards and regrets
    rewards = np.append(rewards, realized_rewards.cpu().numpy())
    regrets = np.append(regrets, batch_regret.cpu().numpy())

    progress_bar.set_postfix(
        reward=realized_rewards.sum(dim=1).mean().item(),  # shape batch_size, selected_arms
        regret=batch_regret.mean().item(),
        average_regret=regrets.mean(),
    )

    bandit_module.record_feedback(chosen_contextualized_actions, realized_rewards)
    trainer.fit(bandit_module)
metrics = pd.DataFrame(
    {
        "reward": rewards,
        "regret": regrets,
    }
)

  0%|          | 0/40 [00:00<?, ?it/s]


AttributeError: 'int' object has no attribute 'ndim'

In [None]:
# load metrics from the logger and plot
import numpy as np

cumulative_reward = np.cumsum(metrics["reward"][:10000])
cumulative_regret = np.cumsum(metrics["regret"][:10000])

In [None]:
import matplotlib.pyplot as plt

plt.plot(cumulative_reward, label="Cumulative Reward")
plt.plot(cumulative_regret, label="Cumulative Regret")
plt.xlabel("steps")
plt.ylabel("cumulative reward/regret")
plt.legend()
plt.show()

In [None]:
print(sum(metrics["reward"][:10]) / 10)
print(sum(metrics["reward"][:100]) / 100)
print(sum(metrics["reward"][:10000]) / 10000)

print(sum(metrics["regret"][:10].dropna()) / 10)
print(sum(metrics["regret"][:100].dropna()) / 100)
print(sum(metrics["regret"][:10000].dropna()) / 10000)

In [None]:
bandit_metrics_csv = logger._logger_wrappee.log_dir + "/metrics.csv"
print(bandit_metrics_csv)
bandit_metrics = pd.read_csv(bandit_metrics_csv)

plt.plot(bandit_metrics["loss"][:10000].dropna(), label="Loss")
plt.xlabel("steps")
plt.ylabel("loss")
plt.legend()
plt.show()