In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame, concat
from sisyphus import ModelFree, ValueIteration
from sisyphus.envs import FreeChoice
from sisyphus.misc import softmax
from tqdm import tqdm
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize environment.
gym = FreeChoice()

## Define exploration schedule.
schedule = np.logspace(0,1,100)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

data = []
for beta in [1.0, 0.5, 0.0]:
    
    ## Initialize agent.
    agent = ModelFree('betamax', eta=0.1, gamma=1, beta=beta)
    
    for _ in tqdm(range(100)):
        
        ## Solve for Q-values.
        agent = agent.fit(gym, schedule=schedule, overwrite=True)
        
        ## Compute choice likelihood.
        theta = softmax(agent.Q[:2] * schedule[-1])

        ## Store results.
        data.append( dict(beta=beta, theta=theta[0]) )

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        
## Convert to DataFrame.
data = DataFrame(data)

sns.pointplot('beta','theta',data=data)