In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import DeterministicPolicy,Policy, Critic, ensemblize

import flax
import flax.linen as nn

from functools import partial



class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: TrainState
    target_critic: TrainState
    actor: TrainState
    config: dict = nonpytree_field()
    
    
        
    
    @jax.jit    
    def reset_critic_optimizer(agent):
    
        new_opt_state = agent.critic.tx.init(agent.critic.params)
        new_critic = agent.critic.replace(opt_state=new_opt_state)
        
        return agent.replace(critic=new_critic)
        
    @partial(jax.jit,static_argnames=('num_steps',))  
    def update_critic(agent, transitions: Batch,idxs:jnp.array,num_steps:int):
        
        
        
        def one_update(agent, batch: Batch):
                
            new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        
            def critic_loss_fn(critic_params):
                next_actions = agent.actor(batch['next_observations'])
                concat_actions = jax.numpy.vstack([batch["actions"],next_actions])
                concat_observations = jax.numpy.vstack([batch["observations"],batch["next_observations"]])
                
                
                # concat_q,updates = agent.critic(concat_observations, concat_actions,
                #                         training=True,params=critic_params,mutable=['batch_stats'])
                
                concat_q = agent.critic(concat_observations, concat_actions,
                                        training=True,params=critic_params)
                
                
                q,next_q = concat_q[:256],concat_q[256:]
             
                target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q
                target_q = jax.lax.stop_gradient(target_q)
                
                critic_loss = ((target_q-q)**2).mean()
                
                #return critic_loss, updates
                return critic_loss, {}
            
            new_critic,_ = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
            # new_critic, updates = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
            # params = new_critic.params
            # params["batch_state"]=updates["batch_state"]
            # new_critic = new_critic.replace(params=params)
            
            
            #return agent.replace(rng=new_rng, critic=new_critic, target_critic=new_critic)
            return agent.replace(rng=new_rng, critic=new_critic)
        
        
        get_batch = lambda transitions,idx : jax.tree_map(lambda x : x[idx],transitions)
        
        
        # for i in range(num_steps):
        #     agent = one_update(agent,get_batch(transitions,idxs[i]))
        agent = jax.lax.fori_loop(0, num_steps, 
                        lambda i, agent: one_update(agent,get_batch(transitions,idxs[i])),
                        agent)
        
        return agent,{}
       
    

        
    @jax.jit
    def update_actor(agent, batch: Batch):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        def actor_loss_fn(actor_params):
            
            actions = agent.actor(batch['observations'], params=actor_params)
            qs = agent.critic(batch['observations'], actions)
            #q = qs.mean(axis=0)
            q = qs
            actor_loss = (-q).mean()
            
            return actor_loss, {
                'actor_loss': actor_loss,
              
            }

        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)

        return agent.replace(rng=new_rng,actor=new_actor,), {**actor_info}
        


    @jax.jit
    def sample_actions(agent,observations: np.ndarray) -> jnp.ndarray:
        actions = agent.actor(observations)
       
        return actions
    
 

def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = DeterministicPolicy((64,64), action_dim=action_dim,final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=actor_lr))
        
        
        #critic_def = ensemblize(Critic, num_qs=2,split_rngs={"dropout":True})(hidden_dims)
        critic_def = Critic(hidden_dims=hidden_dims)
        critic_params = critic_def.init(critic_key, observations, actions,training=False)['params']
        critic = TrainState.create(critic_def, critic_params, tx=optax.adam(learning_rate=critic_lr))
        target_critic = TrainState.create(critic_def, critic_params)        


        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau,    
        ))

        return SACAgent(rng, critic=critic, target_critic=target_critic, actor=actor, config=config)
        


In [2]:

from jaxrl_m.rollout import PolicyRollout,rollout_policy


def f(anc_agent,obs,actor):

    actions = anc_agent.actor(obs, params=actor)
    qs = anc_agent.critic(obs, actions,params=anc_agent.critic.params,training=False)
    #q = jnp.minimum(q1, q2)
    #q = qs.mean(axis=0)
    q = qs
    
    return q
    

@jax.jit
def estimate_return(anc_agent,anc_return,acq_rollout:PolicyRollout,):
    
    acq_obs = acq_rollout.observations
    acq_masks = acq_rollout.disc_masks
  
    acq_actor = acq_rollout.policy_params
    acq_return = acq_rollout.policy_return
    
    anc_actor = anc_agent.actor.params
    
    acq_q = f(anc_agent,acq_obs,acq_actor)
    anc_q = f(anc_agent,acq_obs,anc_actor)
    
    adv = ((acq_q - anc_q)*acq_masks).sum()/5
    acq_return_pred = anc_return + adv
  
    
    return acq_return_pred,acq_return


def evaluate_critic(anc_agent,anc_return,policy_rollouts):

    y_pred,y= [],[]
    for policy_rollout in policy_rollouts:
        
        acq_return_pred,acq_return = estimate_return(anc_agent,anc_return,policy_rollout)
        y_pred.append(acq_return_pred),y.append(acq_return)
        
    y_pred,y = np.array(y_pred),np.array(y)
    a2 = jnp.clip(((y-y_pred)**2),a_min=1e-4).sum()
    b2=((y-y.mean())**2).sum()
    R2 = 1-(a2/b2)  
    
    return R2


In [3]:
    import os
    from functools import partial
    import numpy as np
    import jax
    import tqdm
    import gymnasium as gym


    from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
    import wandb
    from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
    from jaxrl_m.dataset import ReplayBuffer
    from collections import deque


    env_name='Ant-v4'
    seed=np.random.choice(1000000)
    eval_episodes=10
    batch_size = 256
    max_steps = int(1e6)
    start_steps = 50000                   
    log_interval = 5000
    #eval_interval = 10000

    wandb_config = default_wandb_config()
    wandb_config.update({
        'project': 'd4rl_test',
        'group': 'sac_test',
        'name': 'sac_{env_name}',
    })


    env = EpisodeMonitor(gym.make(env_name))
    eval_env = EpisodeMonitor(gym.make(env_name))
    setup_wandb({"bonjour":1})

    example_transition = dict(
        observations=env.observation_space.sample(),
        actions=env.action_space.sample(),
        rewards=0.0,
        masks=1.0,
        next_observations=env.observation_space.sample(),
        discounts=1.0,
    )

    replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))
    actor_buffer = ReplayBuffer.create(example_transition, size=int(5e3))

    agent = create_learner(seed,
                    example_transition['observations'][None],
                    example_transition['actions'][None],
                    max_steps=max_steps,
                    #**FLAGS.config
                    )

    exploration_metrics = dict()
    obs,info = env.reset()    
    exploration_rng = jax.random.PRNGKey(0)
    i = 0
    unlogged_steps = 0
    policy_rollouts = deque([], maxlen=10)
    with tqdm.tqdm(total=max_steps) as pbar:
        
        while i < max_steps:
        
            replay_buffer,actor_buffer,policy_rollout,policy_return,num_steps = rollout_policy(agent,env,exploration_rng,
                    replay_buffer,actor_buffer,
                    warmup=(i < start_steps))
            policy_rollouts.append(policy_rollout)
            unlogged_steps += num_steps
            i+=num_steps
            pbar.update(num_steps)
            
                
            if replay_buffer.size > start_steps:
            
            
                transitions = replay_buffer.get_all()
                idxs = jax.random.choice(a=replay_buffer.size, shape=(5000,256), replace=True,key=jax.random.PRNGKey(0))
                agent = agent.reset_critic_optimizer()
                agent, critic_update_info = agent.update_critic(transitions,idxs,5000)
                R2 = evaluate_critic(agent,policy_rollouts[-1].policy_return,policy_rollouts)

                
                actor_batch = actor_buffer.get_all()      
                agent, actor_update_info = agent.update_actor(actor_batch)    
                
                
                update_info = {**critic_update_info, **actor_update_info, 'R2_validation': R2}
                
                
                if unlogged_steps > log_interval:
                    exploration_metrics = {f'exploration/disc_return': policy_return}
                    wandb.log(exploration_metrics, step=int(i),commit=False)
                    train_metrics = {f'training/{k}': v for k, v in update_info.items()}
                    wandb.log(train_metrics, step=int(i),commit=False)
                    #wandb.log(exploration_metrics, step=i)
                    policy_fn = agent.actor
                    eval_info = evaluate(policy_fn, eval_env, num_episodes=eval_episodes)
                    eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
                    print('evaluating')
                    wandb.log(eval_metrics, step=int(i),commit=True)
                    unlogged_steps = 0


2023-12-19 02:56:01.602824: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


  5%|▌         | 50414/1000000 [00:19<02:51, 5528.23it/s]

evaluating


  6%|▌         | 56449/1000000 [00:41<53:12, 295.54it/s]  

evaluating


  6%|▌         | 61661/1000000 [00:58<2:08:14, 121.95it/s]

evaluating


  7%|▋         | 66923/1000000 [01:08<1:01:25, 253.15it/s]

evaluating


  7%|▋         | 73063/1000000 [01:30<18:17, 844.52it/s]  

evaluating


  8%|▊         | 80436/1000000 [01:50<27:54, 549.28it/s]  

evaluating


  9%|▊         | 87019/1000000 [02:15<33:52, 449.08it/s]  

evaluating


  9%|▉         | 93021/1000000 [02:50<1:37:11, 155.52it/s]

evaluating


 10%|█         | 101044/1000000 [03:10<26:30, 565.28it/s] 

evaluating


 11%|█         | 108799/1000000 [03:24<35:06, 423.11it/s]

evaluating


 12%|█▏        | 115354/1000000 [04:20<2:12:53, 110.94it/s]

evaluating


 12%|█▏        | 121551/1000000 [04:40<47:31, 308.11it/s]  

evaluating


 13%|█▎        | 129856/1000000 [05:20<37:31, 386.55it/s]  

evaluating


 14%|█▎        | 137305/1000000 [05:40<36:29, 393.99it/s]  

evaluating


 14%|█▍        | 144483/1000000 [06:30<55:53, 255.11it/s]  

evaluating


 15%|█▌        | 150917/1000000 [07:00<49:35, 285.38it/s]  

evaluating


 16%|█▌        | 157825/1000000 [07:50<1:05:17, 214.96it/s]

evaluating


 16%|█▋        | 164263/1000000 [08:20<1:04:30, 215.93it/s]

evaluating


 17%|█▋        | 172736/1000000 [09:10<1:04:47, 212.79it/s]

evaluating


 18%|█▊        | 178849/1000000 [09:50<1:11:33, 191.27it/s]

evaluating


 19%|█▊        | 186215/1000000 [10:30<1:05:48, 206.09it/s]

evaluating


 19%|█▉        | 193643/1000000 [11:40<1:25:26, 157.29it/s]

evaluating


 20%|██        | 201628/1000000 [12:30<1:18:18, 169.92it/s]

evaluating


 21%|██        | 209964/1000000 [13:30<1:13:40, 178.70it/s]

evaluating


 22%|██▏       | 219165/1000000 [14:10<1:01:09, 212.81it/s]

evaluating


 23%|██▎       | 226903/1000000 [15:10<1:13:03, 176.39it/s]

evaluating


 24%|██▎       | 235225/1000000 [16:10<1:15:37, 168.56it/s]

evaluating


 24%|██▍       | 244109/1000000 [17:10<1:08:47, 183.12it/s]

evaluating


 25%|██▌       | 253138/1000000 [18:20<1:12:57, 170.63it/s]

evaluating


 26%|██▌       | 261061/1000000 [19:20<1:17:47, 158.33it/s]

evaluating


 27%|██▋       | 269607/1000000 [20:30<1:22:20, 147.85it/s]

evaluating


 28%|██▊       | 279395/1000000 [21:30<1:08:13, 176.02it/s]

evaluating


 29%|██▉       | 289353/1000000 [22:40<1:07:36, 175.20it/s]

evaluating


 30%|██▉       | 298180/1000000 [23:50<1:13:06, 159.99it/s]

evaluating


 31%|███       | 306312/1000000 [24:50<1:14:50, 154.48it/s]

evaluating


 32%|███▏      | 316249/1000000 [26:00<1:10:18, 162.08it/s]

evaluating


 33%|███▎      | 325188/1000000 [27:10<1:11:07, 158.12it/s]

evaluating


 34%|███▎      | 335188/1000000 [28:20<1:03:49, 173.60it/s]

evaluating


 34%|███▍      | 344569/1000000 [29:20<1:03:38, 171.66it/s]

evaluating


 35%|███▌      | 352915/1000000 [30:40<1:09:50, 154.43it/s]

evaluating


 36%|███▌      | 361448/1000000 [31:40<1:07:25, 157.83it/s]

evaluating


 37%|███▋      | 371442/1000000 [32:50<1:05:50, 159.10it/s]

evaluating


 38%|███▊      | 380321/1000000 [34:10<1:07:21, 153.33it/s]

evaluating


 39%|███▉      | 389412/1000000 [35:00<1:00:33, 168.06it/s]

evaluating


 40%|███▉      | 398936/1000000 [36:00<55:13, 181.37it/s]  

evaluating


 41%|████      | 407285/1000000 [37:00<57:23, 172.13it/s]  

evaluating


 42%|████▏     | 415605/1000000 [38:00<50:55, 191.23it/s]  

evaluating


 43%|████▎     | 425328/1000000 [39:20<56:32, 169.40it/s]  

evaluating


 43%|████▎     | 433942/1000000 [40:10<53:08, 177.55it/s]  

evaluating


 44%|████▍     | 441555/1000000 [41:20<1:00:33, 153.67it/s]

evaluating


 45%|████▍     | 449319/1000000 [42:20<59:46, 153.53it/s]  

evaluating


 46%|████▌     | 458397/1000000 [43:40<59:52, 150.75it/s]  

evaluating


 47%|████▋     | 466209/1000000 [44:40<1:02:24, 142.56it/s]

evaluating


 47%|████▋     | 474685/1000000 [45:50<58:32, 149.54it/s]  

evaluating


 48%|████▊     | 484685/1000000 [47:00<54:33, 157.41it/s]  

evaluating


 49%|████▉     | 494685/1000000 [48:10<52:16, 161.10it/s]  

evaluating


 50%|█████     | 504685/1000000 [49:20<48:28, 170.31it/s]  

evaluating


 51%|█████▏    | 514610/1000000 [50:30<46:33, 173.77it/s]  

evaluating


 52%|█████▏    | 524610/1000000 [51:40<47:26, 167.02it/s]  

evaluating


 53%|█████▎    | 534088/1000000 [53:00<48:35, 159.79it/s]  

evaluating


 54%|█████▍    | 544088/1000000 [54:10<47:03, 161.45it/s]  

evaluating


 55%|█████▌    | 554088/1000000 [55:20<45:37, 162.91it/s]  

evaluating


 56%|█████▋    | 564088/1000000 [56:30<44:22, 163.72it/s]  

evaluating


 57%|█████▋    | 574088/1000000 [57:50<43:09, 164.50it/s]  

evaluating


 58%|█████▊    | 584088/1000000 [59:00<42:07, 164.57it/s]

evaluating


 59%|█████▉    | 594088/1000000 [1:00:10<41:07, 164.54it/s]

evaluating


 60%|██████    | 604088/1000000 [1:01:20<40:06, 164.49it/s]

evaluating


 61%|██████▏   | 613025/1000000 [1:02:40<40:50, 157.92it/s]

evaluating


 62%|██████▏   | 623025/1000000 [1:03:50<38:53, 161.55it/s]

evaluating


 63%|██████▎   | 633025/1000000 [1:05:00<37:31, 162.96it/s]

evaluating


 64%|██████▍   | 642527/1000000 [1:06:10<36:41, 162.39it/s]

evaluating


 65%|██████▌   | 652527/1000000 [1:07:20<35:20, 163.89it/s]

evaluating


 66%|██████▋   | 662527/1000000 [1:08:40<34:14, 164.30it/s]

evaluating


 67%|██████▋   | 672527/1000000 [1:09:50<33:09, 164.60it/s]

evaluating


 68%|██████▊   | 682527/1000000 [1:10:50<30:36, 172.84it/s]

evaluating


 69%|██████▉   | 692527/1000000 [1:12:10<30:19, 168.95it/s]

evaluating


 70%|███████   | 702527/1000000 [1:13:10<28:30, 173.86it/s]

evaluating


 71%|███████▏  | 712527/1000000 [1:14:30<28:36, 167.48it/s]

evaluating


 72%|███████▏  | 721598/1000000 [1:15:40<29:42, 156.20it/s]

evaluating


 73%|███████▎  | 731598/1000000 [1:16:50<28:05, 159.26it/s]

evaluating


 74%|███████▍  | 741598/1000000 [1:18:10<26:49, 160.57it/s]

evaluating


 75%|███████▌  | 751598/1000000 [1:19:20<25:23, 163.03it/s]

evaluating


 76%|███████▌  | 761136/1000000 [1:20:30<24:55, 159.68it/s]

evaluating


 77%|███████▋  | 771136/1000000 [1:21:50<23:41, 161.04it/s]

evaluating


 78%|███████▊  | 780855/1000000 [1:23:00<22:49, 160.00it/s]

evaluating


 79%|███████▉  | 790855/1000000 [1:24:10<21:04, 165.44it/s]

evaluating


 80%|████████  | 800357/1000000 [1:25:20<20:44, 160.41it/s]

evaluating


 81%|████████  | 810357/1000000 [1:26:30<18:53, 167.38it/s]

evaluating


 82%|████████▏ | 818781/1000000 [1:27:40<19:55, 151.62it/s]

evaluating


 83%|████████▎ | 828781/1000000 [1:29:00<18:08, 157.25it/s]

evaluating


 84%|████████▍ | 838499/1000000 [1:30:10<17:06, 157.29it/s]

evaluating


 85%|████████▍ | 848162/1000000 [1:31:20<15:55, 158.89it/s]

evaluating


 86%|████████▌ | 858162/1000000 [1:32:40<14:34, 162.29it/s]

evaluating


 87%|████████▋ | 867232/1000000 [1:33:50<14:00, 157.91it/s]

evaluating


 88%|████████▊ | 877232/1000000 [1:35:00<12:40, 161.50it/s]

evaluating


 89%|████████▊ | 887232/1000000 [1:36:10<11:24, 164.78it/s]

evaluating


 90%|████████▉ | 897232/1000000 [1:37:20<09:59, 171.30it/s]

evaluating


 91%|█████████ | 907232/1000000 [1:38:30<09:16, 166.80it/s]

evaluating


 92%|█████████▏| 917232/1000000 [1:39:40<08:23, 164.29it/s]

evaluating


 93%|█████████▎| 927232/1000000 [1:41:00<07:25, 163.43it/s]

evaluating


 94%|█████████▎| 936362/1000000 [1:42:00<06:23, 165.86it/s]

evaluating


 95%|█████████▍| 946012/1000000 [1:43:10<05:24, 166.12it/s]

evaluating


 96%|█████████▌| 955155/1000000 [1:44:20<04:36, 162.48it/s]

evaluating


 96%|█████████▋| 964800/1000000 [1:45:30<03:34, 163.74it/s]

evaluating


 97%|█████████▋| 973834/1000000 [1:46:50<02:48, 155.50it/s]

evaluating


 98%|█████████▊| 983103/1000000 [1:47:50<01:47, 156.76it/s]

evaluating


 99%|█████████▉| 991255/1000000 [1:49:00<00:58, 149.51it/s]

evaluating


1000820it [1:51:02, 150.21it/s]                            

evaluating





In [4]:
def f(i,carry):
    return carry+i

a = np.array([1,2,3,4])

lax.fori_loop(0, 10, f,0)


Array(45, dtype=int32, weak_type=True)