In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import DeterministicPolicy,Policy, Critic, ensemblize

import flax
import flax.linen as nn

from functools import partial

NUM_CRITICS = 5
NUM_ROLLOUTS = 5
MAX_SIZE = 20

class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: (TrainState)
    o_critic: (TrainState)
    actor: TrainState
    config: dict = nonpytree_field()
    
    @partial(jax.jit,static_argnames=('num_steps',))  
    def update_many_critics(agent,transitions: Batch,idxs:jnp.array,num_steps:int,R2):

        def update_one_critic(critic,idxs,
                            agent,transitions,num_steps):
            
            def one_update(agent,critic,batch: Batch):
                                  
                def critic_loss_fn(critic_params):
                    next_actions = agent.actor(batch['next_observations'])
                    concat_actions = jnp.concatenate([batch["actions"],next_actions])
                    concat_observations = jnp.concatenate([batch["observations"],batch["next_observations"]])
                    
                    concat_q = critic(concat_observations, concat_actions,
                                            True,params=critic_params)
                    q,next_q = jnp.split(concat_q,2,axis=0) ## axis=1 for ensemble
                    
                    target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q
                    target_q = jax.lax.stop_gradient(target_q)
                    
                    critic_loss = ((target_q-q)**2).mean()
                    
                    return critic_loss, {}
        

                
                new_critic, critic_info = critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
                
                return agent,new_critic
            
            
            get_batch = lambda transitions,idx : jax.tree_map(lambda x : x[idx],transitions)
                
            agent,new_critic = jax.lax.fori_loop(0, num_steps, 
                        lambda i, args: one_update(*args,get_batch(transitions,idxs[i])),
                        (agent,critic))
            
            return new_critic
        
        
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)
        critic = agent.critic
        
        ###### Reset critic params ######
        
        reset = lambda rng,params : critic.init(rng,
                                                agent.config["observations"], agent.config["actions"],False)["params"]
        no_reset = lambda rng,params: params
        f = lambda  mask,rng,params :lax.cond(mask,reset,no_reset,rng,params)
        mask = jnp.zeros((NUM_CRITICS))
        mask.at[jnp.argmin(R2)].set(1)
        rngs = jax.random.split(agent.rng, NUM_CRITICS)
        critic_params = jax.vmap(f,in_axes=(0,0,0))(mask,rngs,critic.params)
        ###################################
        critic_def = Critic((256,256))
        critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,critic_params,optax.adam(learning_rate=3e-4))
        tmp = partial(update_one_critic,agent=agent,transitions=transitions,num_steps=num_steps)
        new_critics = jax.vmap(tmp,in_axes=(0,0))(critics,idxs)
        agent = agent.replace(rng=new_rng,critic=new_critics)
        
    
        return agent,{}
    
    
    @partial(jax.jit,static_argnames=('num_steps',)) 
    def update_many_o_critics(agent,transitions: Batch,idxs:jnp.array,num_steps:int,R2,R2_o): 
        
      
        def update_o_critic(o_critic,idxs,
                            agent,transitions,num_steps,R2):
                
                def one_update(agent,o_critic,R2:float,batch: Batch,):
                                    
                    def critic_loss_fn(o_critic_params,R2):
                        next_actions = agent.actor(batch['next_observations'])
                        q = o_critic(batch["observations"], batch["actions"],
                                                True,params=o_critic_params)
                        next_q1 = o_critic(batch["next_observations"], next_actions,
                                                False,params=o_critic_params)
                        call_one_critic = lambda observations,actions,params : o_critic(observations,actions,params=params)
                        next_q = jax.vmap(call_one_critic,in_axes=(None,None,0))(batch['next_observations'], next_actions,agent.critic.params)
                        q_weights = jax.nn.softmax(R2,axis=0)
                        q_weights /= jnp.max(q_weights)
                        next_q = jnp.max(q_weights.reshape(-1,1)*next_q,axis=0)
                        next_q = jnp.maximum(next_q,next_q1)
                        target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q
                        target_q = jax.lax.stop_gradient(target_q)
                        
                        critic_loss = ((target_q-q)**2).mean()
                        
                        return critic_loss, {}

                    tmp = partial(critic_loss_fn,R2=R2)
                    new_critic, critic_info = o_critic.apply_loss_fn(loss_fn=tmp, has_aux=True)
                    
                    return agent,new_critic,R2
                
                
                get_batch = lambda transitions,idx : jax.tree_map(lambda x : x[idx],transitions)
                # o_critic_params = agent.o_critic.init(agent.rng, agent.config["observations"], agent.config["actions"],False)['params']
                # o_critic = agent.o_critic.replace(params=o_critic_params)
                agent,new_critic,_ = jax.lax.fori_loop(0, num_steps, 
                            lambda i, args: one_update(*args,get_batch(transitions,idxs[i])),
                            (agent,o_critic,R2))
            
                
                return new_critic
            
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)
        critic_keys  = jax.random.split(curr_key, NUM_CRITICS)
        critic_def = Critic((256,256))  
        o_critic_params = jax.vmap(critic_def.init,in_axes=(0,None,None))(critic_keys, 
                                                                          agent.config["observations"], agent.config["actions"])["params"]
        critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,o_critic_params,optax.adam(learning_rate=3e-4))
        tmp = partial(update_o_critic,agent=agent,transitions=transitions,num_steps=num_steps,R2=R2)
        new_critics = jax.vmap(tmp,in_axes=(0,0))(critics,idxs)
        agent = agent.replace(rng=new_rng,o_critic=new_critics)
        #######################################################
        # o_critic = agent.o_critic
        # reset = lambda rng,params : o_critic.init(rng,
        #                                         agent.config["observations"], agent.config["actions"],False)["params"]
        # no_reset = lambda rng,params: params
        # f = lambda  mask,rng,params :lax.cond(mask,reset,no_reset,rng,params)
        # mask = jnp.zeros((NUM_CRITICS))
        # mask.at[jnp.argmin(R2_o)].set(1)
        # rngs = jax.random.split(agent.rng, NUM_CRITICS)
        # critic_params = jax.vmap(f,in_axes=(0,0,0))(mask,rngs,o_critic.params)
        # critic_def = Critic((256,256))
        # critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,critic_params,optax.adam(learning_rate=3e-4))
        # tmp = partial(update_o_critic,agent=agent,transitions=transitions,num_steps=num_steps,R2=R2)
        # new_critics = jax.vmap(tmp,in_axes=(0,0))(critics,idxs)
        # agent = agent.replace(rng=new_rng,o_critic=new_critics)
        
        return agent,{}
        
    @jax.jit
    def update_actor(agent,critic_params,batch: Batch,R2):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)
        
        def actor_loss_fn(actor_params,critic_params,R2):
            
            actions = agent.actor(batch['observations'], params=actor_params)
            
            call_one_critic = lambda observations,actions,params : agent.critic(observations,actions,params=params)
            q = jax.vmap(call_one_critic,in_axes=(None,None,0))(batch['observations'], actions,critic_params)##critic
            q_weights = jax.nn.softmax(R2,axis=0)
            q = jnp.sum(q_weights*q,axis=0)
            
            q = q*batch['masks']
            
            #q = agent.o_critic(batch['observations'], actions)
            
            actor_loss = (-q).mean()
            
            return actor_loss, {
                'actor_loss': actor_loss,
              
            }

        loss_fn = partial(actor_loss_fn,critic_params=critic_params,R2=R2)
        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=loss_fn, has_aux=True)

        return agent.replace(rng=new_rng,actor=new_actor,), {**actor_info}
        


    @jax.jit
    def sample_actions(agent,observations: np.ndarray,seed:np.ndarray) -> jnp.ndarray:
        actions = agent.actor(observations)
       
        return actions
    
 

def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = DeterministicPolicy((64,64), action_dim=action_dim,final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=actor_lr))
        
        critic_def = Critic(hidden_dims)
        critic_keys  = jax.random.split(critic_key, NUM_CRITICS)
        critic_params = jax.vmap(critic_def.init,in_axes=(0,None,None))(critic_keys, observations, actions)['params']
        critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,critic_params,optax.adam(learning_rate=critic_lr))

        o_critic_params = jax.vmap(critic_def.init,in_axes=(0,None,None))(critic_keys, observations, actions)['params']
        o_critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,o_critic_params,optax.adam(learning_rate=critic_lr))
        
        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau, 
            observations = observations,
            actions = actions,
            
        ))

        return SACAgent(rng, critic=critics,o_critic = o_critics,actor=actor, config=config)

In [2]:
import functools

def compute_q(anc_agent,obs,actor_params,critic_params):

    actions = anc_agent.actor(obs, params=actor_params)
    q = anc_agent.critic(obs, actions,False,params=critic_params)
   
    return q

def estimate_return(acq_rollout,
                    anc_agent,anc_critic_params,anc_return):
    
    acq_obs = acq_rollout.observations
    acq_masks = acq_rollout.disc_masks
    acq_return = acq_rollout.policy_return
  
    
    anc_actor_params = anc_agent.actor.params
    acq_actor_params = acq_rollout.policy_params
    
    
    anc_q = compute_q(anc_agent,acq_obs,anc_actor_params,anc_critic_params)
    acq_q = compute_q(anc_agent,acq_obs,acq_actor_params,anc_critic_params)
    
    adv = ((acq_q - anc_q)*acq_masks).sum()/NUM_ROLLOUTS
    acq_return_pred = anc_return + adv
  
    
    return acq_return_pred,acq_return


def evaluate_critic(anc_agent,anc_critic_params,
                    anc_return,policy_rollouts):

    
    tmp =  partial(estimate_return,
                   anc_agent=anc_agent,
                   anc_critic_params =anc_critic_params,
                   anc_return = anc_return)
    y_pred,y = jax.vmap(tmp)(policy_rollouts)
    a2 = jnp.clip(((y-y_pred)**2),a_min=1e-4).sum()
    b2=((y-y.mean())**2).sum()
    R2 = 1-(a2/b2)  
    bias = (y_pred-y).mean()
    
    return R2,bias

@jax.jit
def evaluate_critics(anc_agent,anc_critic_params,
                    anc_return,policy_rollouts):
    
    R2,bias = jax.vmap(evaluate_critic,in_axes=(None,0,None,None))(anc_agent,anc_critic_params,anc_return,policy_rollouts)
    
    return R2,bias



def merge(x,y):

    return jax.tree_map(lambda x,y : jnp.vstack([x,y]),x,y)

def flatten_rollouts(policy_rollouts):
    
    n_policies = len(policy_rollouts)
    merged_rollouts = functools.reduce(merge, policy_rollouts)
    merged_rollouts = jax.tree_map(lambda x:jnp.stack(jnp.split(x,n_policies,axis=0)),merged_rollouts)
    
    def reshape_tree(tree, reference_tree,n_policies):
        def reshape_fn(x, reference_x):
            return jnp.reshape(x, (n_policies,*reference_x.shape))
        
        return jax.tree_map(reshape_fn, tree, reference_tree)
    
    merged_rollouts = reshape_tree(merged_rollouts,policy_rollouts[0],n_policies)
    
    return merged_rollouts


In [3]:
import os
from functools import partial
import numpy as np
import jax
import tqdm
import gymnasium as gym

from jaxrl_m.common import CodeTimer
from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import ReplayBuffer
from collections import deque
from jaxrl_m.rollout import PolicyRollout,rollout_policy

env_name='Ant-v4'
seed=np.random.choice(1000000)
eval_episodes=10
batch_size = 256
max_steps = int(1e6)
start_steps = 50000                   
log_interval = 5000
#eval_interval = 10000

wandb_config = default_wandb_config()
wandb_config.update({
    'project': 'delete2',
    'name': 'dpg_{env_name}',
    'hyperparam_dict':{},
})


env = EpisodeMonitor(gym.make(env_name))
eval_env = EpisodeMonitor(gym.make(env_name))
setup_wandb(**wandb_config)

example_transition = dict(
    observations=env.observation_space.sample(),
    actions=env.action_space.sample(),
    rewards=0.0,
    masks=1.0,
    next_observations=env.observation_space.sample(),
    discounts=1.0,
)   

replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))
actor_buffer = ReplayBuffer.create(example_transition, size=int(5e3))

agent = create_learner(seed,
                example_transition['observations'][None],
                example_transition['actions'][None],
                max_steps=max_steps,
                #**FLAGS.config
                )

exploration_metrics = dict()
obs,info = env.reset()    
exploration_rng = jax.random.PRNGKey(0)
i = 0
num_grad_updates = 0
unlogged_steps = 0
policy_rollouts = deque([], maxlen=MAX_SIZE )
R2,R2_o = jnp.ones(NUM_CRITICS),jnp.ones(NUM_CRITICS)

warmup_steps = 0
while warmup_steps < start_steps:
    
    replay_buffer,_,_,policy_return,undisc_policy_return,num_steps = rollout_policy(agent,env,exploration_rng,
                        replay_buffer,actor_buffer,
                        warmup=False,num_rollouts=NUM_ROLLOUTS)
    warmup_steps += num_steps
            
with tqdm.tqdm(total=max_steps) as pbar:
    
    while i < max_steps:
        
            actor_buffer = ReplayBuffer.create(example_transition, size=int(5e3))
            
            replay_buffer,actor_buffer,policy_rollout,policy_return,undisc_policy_return,num_steps = rollout_policy(agent,env,exploration_rng,
                                replay_buffer,actor_buffer,warmup=False,num_rollouts=NUM_ROLLOUTS)
            policy_rollouts.append(policy_rollout)
            
            if i == 0 : policy_rollouts.append(policy_rollout)
            
            unlogged_steps += num_steps
            i+=num_steps
            pbar.update(num_steps)
            
            if replay_buffer.size > start_steps:
        
            #with CodeTimer('update_critic'):
                key = jax.random.PRNGKey(0)
                transitions = replay_buffer.get_all()
                tmp = partial(jax.random.choice,a=replay_buffer.size, shape=(10000,256), replace=True)
                idxs = jax.vmap(tmp)(jax.random.split(key, NUM_CRITICS))
                agent, critic_update_info = agent.update_many_critics(transitions,idxs,10000,R2)
                
            
            #with CodeTimer('evaluate_critic'):        
                critic_params = agent.critic.params
                anc_return = policy_rollouts[-1].policy_return
                flattened_rollouts = flatten_rollouts(policy_rollouts)
                R2,bias = evaluate_critics(agent,critic_params,anc_return,flattened_rollouts)
                # agent, critic_update_info = agent.update_many_o_critics(transitions,idxs,10000,R2,R2_o)
                # o_critic_params = agent.o_critic.params
                # R2_o,bias_o = evaluate_critics(agent,o_critic_params,anc_return,flattened_rollouts)

            #with CodeTimer('update_actor'):
                actor_batch = actor_buffer.get_all() 
                agent, actor_update_info = agent.update_actor(agent.critic.params,actor_batch,R2_o.reshape(-1,1))   
                num_grad_updates += 1 
                
                if unlogged_steps > log_interval:
                    
                    #with CodeTimer('eval'):
                            
                        
                        update_info = {**critic_update_info, **actor_update_info,
                                'R2_validation': jnp.max(R2),'bias': bias[jnp.argmax(R2)],'num_grad_updates': num_grad_updates,
                                #'R2_opt':jnp.max(R2_o),"bias_opt":bias_o[jnp.argmax(R2_o)],
                                }
                        exploration_metrics = {f'exploration/disc_return': policy_return}
                        
                        train_metrics = {f'training/{k}': v for k, v in update_info.items()}
                        
                        # eval_info = evaluate(agent.actor, eval_env, num_episodes=eval_episodes)
                        # eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
                        
                    #with CodeTimer('logging'):
                        
                        wandb.log(train_metrics, step=int(i),commit=False)
                        wandb.log({"undisc_return":undisc_policy_return},step=int(i))
                        wandb.log(exploration_metrics, step=int(i),commit=True)
                        wandb.log({"R2_hist": wandb.Histogram(jnp.clip(R2,a_min=-1,a_max=1))})
                        wandb.log({"R2_o_hist": wandb.Histogram(jnp.clip(R2_o,a_min=-1,a_max=1))})
                        #wandb.log({"bias_hist": wandb.Histogram(bias)})
                        #wandb.log(eval_metrics, step=int(i),commit=True)
                        unlogged_steps = 0

2024-01-09 00:04:09.417052: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


 21%|██        | 209708/1000000 [07:06<26:47, 491.77it/s] 


KeyboardInterrupt: 