In [22]:
import numpy as np
from distribution import Gaussian, GenericFunction
from bandit_type import Linear
from instance import soare
import utils

In [23]:
class General(Linear):
    
    def __init__(self, X, Y, gen_star, T, sigma, name):
        super().__init__(X, Y, gen_star, T, sigma, name)
        self.pi = Gaussian(np.zeros(self.d), self.V)
        self.gen_star = gen_star
        self.B = 100
        
        
    def run(self, logging_period=1):
        for t in range(self.T):
            
            f1 = self.pi.sample()
            idx1 = np.argmax(f1.evaluate(self.X))
            x1 = self.X[idx1]
            
            idx2 = idx1
            while idx1 == idx2:
                f2 = self.pi.sample()  # TODO: sample 10 at a time
                idx2 = np.argmax(f2.evaluate(self.X))
            x2 = self.X[idx2]

            v = []   
            
            for idx in range(self.K):
                
                x = self.X[idx]
                expected_diff = 0
                expected_diff_squared = 0
                
                for b1 in range(self.B):
                    gen_b1 = self.pi.sample()
                    y_b1 = gen_b1.pull(x)
                    pi_plus = self.pi.update_posterior(x, y_b1, copy=True)
                    gen_b2 = pi_plus.sample()
                    expected_diff += ( gen_b2.evaluate(x1) - gen_b2.evaluate(x2) ) / self.B
                    expected_diff_squared += ( gen_b2.evaluate(x1) - gen_b2.evaluate(x2) )**2 / self.B
                
                v.append( expected_diff_squared - expected_diff**2 )
            
            x_n = X[np.argmax(v)]
            y_n = self.gen_star.pull(x_n)
            self.pi.update_posterior(x_n, y_n)
            self.arms_chosen.append(x_n)

            if t%logging_period == 0:
                print('general run', self.name, 'iter', t, "/", self.T, end="\r")

In [1]:
d = 5
T = 1000
sigma = 1.
X, theta_star = soare(d, alpha=.1)
Y = utils.compute_Y(X)

gen_star = GenericFunction(lambda x: x @ theta_star, sigma)

alg = General(X, Y, gen_star, T, sigma, 'top-two')
alg.run()