<a href="https://colab.research.google.com/github/probml/bandits/blob/main/bandits/scripts/subspace_bandits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bayesian Subspace bandits

See  https://arxiv.org/abs/2112.00195 for details.


## Installation

In [1]:
# !git clone --depth 1 https://github.com/probml/bandits

In [1]:
import matplotlib.pyplot as plt

In [2]:
!pip install --upgrade jax jaxlib # CPU-only
!pip3 install fire
!pip3 install ml-collections
!pip3 install git+https://github.com/deepmind/optax.git
!pip3 install flax
!pip3 install --upgrade git+https://github.com/google/flax.git
#!pip3 install -qqq git+git://github.com/deepmind/optax.git
#!pip3 install -qqq --upgrade git+https://github.com/google/flax.git

Collecting git+https://github.com/deepmind/optax.git
  Cloning https://github.com/deepmind/optax.git to /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-ybinu26u
  Running command git clone -q https://github.com/deepmind/optax.git /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-ybinu26u
  Resolved https://github.com/deepmind/optax.git to commit ebe7168f064c44924ac124735a6c9490fbd359ed
Collecting git+https://github.com/google/flax.git
  Cloning https://github.com/google/flax.git to /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-u9z7im_1
  Running command git clone -q https://github.com/google/flax.git /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-u9z7im_1
  Resolved https://github.com/google/flax.git to commit fec10eb643b68527bcb4a4e2b67de8649301e03e


In [3]:
!pip3 install --upgrade git+https://github.com/google/flax.git
!pip3 install --upgrade tensorflow-probability
!pip3 install git+https://github.com/blackjax-devs/blackjax.git
!pip3 install git+https://github.com/deepmind/distrax.git
!pip3 install blackjax

Collecting git+https://github.com/google/flax.git
  Cloning https://github.com/google/flax.git to /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-dpc93b0v
  Running command git clone -q https://github.com/google/flax.git /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-dpc93b0v
  Resolved https://github.com/google/flax.git to commit 4ba1d3416f4ccbada0b1e157e59b9763469ac699
Collecting git+https://github.com/blackjax-devs/blackjax.git
  Cloning https://github.com/blackjax-devs/blackjax.git to /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-fxd_r_85
  Running command git clone -q https://github.com/blackjax-devs/blackjax.git /private/var/folders/9t/286b6px10md3kt42t9bshfr40000gn/T/pip-req-build-fxd_r_85
  Resolved https://github.com/blackjax-devs/blackjax.git to commit 28167ad35fb880f7b90d0a78d8324f7fda8032e2
Collecting git+https://github.com/deepmind/distrax.git
  Cloning https://github.com/deepmind/distrax.git to /priva

## Test the installatation

In [2]:
%%bash
cd /Users/dmitrisaberi/Documents/GitHub/bandits
python bandits test

Expected Reward : 4420.50 ± 14.35
Time : 7.601s


## Setup 

In [3]:
%cd /content/bandits/bandits/experiments

[Errno 2] No such file or directory: '/content/bandits/bandits/experiments'
/Users/dmitrisaberi/Documents/GitHub/bandits/bandits/scripts


In [4]:
import os
os.chdir("..")

import jax
import ml_collections

import pandas as pd

import glob
from datetime import datetime

import scripts.movielens_exp as movielens_run
import scripts.mnist_exp as mnist_run
import scripts.tabular_exp as tabular_run
import scripts.tabular_subspace_exp as tabular_sub_run

print(jax.device_count())

1


In [5]:
def get_config(results_filename):
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.filepath = results_filename
  config.ntrials = 2 # was 10 in paper
  return config

In [6]:
timestamp = datetime.timestamp(datetime.now())

In [7]:
def plot_figure(data, x, y, filename, figsize=(24, 9), log_scale=False):   
    sns.set(font_scale=1.5)
    plt.style.use("seaborn-poster")

    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    g = sns.barplot(x=x, y=y, hue="Method", data=data, errwidth=2, ax=ax, palette=colors)
    if log_scale:
        g.set_yscale("log")
    plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))
    plt.tight_layout()
    plt.savefig(f"./figures/{filename}.png")
    plt.show()

def read_data(dataset_name):
    *_, filename = sorted(glob.glob(f"./results/{dataset_name}_results*.csv"))
    df = pd.read_csv(filename)
    if dataset_name=="mnist":
        linear_df = df[(df["Method"]=="Lin-KF") | (df["Method"]=="Lin")].copy()
        linear_df["Model"] = "MLP2"
        df = df.append(linear_df)
        linear_df["Model"] = "LeNet5"
        df = df.append(linear_df)

    by = ["Rank"] if dataset_name=="tabular" else ["Rank", "AltRank"]

    data_up = df.sort_values(by=by).copy()
    data_down = df.sort_values(by=by).copy()

    data_up["Reward"] = data_up["Reward"] + data_up["Std"]
    data_down["Reward"] = data_down["Reward"] - data_down["Std"]
    data = pd.concat([data_up, data_down])
    return data

def plot_subspace_figure(df, filename=None):
    df = df.reset_index().drop(columns=["index"])
    plt.style.use("seaborn-darkgrid")
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.lineplot(x="Subspace Dim", y="Reward", hue="Method", marker="o", data=df)
    lines, labels = ax.get_legend_handles_labels()
    for line, method in zip(lines, labels):
        data = df[df["Method"]==method]
        color = line.get_c()
        y_lower_bound =  data["Reward"] -  data["Std"]
        y_upper_bound = data["Reward"] + data["Std"]
        ax.fill_between(data["Subspace Dim"],  y_lower_bound, y_upper_bound, color=color, alpha=0.3)

    ax.set_ylabel("Reward", fontsize=16)
    plt.setp(ax.get_xticklabels(), fontsize=16) 
    plt.setp(ax.get_yticklabels(), fontsize=16) 
    ax.set_xlabel("Subspace Dimension(d)", fontsize=16)
    dataset = df.iloc[0]["Dataset"]
    ax.set_title(f"{dataset.title()} - Subspace Dim vs. Reward", fontsize=18)
    legend = ax.legend(loc="lower right", prop={'size': 16},frameon=1)
    frame = legend.get_frame()
    frame.set_color('white')
    frame.set_alpha(0.6)
    
    file_path = "./figures/"
    file_path = file_path + f"{dataset}_sub_reward.png" if filename is None else file_path + f"{filename}.png"
    plt.savefig(file_path)

# Run tabular experiments (not using this!)

In [8]:
%pwd
%cd /content/bandits/bandits
tabular_filename = f"./results/tabular_results_{timestamp}.csv"
config = get_config(tabular_filename)
tabular_run.main(config)
dataset_name = "tabular"
tabular_df = read_data(dataset_name)
tabular_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',
                'EKF-Orig-Full',  'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin', 'Lim2', 'NeuralTS']
tabular_df = tabular_df[tabular_df['Method'].isin(tabular_rows)]
x, y = "Dataset", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(tabular_df, x, y, filename)
x, y = "Dataset", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(tabular_df[tabular_df["Method"] != "NeuralTS"], x, y, filename, log_scale=True)

[Errno 2] No such file or directory: '/content/bandits/bandits'
/Users/dmitrisaberi/Documents/GitHub/bandits/bandits/scripts


FileNotFoundError: [Errno 2] No such file or directory: './bandit-data/bandit-statlog.pkl'

# Run movielens experiments

In [7]:
movielens_filename = f"./results/movielens_results_{timestamp}.csv"
config = get_config(movielens_filename)
movielens_run.main(config)

Bandit : Linear
Expected Reward : 2665.50 ± 16.50
	Time : 7.686638116836548:0.3f
Bandit : BNN
Expected Reward : 2078.50 ± 145.50
	Time : 38.232105016708374:0.3f
Bandit : Linear
Expected Reward : 2665.50 ± 16.50
	Time : 6.169766902923584:0.3f
Bandit : BNN
Expected Reward : 2078.50 ± 145.50
	Time : 51.37640404701233:0.3f


KeyError: 'BNN'

In [6]:
dataset_name = "movielens"
movielens_df = read_data(dataset_name)
movielens_rows =  ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',
                   'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']
movielens_df = movielens_df[movielens_df['Method'].isin(movielens_rows)]

In [None]:
x, y = "Model", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(movielens_df, x, y, filename)

In [None]:
x, y = "Model", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(movielens_df, x, y, filename)

# Run MNIST experiments

In [None]:
mnist_filename = f"./results/mnist_results_{timestamp}.csv"
config = get_config(mnist_filename)
mnist_run.main(config)

In [None]:
method_ordering = {"EKF-Sub-SVD": 0,
                   "EKF-Sub-RND": 1,
                   "EKF-Sub-Diag-SVD": 2,
                   "EKF-Sub-Diag-RND": 3,
                   "EKF-Orig-Full": 4,
                   "EKF-Orig-Diag": 5,
                   "NL-Lim": 6,
                   "NL-Unlim": 7,
                   "Lin": 8,
                   "Lin-KF": 9,
                   "Lin-Wide": 9,
                   "Lim2": 10,
                   "NeuralTS": 11}
                   
colors = {k : sns.color_palette("Paired")[v]
          if k!="Lin-KF" else  sns.color_palette("tab20")[8]
          for k,v in method_ordering.items()}

In [None]:
dataset_name = "mnist"
# For possible methods, run mnist_df.Method.unique()
mnist_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND', 'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']

In [None]:
mnist_df = read_data(dataset_name)
mnist_df = mnist_df[mnist_df['Method'].isin(mnist_rows)]

In [None]:
x, y = "Model", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(mnist_df, x, y, filename)

In [None]:
x, y = "Model", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(mnist_df, x, y, filename, log_scale=True)

# Run tabular subspace experiment

In [None]:
tabular_sub_filename = f"./results/tabular_subspace_results_{timestamp}.csv"
config = get_config(tabular_sub_filename)
tabular_sub_run.main(config)

In [None]:
*_, filename = sorted(glob.glob(f"./results/tabular_subspace_results*.csv"))
tabular_sub_df = pd.read_csv(filename)

In [None]:
dataset_name = "shuttle"
shuttle = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(shuttle)

In [None]:
dataset_name = "adult"
adult = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(adult)

In [None]:
dataset_name = "covertype"
covertype = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(covertype)