In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import DeterministicPolicy,Policy, Critic, ensemblize
from jaxrl_m.rollout import rollout_policy,PolicyRollout
import flax
import flax.linen as nn

from functools import partial



class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: TrainState
    target_critic: TrainState
    actor: TrainState
    config: dict = nonpytree_field()
    
    
    @jax.jit    
    def reset_critic_optimizer(agent):
    
        new_opt_state = agent.critic.tx.init(agent.critic.params)
        new_critic = agent.critic.replace(opt_state=new_opt_state)
        
        return agent.replace(critic=new_critic)

            
    @partial(jax.jit,static_argnames=('num_steps',))  
    
    def update_critics(agent, transitions: Batch,idxs:jnp.array,num_steps:int):
        
        def update_critic(agent,critic,target_critic,transitions: Batch,idxs:jnp.array,num_steps:int):
            
                def one_update(previous_result,new_element):
                        
                    agent,critic,target_critic = previous_result
                    batch_idxs = new_element
                        
                    new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)
                    
                    
                    target_critic_params = target_critic.params
                    batch = jax.tree_map(lambda x : x[batch_idxs],transitions)

                
                    def critic_loss_fn(critic_params):
                        next_actions = agent.actor(batch['next_observations'])
                        
                        #print(critic_params['params']['MLP_0']['Dense_0']['bias'].shape)
                        q = critic(batch['observations'], batch['actions'],True,
                                                            params=critic_params["params"],rngs={'dropout': curr_key},
                                                            )
                        
                        
                        #print(target_critic_params['params']['MLP_0']['Dense_0']['bias'].shape)
                        next_q = target_critic(batch['next_observations'], next_actions,True,
                                                            params=target_critic_params["params"],rngs={'dropout': next_key})
                        target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q
                        
                        
                        critic_loss = ((target_q-q)**2).mean()
                        
                        return critic_loss,{'critic_loss': critic_loss}
              
                    
                    new_critic,_  = critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
                    new_target_critic = target_update(critic, target_critic, agent.config['target_update_rate'])
                    new_agent = agent.replace(rng=new_rng)
                    
                    return (new_agent, new_critic, new_target_critic),()
            

                (agent,critic,target_critic),_ = jax.lax.scan(one_update,(agent,critic,target_critic),idxs)
                
         
                
                return critic,target_critic
        
        critics,target_critics = jax.vmap(update_critic,in_axes=(None,0,0,None,None,None))(
                                                agent,agent.critic,agent.target_critic,transitions,idxs,num_steps)
        
        agent = agent.replace(critic=critics,target_critic=target_critics)
        
        return agent,{}
        
   

        
    @jax.jit
    def update_actor(agent, batch: Batch):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        def actor_loss_fn(actor_params):
            
            actions = agent.actor(batch['observations'], params=actor_params)
            def f(observations,actions,params):
                return agent.critic(observations,actions,params=params["params"])
            
            qs = jax.vmap(f,in_axes=(None,None,0))(batch['observations'],actions,agent.critic.params)
            
            q = qs.mean(axis=0)
            actor_loss = (-q).mean()
            
            return actor_loss, {
                'actor_loss': actor_loss,
              
            }

        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)

        return agent.replace(rng=new_rng,actor=new_actor,), {**actor_info}
        


    @jax.jit
    def sample_actions(agent,observations: np.ndarray) -> jnp.ndarray:
        actions = agent.actor(observations)
       
        return actions
    
 

def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = DeterministicPolicy((64,64), action_dim=action_dim,final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=actor_lr))      

        critic_def = Critic(hidden_dims)
        critics_keys = jax.random.split(critic_key,5)
        critics_params = jax.vmap(critic_def.init,in_axes=(0,None,None))(critics_keys,observations,actions)
        tmp = partial(TrainState.create,critic_def,tx=optax.adam(learning_rate=critic_lr))
        critics = jax.vmap(tmp)(critics_params)
        target_critics = jax.vmap(TrainState.create,in_axes=(None,0))(critic_def,critics_params)

    
        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau,    
        ))

        #return SACAgent(rng, critic=critic, target_critic=target_critic, actor=actor, config=config)
        return SACAgent(rng, critic=critics, target_critic=target_critics, actor=actor, config=config)


In [2]:




def f(anc_agent,obs,actor):

    actions = anc_agent.actor(obs, params=actor)
    qs = anc_agent.critic(obs, actions,params=anc_agent.target_critic.params)
    q = qs.mean(axis=0)
    
    return q
    

@jax.jit
def estimate_return(anc_agent,anc_return,acq_rollout:PolicyRollout,):
    
    acq_obs = acq_rollout.observations
    acq_masks = acq_rollout.disc_masks
  
    acq_actor = acq_rollout.policy_params
    acq_return = acq_rollout.policy_return
    
    anc_actor = anc_agent.actor.params
    
    acq_q = f(anc_agent,acq_obs,acq_actor)
    anc_q = f(anc_agent,acq_obs,anc_actor)
    
    adv = ((acq_q - anc_q)*acq_masks).sum()/5
    acq_return_pred = anc_return + adv
  
    
    return acq_return_pred,acq_return


def evaluate_critic(anc_agent,anc_return,policy_rollouts):

    y_pred,y= [],[]
    for policy_rollout in policy_rollouts:
        
        acq_return_pred,acq_return = estimate_return(anc_agent,anc_return,policy_rollout)
        y_pred.append(acq_return_pred),y.append(acq_return)
        
    y_pred,y = np.array(y_pred),np.array(y)
    a2 = jnp.clip(((y-y_pred)**2),a_min=1e-4).sum()
    b2=((y-y.mean())**2).sum()
    R2 = 1-(a2/b2)  
    
    return R2


In [3]:
import os
from functools import partial
import numpy as np
import jax
import tqdm
import gymnasium as gym


from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import ReplayBuffer
from collections import deque
from jax import config
#config.update("jax_disable_jit", True)

env_name='InvertedPendulum-v4'
seed=np.random.choice(1000000)
eval_episodes=10
batch_size = 256
max_steps = int(1e6)
start_steps = 5000                   
log_interval = 5000
#eval_interval = 10000

wandb_config = default_wandb_config()
wandb_config.update({
    'project': 'd4rl_test',
    'group': 'sac_test',
    'name': 'sac_{env_name}',
})


env = EpisodeMonitor(gym.make(env_name))
eval_env = EpisodeMonitor(gym.make(env_name))
setup_wandb({"bonjour":1})

example_transition = dict(
    observations=env.observation_space.sample(),
    actions=env.action_space.sample(),
    rewards=0.0,
    masks=1.0,
    next_observations=env.observation_space.sample(),
    discounts=1.0,
)

replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))
actor_buffer = ReplayBuffer.create(example_transition, size=int(5e3))

agent = create_learner(seed,
                example_transition['observations'][None],
                example_transition['actions'][None],
                max_steps=max_steps,
                #**FLAGS.config
                )

exploration_metrics = dict()
obs,info = env.reset()    
exploration_rng = jax.random.PRNGKey(0)
i = 0
unlogged_steps = 0
policy_rollouts = deque([], maxlen=10)
with tqdm.tqdm(total=max_steps) as pbar:
    
    while i < max_steps:
    
        replay_buffer,actor_buffer,policy_rollout,policy_return,num_steps = rollout_policy(agent,env,exploration_rng,
                   replay_buffer,actor_buffer,
                   warmup=(i < start_steps))
        policy_rollouts.append(policy_rollout)
        unlogged_steps += num_steps
        i+=num_steps
        pbar.update(num_steps)
        
            
        if replay_buffer.size > start_steps:
        
         
            transitions = replay_buffer.get_all()
            idxs = jax.random.choice(a=replay_buffer.size, shape=(5000,256), replace=True,key=jax.random.PRNGKey(0))
            agent.reset_critic_optimizer()
            agent, critic_update_info = agent.update_critics(transitions,idxs,5000)
            #R2 = evaluate_critic(agent,policy_rollouts[-1].policy_return,policy_rollouts)
            R2 = 0

            
            actor_batch = actor_buffer.get_all()      
            agent, actor_update_info = agent.update_actor(actor_batch)    
            
            
            update_info = {**critic_update_info, **actor_update_info, 'R2_validation': R2}
            
            
            if unlogged_steps > log_interval:
                exploration_metrics = {f'exploration/disc_return': policy_return}
                wandb.log(exploration_metrics, step=int(i),commit=False)
                train_metrics = {f'training/{k}': v for k, v in update_info.items()}
                wandb.log(train_metrics, step=int(i),commit=False)
                #wandb.log(exploration_metrics, step=i)
                policy_fn = agent.actor
                eval_info = evaluate(policy_fn, eval_env, num_episodes=eval_episodes)
                eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
                print('evaluating')
                wandb.log(eval_metrics, step=int(i),commit=True)
                unlogged_steps = 0


2023-12-18 16:42:57.542973: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


  1%|          | 5003/1000000 [00:00<00:40, 24810.87it/s]

evaluating


  1%|          | 10045/1000000 [00:46<1:53:23, 145.51it/s]

evaluating


  2%|▏         | 15277/1000000 [01:25<2:35:57, 105.23it/s]

evaluating


  2%|▏         | 20458/1000000 [01:58<2:14:20, 121.52it/s]

evaluating


  3%|▎         | 25497/1000000 [02:27<2:12:09, 122.89it/s]

evaluating


  3%|▎         | 30681/1000000 [02:55<1:53:45, 142.01it/s]

evaluating


  4%|▎         | 35749/1000000 [03:24<2:19:15, 115.40it/s]

evaluating


  4%|▍         | 40877/1000000 [03:56<2:14:23, 118.95it/s]

evaluating


  5%|▍         | 45995/1000000 [04:26<2:07:57, 124.25it/s]

evaluating


  5%|▌         | 51093/1000000 [04:55<2:01:39, 130.00it/s]

evaluating


  6%|▌         | 56223/1000000 [05:22<1:59:10, 131.99it/s]

evaluating


  6%|▌         | 61371/1000000 [05:48<1:58:14, 132.31it/s]

evaluating


  7%|▋         | 66417/1000000 [06:17<2:19:51, 111.25it/s]

evaluating


  7%|▋         | 71541/1000000 [06:52<2:30:55, 102.53it/s]

evaluating


  8%|▊         | 76579/1000000 [07:23<2:11:58, 116.61it/s]

evaluating


  8%|▊         | 81657/1000000 [07:57<2:27:21, 103.87it/s]

evaluating


  9%|▊         | 86763/1000000 [08:36<2:18:41, 109.74it/s]

evaluating


  9%|▉         | 89700/1000000 [08:55<1:30:37, 167.40it/s]


KeyboardInterrupt: 

In [None]:
    
    @jax.jit 
    def train_critics(agent, transitions: Batch,idxs:jnp.array,num_steps:int,R2_history:jnp.ndarray):
        
        rng = jax.random.PRNG(0)
        rngs = jax.random.split(rng,5)
        n_critics = 5
        critic_mask = jnp.zeros((n_critics,))
        opt_mask = jnp.ones(n_critics)
        
    
        if jnp.min(R2_history) < 0 :
                                critic_mask.at[jnp.argmin(R2_history)].set(1)
                                print(f'resetting  {np.argmin(R2_history)}')
                                                                    
        # b_critic_params = reset_critic_vmap(critic_mask,rngs,agent_state.b_critic_params)
        # b_critic_target_params = reset_critic_vmap(critic_mask,rngs,agent_state.b_critic_target_params)
        # b_critic_opt_state = reset_opt_vmap(opt_mask,rngs,agent_state.b_critic_opt_state)
        # b_batch_idxs = self.generate_batch_vmap(rngs,agent_state.buffer_max_size,num_steps,self.critic_batch_size)
        
    