# Q-learning with linear policy approximation

In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

## Reminders

Standard Q-learning update :
$$Q(s_{t+1}, a_{t+1})\leftarrow (1-\alpha)Q(s_t, a_t) + \alpha\,\left[R(s_t,a_t,s_{t+1})+\gamma\,\underset{a'}{max}\,Q(s_{t+1}, a')\right]$$

Linear approximation of $Q$ :
$$Q(s,a)=\theta_0 + \theta_1 \phi_1 + ... + \theta_n \phi_n = \theta^T \phi$$

Loss 

Parameter update :
$$\theta^{(i+1)}\leftarrow \theta^{(i)} - \alpha \nabla L(\theta^{(i)})$$

In [40]:
class Agent:

    def __init__(self, action_dim, state_dim, n_features, discount=0.99, rbf_param=0.1, epsilon=0.1):
        self.n_features = n_features
        self.rbf_param = rbf_param
        self.rbf_centers = np.random.uniform(low=-1., high=1., size=(n_features, state_dim))
        self.weights = np.zeros((action_dim, n_features))
        self.epsilon = epsilon
        self.discount = discount
        self.possible_actions = np.arange(action_dim)

    def norm2(self, z):
        return np.sum(z**2)

    def rbf(self, state, center, sigma):
        return np.exp(-self.norm2(state-center)) / (2 * sigma**2)


    def encode_state(self, state):
        """Converts a state of shape (state_dim, 1) as a feature vector of shape (n_features, 1)"""
        res = np.zeros(n_features)
        for idx,center in enumerate(self.rbf_centers):
            res[idx] = rbf(state, center, self.rbf_param)
        return res

    def choose_action(self, state):
        encoded = self.encode_state(state).reshape(-1,1)
        if np.random.random() > self.epsilon:
            q_values = self.weights @ encoded
            action = np.argmax(q_values)
            return action
        else:
            return np.random.choice(self.possible_actions)
        


## Training loop

In [None]:
env = gym.make('CartPole-v1')
agent = Agent(2, 4, 15)
n_episodes = 300

for ep in range(n_episodes):
    state, _ = env.reset()
    done = False

    while not done:
        action = agent.choose_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)

        encoded_next_state = agent.encode_state(next_state)


        next_q_value = 
        agent.weights = 

In [36]:
np.dot(np.ones(4).reshape(-1,1), np.ones(4).reshape(1,-1))

array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])