In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
from torch import nn 
from tqdm import tqdm

In [2]:
from utils import *

In [3]:
arms = torch.tensor([[ 0.6667,  0.3333, -0.6667], [-0.2182,  0.8729, -0.4364]]).float()

In [4]:
bandit = GradientBandit(n_arms=2, context_size=2)
non_strat = Agents(n=1000, n_arms=2, context_size=2,
                  arms=arms,
                  max_reward=0, max_variance=0)

In [5]:
def pick_arm(y):
    y = np.maximum(0, np.sign(y))
    value = 0
    for i, v in enumerate(y):
        value += v * (2 ** i)
    return value

def color_picker(v):
    colors = ['red', 'green', 'blue', 'orange', 'yellow', 'purple', 'pink', 'black']
    return colors[v]

def train(model, dataset, learning_rate=1e-5, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    agent_rewards = dataset.rewards
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
    
    for epoch in range(epochs):
        for xs, ys, variances in tqdm(data_loader):
            optimizer.zero_grad()
            loss = model(x=xs, y=ys, variances=variances, agent_rewards=agent_rewards)
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            features = []
            predictions = []
            labels = []
            for xs, ys, variances in tqdm(data_loader):
                y_hat = model(x=xs, variances=variances, agent_rewards=agent_rewards)
                features += xs.numpy().tolist()
                predictions += [pick_arm(pred) for pred in y_hat]
                labels += [pick_arm(y) for y in ys]
            print('Error:', sum(predictions == labels) / len(labels))
            plt.scatter(features[:, 0], features[:, 1], c=[color_picker(v) for v in predictions])
            plt.draw()

In [6]:
train(bandit, non_strat)

  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 3]], which is output 0 of SubBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [None]:
plt.axis('equal')
for i in range(len(non_strat)):
    x, y, var = non_strat[i]
    plt.scatter(x[0], x[1], c='r' if y[0] < 0 and y[1] < 0 else 'g' if y[0] < 0 and y[1] >= 0 else 'b' if y[0] >= 0 and y[1] < 0 else 'y')

In [None]:
plt.axis('equal')
x_aug = torch.cat([non_strat.x, torch.ones(len(non_strat.x)).reshape(-1, 1)], 1).float()
ys = torch.max(x_aug @ arms.T, dim=-1).indices
for i in range(len(non_strat)):
    x, y, var = non_strat[i]
    plt.scatter(x[0], x[1], c='r' if ys[i] == 0 else 'b')