In [None]:
import numpy as np

In [70]:
# error display
def error_display(linucb_set, true_set):
    return "different. linucb setting : {0}, true settings : {1}".format(linucb_set, true_set)

In [76]:
# 既存のクラスにインスタンスメソッドを追加する関数
def add_instance_method(Class, method):
    aetattr(Class, method.__name__, method)

In [158]:
# LinUCBアルゴリズムのクラス
class LinUCB:
    def __init__(self, alpha=0.01, sigma_0=0.1):
        self.alpha = alpha
        self.sigma_0 = sigma_0
        
    # 初期化関数
    # 入力：
    # n_arms     引ける腕の数
    # n_features 文脈の次元数
    # sigma      誤差項の分散
    def initialize(self, n_arms, n_features, sigma):
        self.n_arms = n_arms
        self.n_features = n_features
        self.sigma  = sigma
        self.A_inv  = self.sigma_0 / sigma * np.eye(n_features)
        self.b      = np.zeros(n_features)
        self.theta  = self.A_inv.dot(self.b)
        
    # theta の係数格納
    def set_theta(theta):
        self.theta = theta
        
    # iteration 回数のset関数
    def set_iteration_number(self, t):
        # t が0以下の値の場合、エラーを返す
        assert t > 0, "iteration number must be positive. t = {0}".format(t)
        
        self.iter = t
        
    # context のset関数
    def set_context(self, context):
        self.context = np.array(context)
        
    # UCBスコアの計算
    # 入力：
    # t       反復回数
    # context 入力する文脈. 属性情報とかそれまでの選択とか. n_arms * n_features の行列
    def calc_UCBScore(self, t, context):
        # iteration, context のセット
        self.set_iteration_number(t)
        self.set_context(context)
        
        alpha_t = self.alpha * np.sqrt(np.log(t))
        
        UCBScore = [0 for _ in range(self.n_arms)]
        for idx, arm in enumerate(range(self.n_arms)):
            # context を各armの場合に分ける
            a_it = self.context[idx]
            UCBScore[idx] = a_it.dot(self.theta) + alpha_t * self.sigma * a_it.T.dot(self.A_inv).dot(a_it)
        
        return UCBScore
    
    # 腕を選択
    def select_arm(self, UCBScore):
        return np.argmax(UCBScore)
    
    # 観測した報酬からパラメータ更新
    # 入力
    # selected_arm 選択した腕
    # reward       得られた報酬
    def update(self, selected_arm, reward):
        # 視認性向上のため、context 情報を変数に格納しておく
        a_it = self.context[selected_arm]
        A_inv_a_it = self.A_inv.dot(a_it)
        A_inv_a_it_a_it_T = self.A_inv.dot(a_it).dot(a_it.T)
        
        # パラメータ更新
        self.A_inv = self.A_inv - np.dot(A_inv_a_it_a_it_T, self.A_inv) / 1+a_it.T.dot(A_inv_a_it)
        self.b = self.b + a_it*reward

In [72]:
def test_LinUCB_initialize():
    lu = LinUCB()
    n_arms = 3
    n_features = 4
    sigma = 0.1
    
    lu.initialize(n_arms, n_features, sigma)
    
    # your settings
    assert lu.n_arms == n_arms, "n_arms " + error_display(lu.n_arms, n_arms)
    assert lu.n_features == n_features, "n_features " + error_display(lu.n_features, n_features)
    assert lu.sigma == sigma, "sigma " + error_display(lu.sigma, sigma)
    
    # initialize
    assert (lu.A_inv == lu.sigma_0 / sigma * np.eye(n_features)).all(), "A_inv different. "
    assert (lu.b == np.zeros(n_features)).all(), "b different. "
    assert (lu.theta == lu.A_inv * lu.b).all(), "theta different"
    
    # check size
    assert lu.A_inv.shape[0] == n_features, "A_inv row size " + error_display(lu.A_inv.shape[0], n_features)
    assert lu.A_inv.shape[1] == n_features, "A_inv column size " + error_display(lu.A_inv.shape[1], n_features)
    assert lu.b.shape[0] == n_features, "b size " + error_display(lu.b.shape[0], n_features)
    assert lu.theta.shape[0] == n_features, "theta size " + error_display(lu.theta.shape[0], n_features)
    
    # theta はベクトルのはずなので2次元目が存在したらエラー
    try:
        lu.theta.shape[1]
    except:
        print("OK")
    else:
        assert False, "LinUCB.theta has column!"
        
    
    # すべてのテスト完了
    print("Conglatulations!")
    return

In [73]:
test_LinUCB_initialize()

OK
Conglatulations!


In [111]:
# calc_UCBScore のテスト
def test_LinUCB_calc_UCBScore():
    # parameter settings
    lu = LinUCB()
    n_arms = 3
    n_features = 2
    sigma = 0.1
    lu.initialize(n_arms, n_features, sigma)
    
    t = 2
    context = np.array([[1,2], [4,5], [7,8]])
    
    ucb_score = lu.calc_UCBScore(t,context)
    
    assert len(ucb_score) == n_arms, "n_arms " + error_display(len(ucb_score), n_arms)
    
    print(ucb_score)
    
    # 反復回数が0以下だった時の例外処理
    t_0 = 0
    try:
        lu.calc_UCBScore(t_0,context)
    except:
        print("OK")
    else:
        assert False, "0 divide exception denied."
    
    # すべてのテスト完了
    print("Conglatulations!")
    return

In [112]:
test_LinUCB_calc_UCBScore()

[0.004162773055788488, 0.034134739057465606, 0.09407867106081984]
OK
Conglatulations!


In [114]:
# 腕の選択に関するテスト
def test_LinUCB_select_arm():
    # parameter settings
    lu = LinUCB()
    n_arms = 3
    n_features = 2
    sigma = 0.1
    lu.initialize(n_arms, n_features, sigma)
    t = 2
    context = np.array([[1,2], [4,5], [7,8]])
    ucb_score = lu.calc_UCBScore(t,context)
    
    selected_arm = lu.select_arm(ucb_score)
    
    assert selected_arm == np.argmax(ucb_score), "selected arm " + error_display(selected_arm, np.argmax(ucb_score))

In [115]:
test_LinUCB_select_arm()

In [161]:
# update に関するテスト
def test_LinUCB_update():
    # parameter settings
    lu = LinUCB()
    n_arms = 3
    n_features = 2
    sigma = 0.1
    lu.initialize(n_arms, n_features, sigma)
    t = 2
    context = np.array([[1,2], [4,5], [7,8]])
    ucb_score = lu.calc_UCBScore(t,context)
    selected_arm = lu.select_arm(ucb_score)
    
    reward_list = [0.1, 0.4, 0.3]
    reward = reward_list[selected_arm]
    
    lu.update(selected_arm, reward)
    
    assert lu.A_inv.shape[0] == n_features, "A_inv size " + error_display(lu.A_inv.shape[0], n_features)
    assert lu.A_inv.shape[1] == n_features, "A_inv size " + error_display(lu.A_inv.shape[1], n_features)
    assert lu.b.shape[0] == n_features, "A_inv size " + error_display(lu.b.shape[0], n_features)
    
    # b はベクトルのはずなので2次元目が存在したらエラー
    try:
        lu.b.shape[1]
    except:
        print("OK")
    else:
        assert False, "LinUCB.b has column!"

In [162]:
test_LinUCB_update()

OK
[[  1. 113.]
 [113.   1.]]
