In [None]:

@jit
def example_reward_function(state, goal_state):
    """Define your reward logic here."""
    # if jnp.array_equal(state, jnp.array([4, 4])):
    #    return 10
    # else:
    #    return -1
    is_goal = jnp.all(state == goal_state)
    current_distance = jnp.linalg.norm(state - goal_state)
    total_distance = jnp.sqrt(2)
    
    distance_reward = current_distance/total_distance * 3
    
    # Use jnp.where to select between goal reward and distance-based reward
    reward = jnp.where(is_goal, 15.0, distance_reward)
    return reward

def example_transition_function(state, action, state_space_shape):
    """Define your state transition logic here."""
    x, y = state

    def magic(key):
        return random.uniform(random.PRNGKey(key), (1,2), minval=-1, maxval=1)[0]

    def action_one(_):
        return x+y

    def action_two(_):
        return jnp.abs(x-y)

    def action_three(_):
        return x/y

    def action_four(_):
        return y/x

    key = jax.lax.switch(action, [action_one,action_two,action_three,action_four], None)
    x, y = magic(key) * key
    return jnp.array([x, y])


In [1]:
import os
import numpy as np
from collections import deque
import flax
from flax import linen as nn
from flax.training import train_state

import jax
from jax import jit, random, vmap, numpy as jnp

import optax

import random as rd


In [None]:
#print(f"episode: {episode} state: {state}, reward: {reward}, action: none done: {done}")
#episode_reward = 0  # Track total reward per episode


In [None]:
STOP

In [None]:
minibatch = rd.sample(agent.memory, batch_size)
states = jnp.array([experience[0] for experience in minibatch], dtype=jnp.float32)
actions = jnp.array([experience[1] for experience in minibatch], dtype=jnp.int32)
rewards = jnp.array([experience[2] for experience in minibatch], dtype=jnp.float32)
next_states = jnp.array([experience[3] for experience in minibatch], dtype=jnp.float32)
dones = jnp.array([experience[4] for experience in minibatch], dtype=jnp.bool_)


In [None]:
# Compute the target Q-values using JIT compilation
@jax.jit
def compute_target_q_values(rewards, gamma, futures, dones):
    return rewards + gamma * futures * (1 - dones) 

gamma = agent.gamma
futures = jnp.max(agent.model.apply(agent.state.params, next_states), axis=-1)
target_q_values = compute_target_q_values(rewards, gamma, futures, dones)
type(target_q_values)

In [None]:
from jax import vmap


In [None]:

def compute_loss(states, actions, targets, model, params):
    """
    Compute the loss values in parallel using JAX.

    :param states: Array of states
    :param actions: Array of actions
    :param targets: Array of target Q-values
    :param model: A Flax model
    :param params: Parameters of the Flax model
    :return: Array of loss values
    """
    # Vectorize the model function to apply it to all states
    def model_fn(state):
        return model.apply(params, state)
    
    vectorized_model = vmap(model_fn)
    
    # Compute the Q-values for all states
    q_values = vectorized_model(states)
    
    # Select the Q-values corresponding to the actions taken
    q_values = jnp.take_along_axis(q_values, actions[:, None], axis=1).squeeze()
    
    # Compute the squared differences between predicted Q-values and target Q-values
    loss_values = (q_values - targets) ** 2
    
    return loss_values


In [None]:
L = compute_loss(states, actions, target_q_values, agent.model, agent.state.params)
G = loss_fn_batch(states, actions, target_q_values, agent.model, agent.state.params)
L == G

In [None]:
#Approach 3: Average all the gradients then just apply one
# In this approach, we compute the gradients for all samples, average them, and then apply the averaged gradient.
def loss_fn_batch(states, actions, targets, model, params):
    def single_loss_fn(state, action, target):
        q_values = model.apply(params, state)
        q_value = q_values[action]
        loss = (q_value - target) ** 2
        return loss
    vectorized_loss_fn = jax.vmap(single_loss_fn, in_axes=(0, 0, 0))
    return vectorized_loss_fn(states, actions, targets)


def update_step_average(states, actions, targets, model, train_state):
    grads = jax.grad(loss_fn_batch)(train_state.params, states, actions, targets, model)
    mean_grads = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), grads)
    train_state = train_state.apply_gradients(grads=mean_grads)
    return train_state

# Example usage
#train_state = update_step_average(states, actions, targets, model, train_state)

In [None]:
def loss_fn(params, state, action, target):
    q_values = agent.model.apply(params, state)
    q_value = q_values[action]
    loss = jnp.mean((target - q_value) ** 2)
    return loss
grad_fn = jax.grad(loss_fn)
vmap_grad_fn = vmap(grad_fn, in_axes=(None, 0, 0, 0))

sgrads = grad_fn(agent.state.params, states[8], actions[8], target_q_values[8])
grads = vmap_grad_fn(agent.state.params, states, actions, target_q_values)
average_grads = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), grads)



In [None]:
sgrads

In [None]:
#Approach 2: Calculate all the gradients and then apply them all
#In this approach, we compute the gradients for all samples and then apply them all at once.
 
def loss_fn_batch(params, states, actions, targets, model):
    def single_loss_fn(state, action, target):
        q_values = model.apply(params, state)
        q_value = q_values[action]
        loss = (q_value - target) ** 2
        return loss
    vectorized_loss_fn = jax.vmap(single_loss_fn, in_axes=(None, 0, 0, 0))
    return jnp.mean(vectorized_loss_fn(states, actions, targets))

def update_step_batch(states, actions, targets, model, train_state):
    grads = jax.grad(loss_fn_batch)(train_state.params, states, actions, targets, model)
    train_state = train_state.apply_gradients(grads=grads)
    return train_state

# Example usage
train_state = update_step_batch(states, actions, targets, model, train_state)

In [None]:
# Approach 1: Get gradient, apply it, get the next one, apply it, etc.
# In this approach, we compute and apply the gradient for each sample sequentially.

import jax
from flax.training import train_state

def loss_fn(params, state, action, target, model):
    q_values = model.apply(params, state)
    q_value = q_values[action]
    loss = (q_value - target) ** 2
    return loss

def update_step(state, action, target, model, train_state):
    grads = jax.grad(loss_fn)(train_state.params, state, action, target, model)
    train_state = train_state.apply_gradients(grads=grads)
    return train_state

# Example usage
for state, action, target in zip(states, actions, targets):
    train_state = update_step(state, action, target, model, train_state)