In [None]:
import numpy as np
import lqg1d
import matplotlib.pyplot as plt
from utils import collect_episodes
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from utils import collect_episodes, estimate_performance
from tqdm import tqdm

In [None]:
env = lqg1d.LQG1D(initial_state_type='random')
discount = 0.9
horizon = 50

actions = discrete_actions = np.linspace(-8, 8, 20)

In [None]:
#################################################################
# Show the optimal Q-function
#################################################################
def make_grid(x, y):
    m = np.meshgrid(x, y, copy=False, indexing='ij')
    return np.vstack(m).reshape(2, -1).T

In [None]:
states = discrete_states = np.linspace(-10, 10, 20)
SA = make_grid(states, actions)
S, A = SA[:, 0], SA[:, 1]

K, cov = env.computeOptimalK(discount), 0.001
print('Optimal K: {} Covariance S: {}'.format(K, cov))

Q_fun_ = np.vectorize(lambda s, a: env.computeQFunction(s, a, K, cov, discount, 1))
Q_fun = lambda X: Q_fun_(X[:, 0], X[:, 1])

Q_opt = Q_fun(SA)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(S, A, Q_opt)

plt.show()

In [None]:
class BehPolicy:
    def __init__(self, actions):
        self.actions = actions
        
    def draw_action(self, state):
        return self.actions[np.random.randint(len(self.actions))]

In [None]:
class FQIPolicy:
    def __init__(self, Q, actions):
        self.Q = Q
        self.actions = actions
    
    def draw_action(self, state):
        # return self.actions[(np.abs(self.actions-Q(state, -(theta[0]+state*theta[1])/2*theta[2]))).argmin()]
        return self.actions[np.argmax([Q(state, a) for a in self.actions])]

In [None]:
#################################################################
# Collect the samples using the behavioural policy
#################################################################
# You should use discrete actions
theta = np.zeros((3,))
phi = lambda s,a: np.array([a, s*a, s**2+a**2])
Q = lambda s,a: phi(float(s),float(a)).dot(theta.T)
lmbda = 0.1


In [None]:
# define FQI
# to evaluate the policy you can use estimate_performance
fqi = FQIPolicy(Q, actions)
beh_policy = BehPolicy(actions)

n_itr = 100

dataset = collect_episodes(env, n_episodes=100, policy=beh_policy, horizon=horizon)
#np.argmax([Q(state, a) for a in self.actions]) 

Z = np.array([phi(dataset[t]["states"][i],dataset[t]["actions"][i]).transpose()  for t in range(len(dataset)) for i in range(len(dataset[t]["actions"]))])
# y = np.array([dataset[t]["rewards"][i] + discount*Q(dataset[t]["next_states"][i], -(theta[0]+dataset[t]["next_states"][i]*theta[1])/2*theta[2]) for t in range(len(dataset))  for i in range(len(dataset[t]["actions"]))])
y = np.array([dataset[t]["rewards"][i] + discount*np.max([Q(dataset[t]["next_states"][i], a) for a in beh_policy.actions])  for t in range(len(dataset))  for i in range(len(dataset[t]["actions"]))])

theta = np.linalg.inv(Z.transpose().dot(Z)+lmbda*np.eye(Z.shape[1])).dot(Z.transpose()).dot(y)
J_t = []
for _ in tqdm(range(n_itr), desc="Simulating"):
    #dataset = collect_episodes(env, n_episodes=100,
    #                                            policy=fqi, horizon=horizon)
    #np.argmax([Q(state, a) for a in self.actions]) 

    Z = np.array([phi(episode["states"][i], episode["actions"][i]).transpose()  for episode in dataset for i in range(len(episode["actions"]))])
    # y = np.array([dataset[t]["rewards"][i] + discount*Q(dataset[t]["next_states"][i], -(theta[0]+dataset[t]["next_states"][i]*theta[1])/2*theta[2]) for t in range(len(dataset))  for i in range(len(dataset[t]["actions"]))])
    y = np.array([episode["rewards"][i] + discount*np.max([Q(episode["next_states"][i], a) for a in fqi.actions])  for episode in dataset  for i in range(len(episode["actions"]))])

    theta = np.linalg.inv(Z.T.dot(Z)+lmbda*np.eye(Z.shape[1])).dot(Z.T).dot(y)
    J_t.append(estimate_performance(env, policy=fqi, horizon=10, n_episodes=50, gamma=discount))
    print(theta)
    

In [None]:
plt.plot(J_t)
plt.show()

In [None]:
# plot obtained Q-function against the true one
J = estimate_performance(env, policy=fqi, horizon=10, n_episodes=50, gamma=discount)
print('Policy performance: {}'.format(J))
