In [None]:
# parameter estimation of 
# alpha (learning rate)
# beta  (inverse temperature)
# in softmax Q-learning models

In [1]:
import numpy as np
from scipy.special import softmax
from scipy import optimize as optim

np.set_printoptions(precision=2)

In [2]:
N_SAMPLES = 10
N_TRIALS = 40
VARS  = np.array([2, 2])
MEAN_LS = np.array([(1,-1), (-1,1), (0,0)])
ALPHA, BETA = (0.1, 1)

In [3]:
class Soft_Q_learners():
    """Softmax Q-learning agents for 2-armed bandit"""
    
    def __init__(self, alpha, beta, n_sample=N_SAMPLES):
        
        self.n = n_sample
        self.Q_tables = np.zeros((n_sample, 2))
        self.alpha = alpha
        self.beta  = beta
        
    def act(self):
        probs = softmax(self.Q_tables, axis=-1)
        actions = probs[:,0] > np.random.rand(self.n)
        return actions.reshape(-1,1)
    
    def learn(self, actions, rewards):
        self.Q_tables[range(self.n), actions] += self.alpha * (rewards - self.Q_tables[range(self.n), actions])

In [4]:
# simulate 2-armed bandit games

choice_ls = []
reward_ls = []
Qvalue_ls = []

agents = Soft_Q_learners(alpha=ALPHA, beta=BETA)
for _ in range(10):
    
    MEANS = MEAN_LS[np.random.choice(range(len(MEAN_LS)))]
    
    for _ in range(N_TRIALS):
        actions = np.logical_not(agents.act()[:,0]).astype(int)
        rewards = MEANS[actions] + np.random.randn(N_SAMPLES) * VARS[actions]
        agents.learn(actions, rewards)

        choice_ls.append(actions)
        reward_ls.append(rewards)
        Qvalue_ls.append(np.copy(agents.Q_tables))
    
choice_ls = np.array(choice_ls).T
reward_ls = np.array(reward_ls).T
Qvalue_ls = np.array(Qvalue_ls)

In [5]:
def likelihood(param, choices, rewards):
    
    LL = 0
    Q = np.array([0.,0.])
    
    # map range of alpha to (0, 1); beta to (0, inf)
    alpha = 1 / (1 + np.exp(-param[0]))
    beta = np.exp(param[1])
    
    for choice, reward in zip(choices, rewards):
        
        # make prediction
        probs = softmax(beta * Q)
        LL -= np.log(probs[choice] + 1e-8)
        
        # update Q-value
        Q_err = reward - Q[choice]
        Q[choice] += alpha * Q_err
        
    return LL

In [6]:
print(f"ALPHA:{ALPHA}   BETA:{BETA}\n")

# estimate parameters

X0 = np.array([0, 0])  # initial guess

for i in range(N_SAMPLES):
    
    # compute mle
    x = optim.minimize(
        fun=likelihood, 
        x0=X0, 
        args=(choice_ls[i], reward_ls[i])
    )
    
    # recover parameters
    alpha = 1 / (1 + np.exp(-x.x[0]))
    beta = np.exp(x.x[1])

    print(f"Sample{i}:    alpha:{alpha:.4f}    beta:{beta:.4f}    iters:{x.nit:02d}")

ALPHA:0.1   BETA:1

Sample0:    alpha:0.0886    beta:1.0593    iters:09
Sample1:    alpha:0.1084    beta:0.7843    iters:10
Sample2:    alpha:0.0867    beta:0.8635    iters:10
Sample3:    alpha:0.0922    beta:1.1205    iters:10
Sample4:    alpha:0.1139    beta:0.8045    iters:10
Sample5:    alpha:0.0856    beta:0.9799    iters:08
Sample6:    alpha:0.1136    beta:0.9014    iters:09
Sample7:    alpha:0.0983    beta:0.9181    iters:10
Sample8:    alpha:0.1307    beta:0.7285    iters:09
Sample9:    alpha:0.0772    beta:1.1619    iters:07


In [7]:
# latent variable estimation (Q-value)
# compute parameter
alpha_, beta_ = optim.minimize(
    fun=likelihood, 
    x0=X0, 
    args=(choice_ls[0], reward_ls[0])
).x
alpha_ = 1 / (1 + np.exp(-alpha_))
beta_  = np.exp(beta_)

# simulate
Q_ = [0.,0.]
ls = []
for i in range(len(choice_ls[0])):
    choice = choice_ls[0,i]
    reward = reward_ls[0,i]
    Q_[choice] += alpha_ * (reward - Q_[choice])
    ls.append([Qvalue_ls[i, 0], np.array(Q_)])
    
ls[:20]

[[array([-0.37,  0.  ]), array([-0.33,  0.  ])],
 [array([-0.37, -0.05]), array([-0.33, -0.05])],
 [array([-0.37,  0.05]), array([-0.33,  0.04])],
 [array([-0.37,  0.26]), array([-0.33,  0.23])],
 [array([-0.37,  0.43]), array([-0.33,  0.38])],
 [array([-0.37,  0.57]), array([-0.33,  0.51])],
 [array([-0.33,  0.57]), array([-0.3 ,  0.51])],
 [array([-0.33,  0.77]), array([-0.3 ,  0.69])],
 [array([-0.19,  0.77]), array([-0.17,  0.69])],
 [array([-0.43,  0.77]), array([-0.39,  0.69])],
 [array([-0.43,  1.01]), array([-0.39,  0.91])],
 [array([-0.43,  1.44]), array([-0.39,  1.3 ])],
 [array([-0.43,  1.09]), array([-0.39,  1.  ])],
 [array([-0.57,  1.09]), array([-0.51,  1.  ])],
 [array([-0.54,  1.09]), array([-0.49,  1.  ])],
 [array([-0.54,  1.48]), array([-0.49,  1.36])],
 [array([-0.54,  1.72]), array([-0.49,  1.59])],
 [array([-0.54,  1.66]), array([-0.49,  1.54])],
 [array([-0.54,  1.39]), array([-0.49,  1.31])],
 [array([-0.54,  1.07]), array([-0.49,  1.03])]]