In [None]:
import numpy as np
import matplotlib.pyplot as plt
from worldModels import *
import scipy.io

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
nblocks = 100

prewlist = np.linspace(0.5, 0.95, 25)
pswitchlistLog = np.linspace(np.log(0.01), np.log(0.2), 20)
pswitchlist = np.exp(pswitchlistLog)
efflist = np.zeros((len(prewlist), len(pswitchlist)))

for i, prew in enumerate(prewlist):
    print(i)
    for j, pswitch in enumerate(pswitchlist):
        #print(i, j)
        raw_rates = np.random.rand(nblocks)
        rates = np.vstack((1-raw_rates, raw_rates)).T
        ntrials = np.random.uniform(low=200, high=300, size=nblocks).astype('int')

        world = PersistentWorld(rates=rates, ntrials=ntrials)

        agent = EGreedyInferenceBasedAgent(prew=prew, pswitch=pswitch, eps=0)
#         agent = EGreedyQLearningAgent(gamma=0.5, eps=0.1)

        exp = Experiment(agent, world)
        a,b = exp.run()  
        efflist[i,j] = agent.find_efficiency()


In [None]:
fig, ax = plt.subplots()
im = ax.imshow(efflist, extent=[min(pswitchlistLog), max(pswitchlistLog), min(prewlist), max(prewlist)],
          origin='lower', cmap='hot', vmin=0.55, vmax=0.85, aspect='auto')
fig.colorbar(im)

fontdict = {'fontsize': 16,
 'fontweight': 2}
ax.set_xticks(pswitchlistLog[::3])
ax.set_xticklabels(np.round(pswitchlist[::3], 2), fontdict)
ax.set_yticks(np.linspace(min(prewlist), max(prewlist), 4))
ax.set_yticklabels(np.linspace(min(prewlist), max(prewlist), 4), fontdict)

plt.xlabel(r'$P_{switch}$', fontdict)
plt.ylabel(r'$P_{reward}$', fontdict)




In [None]:
nblocks = 100
prew = 1
pswitch = 0.05
np.random.seed(100)
raw_rates = np.random.rand(nblocks)
rates = np.vstack((1-raw_rates, raw_rates)).T
ntrials = np.random.uniform(low=200, high=300, size=nblocks).astype('int')

world = PersistentWorld(rates=rates, ntrials=ntrials)

# agent = EGreedyInferenceBasedAgent(prew=prew, pswitch=pswitch, eps=0)
agent = EGreedyQLearningAgent(gamma=0.5, eps=0.3)

exp = Experiment(agent, world)
res = exp.run()

In [None]:
agent.find_efficiency()

In [None]:
rate_history = world.get_rate_history()
nsteps = 10000
plt.plot(rate_history[:nsteps, 0])
plt.plot(agent.p0_history[:nsteps])

In [None]:
plt.plot(rate_history[:,0], agent.p0_history, '.')

In [None]:
a = agent.get_running_choice_fraction(10)
b = agent.get_running_reward_fraction(10)
rates_hist = world.get_rate_history()

In [None]:
ntrialcum = np.cumsum(ntrials)
ntrialcum = np.hstack([0, ntrialcum])

In [None]:
choice_fracs = []
income_fracs = []
for i in range(len(ntrials)):
    choice_subset = agent.choice_history[ntrialcum[i]:ntrialcum[i+1]]
    reward_subset = agent.outcome_history[ntrialcum[i]:ntrialcum[i+1]]
    choice_subset = np.array(choice_subset, dtype='int')
    reward_subset = np.array(reward_subset, dtype='int')
    
    income1 = np.sum((reward_subset == 1) & (choice_subset == 1))
    income0 = np.sum((reward_subset == 1) & (choice_subset == 0))
    income_fracs.append(income1 / (income1 + income0))
    choice_fracs.append(np.sum(choice_subset) / len(choice_subset))

In [None]:
plt.plot(income_fracs, choice_fracs, '.')
plt.plot([0, 1], [0, 1])
plt.xlabel('Income fraction')
plt.ylabel('Choice fraction')

## Simulate a world which switches between two states

In [None]:
# Build the rates object
prew = 0.8
pswitch = 0.05
nblocks = 20
ntrialsperblock = 50
rates = []
for i in range(nblocks):
    if i % 2:
        rates.append([0.1, 0.9])
    else:
        rates.append([0.9, 0.1])

rates = np.array(rates)
ntrials = np.random.uniform(low=200, high=300, size=nblocks).astype('int')
world = PersistentWorld(rates=rates, ntrials=ntrials)

agent = InferenceBasedAgent(prew=prew, pswitch=pswitch, type='random')

exp = Experiment(agent, world)
res = exp.run()

In [None]:
plt.plot(agent.p0_history[:500])

In [None]:
side_history = np.array(world.side_history)
plt.plot(world.side_history[:500])