## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import urllib

from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributions as D
from torch.utils.data import TensorDataset, DataLoader

from tqdm.notebook import tqdm

## Config

In [3]:
BUFFER_SIZE = 4096
BATCH_SIZE = 64
LEARNING_RATE = 5e-6
N_TRAINING_STEPS = 18000
N_SAMPLES = 2

## Data

In [4]:
urllib.request.urlretrieve(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data",
    "agaricus-lepiota.data"
)

('agaricus-lepiota.data', <http.client.HTTPMessage at 0x7f35f8acedd0>)

In [5]:
df = pd.read_csv("agaricus-lepiota.data", header=None)

# Find the labels
labels = df.pop(df.columns[0])
labels = pd.Categorical(labels, categories=["p", "e"]).codes
labels

# Get the contexts
for col in df:
    df[col] = pd.Categorical(df[col]).codes
contexts = OneHotEncoder(sparse=False, dtype=np.float32).fit_transform(df)

# Convert to torch tensors
contexts = torch.tensor(contexts)
labels = torch.tensor(labels, dtype=bool)

contexts, labels

(tensor([[0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([False,  True,  True,  ...,  True, False,  True]))

In [6]:
contexts.shape, labels.shape

(torch.Size([8124, 117]), torch.Size([8124]))

In [7]:
contexts.dtype, labels.dtype

(torch.float32, torch.bool)

## BNN

In [8]:
def softplus_inverse(x):
    '''
        Computes the inverse of softplus f(x) = log(exp(x) - 1) in a numerically stable way.
    '''
    return x + torch.log(-torch.expm1(-x))


class VariationalLinear(nn.Module):
    def __init__(
        self,
        in_features:int, out_features:int, prior_distribution, bias=True,
        nonlinearity="relu", param=None,
    ):
        '''
            Args:
                prior:
                    the prior to be used.
                nonlinearity:
                    the nonlinearity that will follow the linear layer. This will be used to 
                    calculate the gain required to properly initialize the weights.
                    For more information see
                    https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain.
                    Default value is "relu".
                param:
                    any parameters needed to be passed for the function that calculates the gain
                    (see also the "nonlinearity" parameter).
        '''


        super().__init__()

        self.bias = bias

        # Calculate gain
        gain = torch.nn.init.calculate_gain(nonlinearity=nonlinearity, param=param)
        # Transform the gain according to the parameterization for sigma in the paper
        scale = softplus_inverse(gain / torch.sqrt(torch.tensor(in_features)))

        # Note that initialization is so that initially the sampled weights follow
        # N(0, gain ** 2 / in_features) distribution. This helps retain output distribution to
        # be closer to standard normal
        self.mu_weights = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.2, 0.2))
        self.rho_weights = nn.Parameter(torch.empty(out_features, in_features).uniform_(-5, -4))

        if bias:
            # Same initialization as above
            self.mu_bias = nn.Parameter(torch.empty(out_features).uniform_(-0.2, 0.2))
            self.rho_bias = nn.Parameter(torch.empty(out_features).uniform_(-5, -4))
        
        self.prior_distribution = prior_distribution


    def forward(self, x:torch.Tensor, prune_weights=False, pruning_threshold=0.0):
        '''
            Args:
                x: input tensor of size (batch_size, in_features)
            Output:
                tuple (tensor, scalar tensor):
                    - the first element is tensor of logits of the model for the input x. The shape is (batch_size, out_features).
                    - the second element is the KL divergence of the model weights sampled
                      in the forward pass. It is a 0-dim tensor containing only one scalar.
        '''
        
        kl_divergence = 0

        # Calculate W
        sigma_weights = F.softplus(self.rho_weights)
        weight_distribution = D.Normal(
            loc=self.mu_weights, scale=sigma_weights
        )
        W = weight_distribution.rsample()
        if prune_weights:
            snr = self.mu_weights.abs() / sigma_weights
            mask = snr <= pruning_threshold
            W[mask] = 0

        # Calculate weight contribution to KL divergence
        kl_divergence += weight_distribution.log_prob(W).sum()
        kl_divergence -= self.prior_distribution.log_prob(W).sum()

        # Multiply input by W
        out = torch.mm(x, W.T)

        # Handle bias
        if self.bias:
            sigma_bias = F.softplus(self.rho_bias)
            bias_distribution = D.Normal(
                loc=self.mu_bias, scale=sigma_bias
            )
            b = bias_distribution.rsample()
            if prune_weights:
                snr = self.mu_bias.abs() / sigma_bias
                mask = snr <= pruning_threshold
                b[mask] = 0

            # Add the bias
            out += b

            # Calculate bias contribution to KL divergence
            kl_divergence += bias_distribution.log_prob(b).sum()
            kl_divergence -= self.prior_distribution.log_prob(b).sum()

        return out, kl_divergence

In [9]:
class VariationalMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, prior_distribution):
        super().__init__()

        self.prior_distribution = prior_distribution

        self.layers = nn.ModuleList([
            VariationalLinear(
                in_features=input_dim, out_features=hidden_dim,
                prior_distribution=prior_distribution
            ),
            VariationalLinear(
                in_features=hidden_dim, out_features=hidden_dim,
                prior_distribution=prior_distribution
            ),
            VariationalLinear(
                in_features=hidden_dim, out_features=output_dim,
                prior_distribution=prior_distribution
            ),
        ])
    
    def _single_forward(self, x, prune_weights, pruning_threshold):
        total_kl_divergence = 0
        for i, layer in enumerate(self.layers):
            x, kl_divergence = layer(
                x, prune_weights=prune_weights, pruning_threshold=pruning_threshold
            )

            if i < len(self.layers) - 1:
                x = F.relu(x)

            total_kl_divergence += kl_divergence

        x = x.unsqueeze(dim=1)

        return x, total_kl_divergence
    

    def forward(self, x, n_samples=1, prune_weights=False, pruning_threshold=0.0):

        logits = []
        kl_divergence = 0
        
        for _ in range(n_samples):
            curr_logits, curr_kl_divergence = self._single_forward(
                x, prune_weights=prune_weights, pruning_threshold=pruning_threshold
            )
            logits.append(curr_logits)
            kl_divergence += curr_kl_divergence

        logits = torch.cat(logits, axis=1)
        kl_divergence /= n_samples
    
        return logits, kl_divergence

In [10]:
class RegressionELBO(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, outputs, labels, kl_divergence, kl_weight):
        nll = self._get_neg_log_lik(y_pred=outputs, y_true=labels)

        elbo = kl_weight * kl_divergence + nll
        
        return elbo, nll

    
    def _get_neg_log_lik(self, y_true, y_pred):
        batched_nll = (y_pred - y_true.unsqueeze(-1))**2 / 2
        
        return batched_nll.sum(dim=0).mean(dim=0)


## Environment

In [11]:
class Environment:

    def __init__(self, contexts, labels):
        self.contexts = contexts
        self.labels = labels
    
    def get_random_mushroom(self):
        mushroom_idx = np.random.randint(len(self.contexts))

        return self.contexts[mushroom_idx], self.labels[mushroom_idx].item()

    def get_agent_reward(self, edible, eaten):
        if not eaten:
            return 0
        if edible:
            return 5
        if torch.rand(1).item() <= 0.5:
            return -35
        return 5
    
    def get_oracle_reward(self, edible):
        return 5 * float(edible)


## Agent

In [12]:
def get_kl_weight(M, batch_index=-1, uniform_kl_weight=True):
    '''
        M: number of batches
    '''

    kl_weight = 1 / M
    if not uniform_kl_weight:
        if batch_index == -1:
            raise Exception("Batch Index Not specified while getting Loss")

        # The batch_index + 1 is because we should be counting the batches
        # from 1 to M, not from 0 to M-1
        kl_weight = 2**(M - (batch_index + 1)) / (2**M - 1)

    return kl_weight

In [13]:
class AgentBNN:

    def __init__(
            self, model, optimizer, batch_size, n_samples=2,
            uniform_kl_weight=False, buffer_size=4096, context_size=117,
        ):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.n_samples = n_samples
        self.uniform_kl_weight = uniform_kl_weight
        print(f"uniform_kl_weight={self.uniform_kl_weight}")
        
        self.criterion = RegressionELBO()
        
        self.buffer_size = buffer_size
        self.context_action_buffer = torch.zeros(buffer_size, context_size + 2)
        self.reward_buffer = torch.zeros(buffer_size)
        
        self.tp = 0
        self.tn = 0
        self.fp = 0
        self.fn = 0
        
        self.tps = []
        self.tns = []
        self.fps = []
        self.fns = []
        
        self.decisions = []


    def step(self, mushroom_context):
        eat_action = torch.hstack([mushroom_context, torch.tensor([1, 0])]).view(1, -1)
        not_eat_action = torch.hstack([mushroom_context, torch.tensor([0, 1])]).view(1, -1)

        self.model.eval()
        with torch.no_grad():
            eat_reward = self.model(eat_action, n_samples=self.n_samples)[0].mean().item()
            not_eat_reward = self.model(not_eat_action, n_samples=self.n_samples)[0].mean().item()
        
        # Action is to eat or not
        return eat_reward > not_eat_reward
    

    def update_buffers(self, context_action_pair, reward, step, edible, eat):
        new_idx = step % self.buffer_size
        self.context_action_buffer[new_idx, :] = context_action_pair
        self.reward_buffer[new_idx] = reward
        
        # record bandit action
        if edible and eat:
            self.tp += 1
        elif edible and not eat:
            self.fn += 1
        elif not edible and eat:
            self.fp += 1
        else:
            self.tn += 1
        
        self.tps.append(self.tp)
        self.tns.append(self.tn)
        self.fps.append(self.fp)
        self.fns.append(self.fn)
        
        self.decisions.append(eat)


    def _get_training_dataloader(self, step):
        max_idx = min(step, self.buffer_size)
        
        training_context_action_pairs = self.context_action_buffer[:max_idx, :]
        training_rewards = self.reward_buffer[:max_idx]

        if max_idx < self.buffer_size:
            indices = torch.randint(
                high=training_context_action_pairs.shape[0], size=(self.buffer_size, ),
            )
            training_context_action_pairs = training_context_action_pairs[indices]
            training_rewards = training_rewards[indices]

        dataset = TensorDataset(training_context_action_pairs, training_rewards)

        return DataLoader(
            dataset=dataset, batch_size=self.batch_size,
            shuffle=True, drop_last=False, num_workers=0,
        )


    def train(self, step):
        dataloader = self._get_training_dataloader(step=step)

        running_elbo = 0
        running_kl = 0
        running_nll = 0
        self.model.train()
        for batch_idx, (x, y) in enumerate(dataloader):
            predicted_rewards, kl_divergence = self.model(x, n_samples=self.n_samples)

            kl_weight = get_kl_weight(
                M=len(dataloader.dataset) // self.batch_size,
                batch_index=batch_idx,
                uniform_kl_weight=self.uniform_kl_weight,
            )
            elbo, nll = self.criterion(
                outputs=predicted_rewards, labels=y.view(-1, 1),
                kl_divergence=kl_divergence, kl_weight=kl_weight,
            )

            self.optimizer.zero_grad()
            elbo.backward()
            self.optimizer.step()

            running_elbo += elbo.item()
            running_kl += (elbo - nll).item()
            running_nll += nll.item()

        return running_elbo, running_kl, running_nll

## Trainer

In [26]:
def train_rl(n_steps, environment, agent, name, scheduler=None, checkpoint=None):
    print(f"n_steps={n_steps}")

    cumulative_regrets = [0]
    elbos = []
    kl_divergences = []
    nlls = []
    
    if checkpoint is not None:
        print(f"Loading checkpoint: {checkpoint}")
        checkpoint = torch.load(checkpoint)

        agent.model.load_state_dict(checkpoint['model'])
        agent.optimizer.load_state_dict(checkpoint['optimizer'])

        agent = checkpoint["agent"]
        cumulative_regrets = checkpoint["cumulative_regrets"]
        elbos = checkpoint["elbos"]
        kl_divergences = checkpoint["kl_divergences"]
        nlls = checkpoint["nlls"]
        
        if len(cumulative_regrets) % 10 == 0:
            cumulative_regrets = [0] + cumulative_regrets
    

    loop = tqdm(range(1, n_steps+1), total=n_steps, leave=False)
    for step in loop:
        # Get a new mushroom
        mushroom_context, edible = environment.get_random_mushroom()

        # Decide whether to eat it
        eat = agent.step(mushroom_context=mushroom_context)
    
        # Calculate the different reward
        agent_reward = environment.get_agent_reward(edible=edible, eaten=eat)
        oracle_reward = environment.get_oracle_reward(edible=edible)

        # Update the buffers
        action = torch.Tensor([1, 0] if eat else [0, 1])
        agent.update_buffers(
            context_action_pair=torch.hstack([mushroom_context, action]),
            reward=agent_reward,
            step=step-1,
            edible=edible, eat=eat,
        )

        # Calculate regret
        regret = oracle_reward - agent_reward
        cumulative_regrets.append(cumulative_regrets[-1] + regret)

        elbo, kl, nll = agent.train(step=step)
        elbos.append(elbo)
        kl_divergences.append(kl)
        nlls.append(nll)
        
        loop.set_postfix(
            cumulative_regret=cumulative_regrets[-1], elbo=elbo, kl=kl, nll=nll,
        )
        
        if scheduler is not None:
            scheduler.step()
        
        # Checkpoint
        if step % 1000 == 0:
            print(f"Step {step}: regret: {cumulative_regrets[-1]} ({len(cumulative_regrets[1:])})")
            print(f"TP: {agent.tp}, TN: {agent.tn}, FP: {agent.fp}, FN: {agent.fn}")
            if scheduler is not None:
                print(f"LR: {scheduler.get_last_lr()[0]}")
            df = pd.DataFrame.from_dict({
                "cumulative_regret": cumulative_regrets[1:],
                "elbo": elbos,
                "kl_divergence": kl_divergences,
                "nll": nlls,
            })
            df.to_csv(f"{name}.csv")
        
            save_dict = {
                "model": agent.model.state_dict(),
                "optimizer": agent.optimizer.state_dict(),
                "agent": agent,
                "cumulative_regrets": cumulative_regrets,
                "elbos": elbos,
                "kl_divergences": kl_divergences,
                "nlls": nlls,
            }
            if scheduler is not None: save_dict["scheduler"] = scheduler.state_dict()
            torch.save(save_dict, f"checkpoint_{step}.ckpt")
    
    return cumulative_regrets[1:], elbos, kl_divergences, nlls

## Run

In [27]:
%%time

# Prior
sigma_1 = torch.exp(-torch.tensor(0))
sigma_2 = torch.exp(-torch.tensor(6))

p = 1/2
mixture_distribution = D.Categorical(probs=torch.tensor([p, 1 - p]))
component_distribution = D.Normal(
    loc=torch.zeros(2),
    scale=torch.tensor([sigma_1, sigma_2]),
)
prior_distribution = D.MixtureSameFamily(
    mixture_distribution=mixture_distribution, component_distribution=component_distribution
)


# Model
model = VariationalMLP(
    input_dim=contexts.shape[1] + 2,
    hidden_dim=100, output_dim=1,
    prior_distribution=prior_distribution
)
print(f"Number of model parameters: {sum(p.nelement() for p in model.parameters())}")


# Optimizer
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=LEARNING_RATE,
)
print(optimizer)


# Scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)
# def epoch_to_factor(epoch):
#     if epoch <= 5000: return 1
#     if epoch <= 25000: return 0.1
#     return 1/20
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=epoch_to_factor)
scheduler = None
print(f"scheduler={scheduler}")


environment = Environment(contexts=contexts, labels=labels)

agent = AgentBNN(
    model=model, optimizer=optimizer,
    batch_size=BATCH_SIZE,
    buffer_size=BUFFER_SIZE,
    context_size=contexts.shape[1],
    n_samples=N_SAMPLES,
    uniform_kl_weight=True,
)

cumulative_regrets, elbos, kl_divergences, nlls = train_rl(
    n_steps=N_TRAINING_STEPS,
    environment=environment,
    agent=agent,
    name=f"bnn",
    scheduler=scheduler,
    checkpoint="/kaggle/input/aml-rl-bnn-checpoint-32000-v2/checkpoint_32000.ckpt",
)

Number of model parameters: 44402
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5e-06
    maximize: False
    weight_decay: 0
)
scheduler=None
uniform_kl_weight=True
n_steps=18000
Loading checkpoint: /kaggle/input/aml-rl-bnn-checpoint-32000-v2/checkpoint_32000.ckpt


  0%|          | 0/18000 [00:00<?, ?it/s]

Step 10: regret: 3965.0 (32010)
TP: 16502, TN: 15195, FP: 196, FN: 117


KeyboardInterrupt: 

In [16]:
# checkpoint = torch.load("/kaggle/input/aml-rl-bnn-checkpoint-32000/checkpoint_32000.ckpt")

# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])

# agent = checkpoint["agent"]
# cumulative_regrets = checkpoint["cumulative_regrets"][1:]
# elbos = checkpoint["elbos"]
# kl_divergences = checkpoint["kl_divergences"]
# nlls = checkpoint["nlls"]

In [17]:
len(cumulative_regrets)

NameError: name 'cumulative_regrets' is not defined

In [None]:
cumulative_regrets[-10:]

In [None]:
plt.plot(range(1, len(cumulative_regrets) + 1), cumulative_regrets)
plt.yscale("log")
plt.grid()
plt.show()

In [None]:
plt.plot(range(1, len(elbos) + 1), elbos, "-o")
plt.grid()

In [None]:
plt.plot(range(1, len(kl_divergences) + 1), kl_divergences, "-o")
plt.grid()

In [None]:
plt.plot(range(1, len(nlls) + 1), nlls, "-o")
plt.grid()