In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import Policy, Critic, ensemblize

import flax
import flax.linen as nn
from functools import partial

NUM_ROLLOUTS = 8
NUM_CRITICS = 5

class Temperature(nn.Module):
    initial_temperature: float = 1e-3

    
    @nn.compact
    def __call__(self) -> jnp.ndarray:
        log_temp = self.param('log_temp',
                              init_fn=lambda key: jnp.full(
                                  (), self.initial_temperature))
        return jnp.abs(log_temp)
        

class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: TrainState
    target_critic: TrainState
    actor: TrainState
    temp: TrainState
    config: dict = nonpytree_field()
    
    
    @jax.jit    
    def reset_critic_optimizer(agent):
    
        new_opt_state = agent.critic.tx.init(agent.critic.params)
        new_critic = agent.critic.replace(opt_state=new_opt_state)
        
        return agent.replace(critic=new_critic)

    @partial(jax.jit,static_argnames=('num_steps',))  
    def update_many_critics(agent,transitions: Batch,idxs:jnp.array,num_steps:int,R2):

        def update_one_critic(critic,idxs,
                            agent,transitions,num_steps):
            
            def one_update(agent,critic,batch: Batch):
                                  
                def critic_loss_fn(critic_params):
                            
                            
                            next_dist = agent.actor(batch['next_observations'])
                            next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=next_key)

                            concat_actions = jnp.concatenate([batch["actions"],next_actions])
                            concat_observations = jnp.concatenate([batch["observations"],batch["next_observations"]])
                            
                            concat_q = agent.critic(concat_observations, concat_actions,
                                                    True,params=critic_params)
                            q,next_q = jnp.split(concat_q,2,axis=0) ## axis=1 for ensemble
                            target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q
                            
                            if agent.config['backup_entropy']:
                                target_q = target_q - agent.config['discount'] * batch['masks'] * next_log_probs * agent.temp()
                                
                            target_q = jax.lax.stop_gradient(target_q)
                            critic_loss = ((target_q-q)**2).mean()
                            
                            return critic_loss, {
                            'critic_loss': critic_loss,
                            'q1': q.mean(),
                        }  
                    
        
                new_critic, critic_info = critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
                
                return agent,new_critic
            
            
            get_batch = lambda transitions,idx : jax.tree_map(lambda x : x[idx],transitions)
                
            agent,new_critic = jax.lax.fori_loop(0, num_steps, 
                        lambda i, args: one_update(*args,get_batch(transitions,idxs[i])),
                        (agent,critic))
            
            return new_critic
        
        
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)
        critic = agent.critic
        
        ###### Reset critic params ######
        
        reset = lambda rng,params : critic.init(rng,
                                                agent.config["observations"], agent.config["actions"],False)["params"]
        no_reset = lambda rng,params: params
        f = lambda  mask,rng,params :lax.cond(mask,reset,no_reset,rng,params)
        mask = jnp.zeros((NUM_CRITICS))
        mask.at[jnp.argmin(R2)].set(1)
        rngs = jax.random.split(agent.rng, NUM_CRITICS)
        critic_params = jax.vmap(f,in_axes=(0,0,0))(mask,rngs,critic.params)
        ###################################
        critic_def = Critic((256,256))
        critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,critic_params,optax.adam(learning_rate=3e-4))
        tmp = partial(update_one_critic,agent=agent,transitions=transitions,num_steps=num_steps)
        new_critics = jax.vmap(tmp,in_axes=(0,0))(critics,idxs)
        agent = agent.replace(rng=new_rng,critic=new_critics)
        
        return agent,{}
    
    

        
    @jax.jit
    def update_actor(agent, batch: Batch,R2):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        def actor_loss_fn(actor_params,R2):
            observations = jnp.repeat(batch['observations'], 10, axis=0)
            discounts = jnp.repeat(batch['discounts'], 10, axis=0)
            masks = jnp.int32(jnp.repeat(batch['masks'], 10, axis=0))

            dist = agent.actor(observations, params=actor_params)
            actions, log_probs = dist.sample_and_log_prob(seed=curr_key)
            call_one_critic = lambda observations,actions,params : agent.critic(observations,actions,params=params)
            q_all = jax.vmap(call_one_critic,in_axes=(None,None,0))(observations, actions,agent.critic.params)##critic
            q_weights = jax.nn.softmax(R2,axis=0)
            q = jnp.sum(q_weights.reshape(-1,1)*q_all,axis=0)
            
            actor_loss = (discounts*(log_probs * agent.temp() - q)).sum()/discounts.sum()
            # lr_bonus = jnp.exp(jnp.max(R2))/jnp.exp(1)
            # actor_loss = actor_loss*lr_bonus
           
            return actor_loss, {
                'actor_loss': actor_loss,
                'entropy': -1 * ((discounts*log_probs)/jnp.sum(discounts)).sum(),
                #'entropy': -1 * log_probs.mean(),
            }
        
        
        def temp_loss_fn(temp_params, entropy, target_entropy):
            temperature = agent.temp(params=temp_params)
            entropy_diff = entropy-target_entropy
            #entropy_diff =  jax.lax.cond(entropy_diff>0, lambda x: x**2, lambda x: -x**2,entropy_diff)
            temp_loss = (temperature * entropy_diff).mean()
            return temp_loss, {
                'temp_loss': temp_loss,
                'temperature': temperature,
                'entropy_diff': entropy_diff,
            }

        loss_fn = partial(actor_loss_fn,R2=R2)
        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=loss_fn, has_aux=True)
        temp_loss_fn = functools.partial(temp_loss_fn, entropy=actor_info['entropy'], target_entropy=agent.config['target_entropy'])
        new_temp, temp_info = agent.temp.apply_loss_fn(loss_fn=temp_loss_fn, has_aux=True)
        new_temp.params["log_temp"]=jnp.clip(new_temp.params["log_temp"],1e-6,1)
        agent = agent.replace(rng=new_rng, temp=new_temp)
        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=loss_fn, has_aux=True)
        
        return agent.replace(rng=new_rng, actor=new_actor), {**actor_info, **temp_info}

    @jax.jit
    def sample_actions(agent,   
                       observations: np.ndarray,
                       seed: PRNGKey,
                       random = bool,
                       temperature: float = 1.0,
                       ) -> jnp.ndarray:
        actions = agent.actor(observations, temperature=temperature).sample(seed=seed)
        
        return actions



def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 temp_lr: float =3e-1,## Test
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 target_entropy: float = None,
                 backup_entropy: bool = True,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = Policy((256,256), action_dim=action_dim, 
            log_std_min=-10.0, state_dependent_std=True, tanh_squash_distribution=True, final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        
        #actor = TrainState.create(actor_def, actor_params, tx=optax.sgd(learning_rate=1e-3,momentum=0.9))
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=3e-4))
        
        
        critic_def = Critic(hidden_dims)
        critic_keys  = jax.random.split(critic_key, NUM_CRITICS)
        critic_params = jax.vmap(critic_def.init,in_axes=(0,None,None))(critic_keys, observations, actions)['params']
        critics = jax.vmap(TrainState.create,in_axes=(None,0,None))(critic_def,critic_params,optax.adam(learning_rate=3e-4))

        temp_def = Temperature()
        temp_params = temp_def.init(rng)['params']
        #temp = TrainState.create(temp_def, temp_params, tx=optax.sgd(learning_rate=3e-4,momentum=0.5))
        temp = TrainState.create(temp_def, temp_params, tx=optax.sgd(learning_rate=1e-3))
        
        if target_entropy is None:
            target_entropy = -0.5 * action_dim

        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau,
            target_entropy=target_entropy,
            backup_entropy=backup_entropy,  
            observations=observations,
            actions=actions,          
        ))

        return SACAgent(rng, critic=critics, target_critic=critics, actor=actor, temp=temp, config=config)



In [2]:

from jaxrl_m.rollout import PolicyRollout


def f(anc_agent,obs,actor_params,critic_params,seed):

    dist = anc_agent.actor(obs, params=actor_params)
    actions, _ = dist.sample_and_log_prob(seed=seed)
    q = anc_agent.critic(obs, actions,params=critic_params)
    
    return q
    
@jax.jit
def estimate_return(acq_rollout,
                    anc_agent,anc_critic_params,anc_return,seed):
    
    acq_obs = jnp.repeat(acq_rollout.observations,10,axis=0)
    acq_disc_masks = jnp.repeat(acq_rollout.disc_masks,10,axis=0)
    
    acq_actor = acq_rollout.policy_params
    acq_return = acq_rollout.policy_return
    anc_actor = anc_agent.actor.params
    
    acq_q = f(anc_agent,acq_obs,acq_actor,anc_critic_params,seed)
    anc_q = f(anc_agent,acq_obs,anc_actor,anc_critic_params,seed)
    
    adv = ((acq_q - anc_q)*acq_disc_masks).sum()/(acq_rollout.num_rollouts *10)
    acq_return_pred = anc_return + adv
    
    std = jnp.sqrt(acq_rollout.variance)
    #return acq_return_pred,acq_return
    return acq_return_pred,acq_return

@jax.jit
def evaluate_one_critic(anc_critic_params,anc_agent,
                        anc_return,policy_rollouts,seed):
    
    predict_rollout =  partial(estimate_return,
                   anc_agent=anc_agent,
                   anc_critic_params =anc_critic_params,
                   anc_return = anc_return,seed=seed)
    y_pred,y = jax.vmap(predict_rollout)(policy_rollouts)
    var = jax.vmap(lambda rollout : rollout.variance)(policy_rollouts)
    weights = 1/(var+1e-8)
    a2 = (weights*(y-y_pred)**2).sum()
    b2 = (weights*(y-y.mean())**2).sum()
    b2=jnp.clip(b2,1e-8)
    
    R2 = 1-(a2/b2)  
    bias = (y_pred-y).mean()
    
    return R2,bias

#@jax.jit
def evaluate_many_critics(anc_agent,anc_return,policy_rollouts):
    
    seed = anc_agent.rng
    anc_critic_params = anc_agent.critic.params
    
    
    tmp = partial(evaluate_one_critic,
                anc_agent=anc_agent,
                anc_return = anc_return,
                policy_rollouts=policy_rollouts,seed=seed)

    
    R2_l,bias_l = [],[]
    for i in range(NUM_CRITICS):
        critic_params = jax.tree_map(lambda x : x[i],anc_critic_params)
        R2,bias = tmp(critic_params)
        R2_l.append(R2),bias_l.append(bias)
    
    R2 = jnp.vstack(R2_l)
    bias = jnp.vstack(bias_l)
    
    return R2,bias




In [3]:
import os
from functools import partial
import numpy as np
import jax
import tqdm
import gymnasium as gym


from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import ReplayBuffer
from collections import deque
from jax import config
from jaxrl_m.utils import flatten_rollouts
# config.update("jax_debug_nans", True)
# config.update("jax_enable_x64", True)
        
from jaxrl_m.rollout import rollout_policy2,rollout_policy

env_name='Walker2d-v4'
seed=np.random.choice(1000000)
eval_episodes=10
batch_size = 256
max_steps = int(1e6)
start_steps = 10000                 
log_interval = 5000
#eval_interval = 10000

#wandb_config = default_wandb_config()
wandb_config = {
    'project': 'on_policy_sac99_sgdtemp2',
    'name': 'sac_{env_name}_{seed}'.format(env_name=env_name, seed=seed),
    'hyperparam_dict':{'env_name':env_name,'seed':seed},
}

env = EpisodeMonitor(gym.make(env_name,max_episode_steps=625))
eval_env = EpisodeMonitor(gym.make(env_name))
setup_wandb(**wandb_config)

example_transition = dict(
    observations=env.observation_space.sample(),
    actions=env.action_space.sample(),
    rewards=0.0,
    masks=1.0,
    next_observations=env.observation_space.sample(),
    discounts=1.0,
)

replay_buffer = ReplayBuffer.create(example_transition, size=int(500_000))
actor_buffer = ReplayBuffer.create(example_transition, size=int(10e3))

agent = create_learner(seed,
                example_transition['observations'][None],
                example_transition['actions'][None],
                max_steps=max_steps,
                #**FLAGS.config
                )

exploration_metrics = dict()
obs,info = env.reset()    
exploration_rng = jax.random.PRNGKey(0)
i = 0
unlogged_steps = 0
policy_rollouts = deque([], maxlen=25)
warmup = True
R2 = jnp.ones(NUM_CRITICS)

with tqdm.tqdm(total=max_steps) as pbar:
    
    while (i < max_steps):

        warmup=(i < start_steps)
        replay_buffer,actor_buffer,policy_rollout,policy_return,variance,undisc_policy_return,num_steps = rollout_policy(
                                                                agent,env,exploration_rng,
                                                                replay_buffer,actor_buffer,warmup=warmup,
                                                                num_rollouts=NUM_ROLLOUTS,random=False,
                                                                )
        
        if not warmup : policy_rollouts.append(policy_rollout)
        unlogged_steps += num_steps
        i+=num_steps
        pbar.update(num_steps)
            
        if replay_buffer.size > start_steps and len(policy_rollouts)>0:
        
            ### Update critics ###
            transitions = replay_buffer.get_all()
            tmp = partial(jax.random.choice,a=replay_buffer.size, shape=(10000,256), replace=True)
            idxs = jax.vmap(tmp)(jax.random.split(agent.rng, NUM_CRITICS))
            agent, critic_update_info = agent.update_many_critics(transitions,idxs,10000,R2)

            ### Update critic weights ## 
            if len(policy_rollouts)>=10:  
                with jax.default_matmul_precision('bfloat16'):
                                          
                    flattened_rollouts = flatten_rollouts(policy_rollouts)
                    R2,bias = evaluate_many_critics(agent,policy_rollout.policy_return,flattened_rollouts)
        
            ### Update actor ###
            actor_batch = actor_buffer.get_all()      
            agent, actor_update_info = agent.update_actor(actor_batch,R2)    
            update_info = {**critic_update_info, **actor_update_info}
            
            if unlogged_steps > log_interval:
                
                policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
                eval_info = evaluate(policy_fn, eval_env, num_episodes=eval_episodes)
                
                eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
                exploration_metrics = {f'exploration/disc_return': policy_return,'training/std': jnp.sqrt(variance)}
                train_metrics = {f'training/{k}': v for k, v in update_info.items()}
                train_metrics['training/undisc_return'] = undisc_policy_return
                if len(policy_rollouts)>=10:
                    R2_train_info = {'R2/max': jnp.max(R2),'R2/bias': bias[jnp.argmax(R2)],
                                    "R2/histogram": wandb.Histogram(jnp.clip(R2,a_min=-1,a_max=1)),
                                    }
                    wandb.log(R2_train_info, step=int(i),commit=False)
                wandb.log(exploration_metrics, step=int(i),commit=False)
                wandb.log(train_metrics, step=int(i),commit=False)
                wandb.log(eval_metrics, step=int(i),commit=True)
                unlogged_steps = 0

2024-02-02 10:38:29.227892: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


  2%|▏         | 16851/1000000 [01:16<1:13:54, 221.69it/s]


KeyboardInterrupt: 