## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import urllib

from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from tqdm.notebook import tqdm

## Config

In [None]:
BUFFER_SIZE = 4096
BATCH_SIZE = 64
EPS = 0.0
LEARNING_RATE = 1e-3
N_TRAINING_STEPS = 50000

## Data

In [None]:
urllib.request.urlretrieve(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data",
    "agaricus-lepiota.data"
)

In [None]:
df = pd.read_csv("agaricus-lepiota.data", header=None)

# Find the labels
labels = df.pop(df.columns[0])
labels = pd.Categorical(labels, categories=["p", "e"]).codes
labels

# Get the contexts
for col in df:
    df[col] = pd.Categorical(df[col]).codes
contexts = OneHotEncoder(sparse=False, dtype=np.float32).fit_transform(df)

# Convert to torch tensors
contexts = torch.tensor(contexts)
labels = torch.tensor(labels, dtype=bool)

contexts, labels

In [None]:
contexts.shape, labels.shape

In [None]:
contexts.dtype, labels.dtype

## Environment

In [None]:
class Environment:

    def __init__(self, contexts, labels):
        self.contexts = contexts
        self.labels = labels
    
    def get_random_mushroom(self):
        mushroom_idx = np.random.randint(len(self.contexts))

        return self.contexts[mushroom_idx], self.labels[mushroom_idx].item()

    def get_agent_reward(self, edible, eaten):
        if not eaten:
            return 0
        if edible:
            return 5
        if torch.rand(1).item() <= 0.5:
            return -35
        return 5
    
    def get_oracle_reward(self, edible):
        return 5 * float(edible)


## Agent

In [None]:
class Agent:

    def __init__(
            self, model, optimizer, batch_size,
            eps=0.0, buffer_size=4096, context_size=117,
        ):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        
        self.eps = eps
        print(f"eps={self.eps}")
        
        self.buffer_size = buffer_size
        self.context_action_buffer = torch.zeros(buffer_size, context_size + 2)
        self.reward_buffer = torch.zeros(buffer_size)
        
        self.tp = 0
        self.tn = 0
        self.fp = 0
        self.fn = 0
        
        self.tps = []
        self.tns = []
        self.fps = []
        self.fns = []
        
        self.decisions = []


    def step(self, mushroom_context):
        eat_action = torch.hstack([mushroom_context, torch.tensor([1, 0])])
        not_eat_action = torch.hstack([mushroom_context, torch.tensor([0, 1])])

        self.model.eval()
        with torch.no_grad():
            eat_reward = self.model(eat_action).item()
            not_eat_reward = self.model(not_eat_action).item()
        
        # Action is to eat or not
        eat = eat_reward > not_eat_reward
        if np.random.rand() <= self.eps:
            eat = np.random.rand() <= 0.5
        
        return eat
    

    def update_buffers(self, context_action_pair, reward, step, edible, eat):
        new_idx = step % self.buffer_size
        self.context_action_buffer[new_idx, :] = context_action_pair
        self.reward_buffer[new_idx] = reward
        
        # record bandit action
        if edible and eat:
            self.tp += 1
        elif edible and not eat:
            self.fn += 1
        elif not edible and eat:
            self.fp += 1
        else:
            self.tn += 1
        
        self.tps.append(self.tp)
        self.tns.append(self.tn)
        self.fps.append(self.fp)
        self.fns.append(self.fn)
        
        self.decisions.append(eat)


    def _get_training_dataloader(self, step):
        max_idx = min(step, self.buffer_size)
        
        training_context_action_pairs = self.context_action_buffer[:max_idx, :]
        training_rewards = self.reward_buffer[:max_idx]

        if max_idx < self.buffer_size:
            indices = torch.randint(
                high=training_context_action_pairs.shape[0], size=(self.buffer_size, ),
            )
            training_context_action_pairs = training_context_action_pairs[indices]
            training_rewards = training_rewards[indices]

        dataset = TensorDataset(training_context_action_pairs, training_rewards)

        return DataLoader(
            dataset=dataset, batch_size=self.batch_size,
            shuffle=True, drop_last=False, num_workers=0,
        )


    def train(self, step):
        dataloader = self._get_training_dataloader(step=step)

        running_loss = 0
        self.model.train()
        for x, y in dataloader:
            predicted_rewards = self.model(x).squeeze(dim=-1)
            loss = F.mse_loss(input=predicted_rewards, target=y)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item() * len(y)

        return running_loss / len(dataloader.dataset)

## Trainer

In [None]:
def train_rl(n_steps, environment, agent, name, scheduler=None):

    cumulative_regrets = [0]
    losses = []

    loop = tqdm(range(1, n_steps+1), total=n_steps, leave=False)
    for step in loop:
        # Get a new mushroom
        mushroom_context, edible = environment.get_random_mushroom()

        # Decide whether to eat it
        eat = agent.step(mushroom_context=mushroom_context)

        # Calculate the different reward
        agent_reward = environment.get_agent_reward(edible=edible, eaten=eat)
        oracle_reward = environment.get_oracle_reward(edible=edible)

        # Update the buffers
        action = torch.Tensor([1, 0] if eat else [0, 1])
        agent.update_buffers(
            context_action_pair=torch.hstack([mushroom_context, action]),
            reward=agent_reward,
            step=step-1,
            edible=edible,
            eat=eat,
        )

        # Calculate regret
        regret = oracle_reward - agent_reward
        cumulative_regrets.append(cumulative_regrets[-1] + regret)

        loss = agent.train(step=step)
        losses.append(loss)
    
        if scheduler is not None:
            scheduler.step()
            loop.set_postfix(
                cumulative_regret=cumulative_regrets[-1], loss=loss,
                lr=scheduler._last_lr[0]
            )
        else:
            loop.set_postfix(
                cumulative_regret=cumulative_regrets[-1], loss=loss,
            )
        
        if step % (n_steps // 50) == 0:
            print(f"Step {step}: regret: {cumulative_regrets[-1]} ({len(cumulative_regrets[1:])})")
            print(f"TP: {agent.tp}, TN: {agent.tn}, FP: {agent.fp}, FN: {agent.fn}")
            df = pd.DataFrame.from_dict({
                "cumulative_regrets": cumulative_regrets[1:],
                "losses": losses
            })
            df.to_csv(f"{name}.csv")
            
            save_dict = {
                "model": agent.model.state_dict(),
                "optimizer": agent.optimizer.state_dict(),
                "agent": agent,
                "cumulative_regrets": cumulative_regrets,
            }
            if scheduler is not None: save_dict["scheduler"] = scheduler.state_dict()
            torch.save(save_dict, f"checkpoint_{step}.ckpt")
    
    return cumulative_regrets[1:], losses

## Run

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        m.bias.data.fill_(0)

In [None]:
%%time

environment = Environment(contexts=contexts, labels=labels)

model = nn.Sequential(
    nn.Linear(in_features=contexts.shape[1] + 2, out_features=100), nn.ReLU(),
    nn.Linear(in_features=100, out_features=100), nn.ReLU(),
    nn.Linear(in_features=100, out_features=1),
)
model.apply(init_weights)
print(f"Number of model parameters: {sum(p.nelement() for p in model.parameters())}")

optimizer = torch.optim.SGD(
    params=model.parameters(), lr=LEARNING_RATE,
    momentum=0.9, nesterov=True,
)
print(optimizer)

# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)
scheduler = None

agent = Agent(
    model=model, optimizer=optimizer,
    batch_size=BATCH_SIZE,
    eps=EPS, buffer_size=BUFFER_SIZE,
    context_size=contexts.shape[1],
)

print("Training")
cumulative_regrets, losses = train_rl(
    n_steps=N_TRAINING_STEPS, environment=environment,
    agent=agent, scheduler=scheduler,
    name=f"greedy_eps_{EPS}",
)

In [None]:
cumulative_regrets[-10:]

In [None]:
plt.plot(range(1, len(cumulative_regrets) + 1), cumulative_regrets)
plt.yscale("log")
plt.grid()
plt.show()

In [None]:
plt.plot(range(1, len(losses) + 1), losses, "-o")
plt.grid()