In [None]:
from collections import OrderedDict
import logging
import os
import sys
sys.path.append(os.path.join(os.getcwd(), os.pardir))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import PolynomialFeatures
from statsmodels.tsa.api import SimpleExpSmoothing

from bandits.arms import GaussianMixtureArm
from bandits.context import Context
from bandits.banditPlayer import BanditPlayer
from bandits.banditLearner import (SGDLearner, XGBLearner, OptimisticSGDLearner, AdaptiveRandomForestLearner, 
                                   PerceptronLearner, BaggedLinearRegressor ,LinearExpertsLearner)

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
diag_down = np.array([[-1, 1], [1, -1]])
diag_up = np.array([[-1, -1], [1, 1]])
left = np.array([[-1, -1], [-1, 1]])
right = np.array([[1, -1], [1, 1]])
top = np.array([[-1, 1], [1, 1]])
bottom = np.array([[-1, -1], [1, -1]])

a0 = GaussianMixtureArm(
    centres = np.array(left), 
    stds= np.array([1, 1]),
    factor=1,
    noise=.05,
)
a1 = GaussianMixtureArm(
    centres = np.array(right), 
    stds= np.array([1, 1]),
    factor=1,
    noise=.05,
)


def update_sudden(player: BanditPlayer, start: tuple, stop: tuple, n=4000):
    arms = list(player.arms.keys())
    assert len(start) == len(arms)
    assert len(stop) == len(arms)
    def update(i):
        if i == 0:
            print("init")
            for ii, centres in enumerate(start):
                player.arms[arms[ii]].centres = centres
        if i == n:
            print("switch")
            for ii, centres in enumerate(stop):
                player.arms[arms[ii]].centres = centres
    return update
            

def update_gradual(player: BanditPlayer, start: tuple, stop: tuple, n1=4000, n2=6000):
    arms = list(player.arms.keys())
    assert len(start) == len(arms)
    assert len(stop) == len(arms)
    def update(i):
        if i == 0:
            print("init")
            for ii, centres in enumerate(start):
                player.arms[arms[ii]].centres = centres
            for arm in player.arms:
                player.arms[arm].factor = 1
        if n1 <= i <= n2:
            if i == n1:
                print("Start shift")
            alpha = (i - n1) / (n2 - n1)
            for arm in player.arms:
                player.arms[arm].factor = 1 - 2 * alpha
#             for ii, centres in enumerate(stop):
#                 player.arms[arms[ii]].centres = (1 - alpha) * start[ii] + alpha * stop[ii]
    return update
            
    
def update_eps(learner, i):        
    if i < 300:
        learner.eps = 0.9
    learner.eps = 0.1
    
    
poly = PolynomialFeatures(4)

def logistic(f: float) -> float:
    return 1/(1+np.exp(-f))

In [None]:
n = 8000

In [None]:
def make_audit_data():
    N = 100
    margins = np.linspace(-2, 2, N).reshape([-1, 1])
    px = np.tile(margins, [N, 1])
    py = np.repeat(np.flip(margins), N).reshape([-1, 1])
    X = np.concatenate([px, py], axis=1)
    return X

In [None]:
def reshape_vals(vals):
    n = int(round(len(vals)**0.5, 0))
    return np.array(vals).reshape([n, n])

In [None]:
def plot_arm(arm):
    values = arm.value(make_audit_data())[1]
    plt.imshow(reshape_vals(values))

In [None]:
def compare_arms(arm1, arm2):
    values = arm2.value(make_audit_data())[1] - arm1.value(make_audit_data())[1]
    plt.imshow(reshape_vals(values))

In [None]:
def show_choice(learners: dict, gt_name="Ground Truth"):
    fig, axes = plt.subplots(1, len(learners)+1)
    for i, (name, learner) in enumerate(learners.items()):
        X = make_audit_data()
        if "xperts" in name or "agged" in name:
            X = PolynomialFeatures(4).fit_transform(X)
        arms = OrderedDict({n: i for i, n in enumerate(learner.learners.keys())})
        choices = [arms[learner.choose(i)] for i in X]
        names = list(arms.keys())
        n = int(round(len(choices)**0.5, 0)) 
        axes[i].tick_params(left=False,
                            bottom=False,
                            labelleft=False,
                            labelbottom=False)
        axes[i].imshow(np.array(choices).reshape([n, n]))
        axes[i].title.set_text(name)
    axes[-1].imshow(reshape_vals(
        player.arms[names[1]].value(make_audit_data())[1] - 
        player.arms[names[0]].value(make_audit_data())[1]
    ))
    axes[-1].tick_params(left=False,
                            bottom=False,
                            labelleft=False,
                            labelbottom=False)
    axes[-1].title.set_text(gt_name)

In [None]:
show_choice(learners={"ADA-RF": learner4, "Experts": learner5, "Bagged": learner6}, gt_name=f"GT: {i}")

In [None]:
def plot_learner(learner):
    X = make_audit_data()
    values = [learner.predict_one({str(i): v for i, v in enumerate(x.squeeze())}) for x in poly.fit_transform(X)]
    plt.imshow(reshape_vals(values))

In [None]:
player = BanditPlayer({"a0": a0, "a1": a1})
context = Context(n, 2)

poly = PolynomialFeatures(4)

learner4 = AdaptiveRandomForestLearner(2, n_trees=21, min_samples_split=9)
regrets4 = []
learner5 = LinearExpertsLearner(2)
regrets5 = []
learner6 = BaggedLinearRegressor(2)
regrets6 = []

# SGD learner needs higher order features
poly = PolynomialFeatures(4)


In [None]:
np.random.seed(272)

# update = update_sudden(player, (diag_down, diag_up), (diag_up, diag_down), n=int(n/2))
update = update_gradual(player, (diag_down, diag_up), (diag_up, diag_down), n1=int(n/3), n2=int(2/3*n))

In [None]:
for i in range(context.contexts.shape[0]):
    state = context.contexts[i, :].reshape([1, -1])
    poly_state = poly.fit_transform(state)
    
    if (i % 1000) == 0:
        print(i)
    
    action4 = learner4.choose(state)
    reward4, regret4 = player.play_one(state, action4)
    learner4.update(state, action4, reward4)
    regrets4.append(regret4)
    update_eps(learner4, i)
    
    action5 = learner5.choose(poly_state)
    reward5, regret5 = player.play_one(state, action5)
    learner5.update(poly_state, action5, reward5)
    regrets5.append(regret5)
    update_eps(learner5, i)
    
    action6 = learner6.choose(poly_state)
    reward6, regret6 = player.play_one(state, action6)
    learner6.update(poly_state, action6, reward6)
    regrets6.append(regret6)
    update_eps(learner6, i)
    
    update(i)
    if i in [100, 500, 2000, 3000, 4001, 4500, 6000, 7000, 7999]:
        show_choice(learners={"ADA-RF": learner4, "Experts": learner5, "Bagged": learner6}, gt_name=f"GT: {i}")


In [None]:
frac_opt_actions4 = SimpleExpSmoothing(np.array(regrets4) == 0).fit(smoothing_level=0.006).fittedvalues
frac_opt_actions5 = SimpleExpSmoothing(np.array(regrets5) == 0).fit(smoothing_level=0.006).fittedvalues
frac_opt_actions6 = SimpleExpSmoothing(np.array(regrets6) == 0).fit(smoothing_level=0.006).fittedvalues

In [None]:
fig = go.Figure(layout_title=f"Fraction of optimal choice (Exponentially smoothed), reversal at {int(n/2)}",
               layout_xaxis_title="Index",
               layout_yaxis_title="Cumulative fraction"
               )
fig.add_trace(go.Scatter(x=np.arange(len(regrets4))+1, y=frac_opt_actions4,
                        line={"color": "#167ab3", "dash": "solid"}, 
                         name=f"Adaptive RF learner", mode="lines"))
fig.add_trace(go.Scatter(x=np.arange(len(regrets5))+1, y=frac_opt_actions5,
                        line={"color": "#565659", "dash": "solid"}, 
                         name=f"Linear experts", mode="lines"))
fig.add_trace(go.Scatter(x=np.arange(len(regrets6))+1, y=frac_opt_actions6,
                        line={"color": "#12a9c1", "dash": "solid"}, 
                         name=f"Bagged Linear", mode="lines"))
fig.layout={'yaxis': {"range":[0, 1]}}
fig.show()

In [None]:
plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)
plt.imshow(reshape_vals(
        player.arms["a1"].value(make_audit_data())[1] - 
        player.arms["a0"].value(make_audit_data())[1]
    ))