In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import Policy, Critic, ensemblize

import flax
import flax.linen as nn

class Temperature(nn.Module):
    initial_temperature: float = 1.0

    @nn.compact
    def __call__(self) -> jnp.ndarray:
        log_temp = self.param('log_temp',
                              init_fn=lambda key: jnp.full(
                                  (), jnp.log(self.initial_temperature)))
        return jnp.exp(log_temp)

class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: TrainState
    target_critic: TrainState
    actor: TrainState
    temp: TrainState
    config: dict = nonpytree_field()

    
    @jax.jit
    def update_critic(agent, batch: Batch):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

     
        
        def critic_loss_fn(critic_params):
            
            next_dist = agent.actor(batch['next_observations'])
            next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=next_key)

            next_q1, next_q2 = agent.target_critic(batch['next_observations'], next_actions,True,
                                                   params=None,rngs={'dropout': next_key})
            next_q = jnp.minimum(next_q1, next_q2)
            target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q

            if agent.config['backup_entropy']:
                target_q = target_q - agent.config['discount'] * batch['masks'] * next_log_probs * agent.temp()
            
            
            q1, q2 = agent.critic(batch['observations'], batch['actions'],True,
                                                params=critic_params,rngs={'dropout': curr_key},
                                                )
            critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
            
            return critic_loss, {
                'critic_loss': critic_loss,
                'q1': q1.mean(),
            }     
    
  
    

        
        new_critic, critic_info = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
        new_target_critic = target_update(agent.critic, agent.target_critic, agent.config['target_update_rate'])
        
        return agent.replace(rng=new_rng, critic=new_critic, target_critic=new_target_critic), {**critic_info}
        
        
        
    @jax.jit
    def update_actor(agent, batch: Batch):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        def actor_loss_fn(actor_params):
            # dist = agent.actor(batch['observations'], params=actor_params)
            # actions, log_probs = dist.sample_and_log_prob(seed=curr_key)
            # q1, q2 = agent.critic(batch['observations'], actions)
            

            dist = agent.actor(jnp.repeat(batch['observations'],5,axis=0), params=actor_params)
            actions, log_probs = dist.sample_and_log_prob(seed=curr_key)
            q1, q2 = agent.critic(jnp.repeat(batch['observations'],5,axis=0), actions)
            

            q = jnp.minimum(q1, q2)

            actor_loss = ((log_probs * agent.temp() - q)).mean()
            #actor_loss = ((log_probs * agent.temp() - q)*weights).sum()
            return actor_loss, {
                'actor_loss': actor_loss,
                'entropy': -1 * log_probs.mean(),
            }
        
                
        def temp_loss_fn(temp_params, entropy, target_entropy):
            temperature = agent.temp(params=temp_params)
            temp_loss = (temperature * (entropy - target_entropy)).mean()
            return temp_loss, {
                'temp_loss': temp_loss,
                'temperature': temperature,
            }

        
        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)
        temp_loss_fn = functools.partial(temp_loss_fn, entropy=actor_info['entropy'], target_entropy=agent.config['target_entropy'])
        new_temp, temp_info = agent.temp.apply_loss_fn(loss_fn=temp_loss_fn, has_aux=True)


        return agent.replace(rng=new_rng,actor=new_actor, temp=new_temp), {**actor_info, **temp_info}
        


    @jax.jit
    def sample_actions(agent,   
                       observations: np.ndarray,
                       *,
                       seed: PRNGKey,
                       temperature: float = 1.0,
                       ) -> jnp.ndarray:
        actions = agent.actor(observations, temperature=temperature).sample(seed=seed)
        #actions = jnp.clip(actions, -1, 1)
        return actions



def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 temp_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 target_entropy: float = None,
                 backup_entropy: bool = True,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = Policy((hidden_dims), action_dim=action_dim, 
            log_std_min=-10.0, state_dependent_std=True, tanh_squash_distribution=True, final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=actor_lr))

        critic_def = ensemblize(Critic, num_qs=2,split_rngs={"dropout":True})(hidden_dims)
        critic_params = critic_def.init(critic_key, observations, actions)['params']
        critic = TrainState.create(critic_def, critic_params, tx=optax.adam(learning_rate=critic_lr))
        target_critic = TrainState.create(critic_def, critic_params)

        temp_def = Temperature()
        temp_params = temp_def.init(rng)['params']
        temp = TrainState.create(temp_def, temp_params, tx=optax.adam(learning_rate=temp_lr))

        if target_entropy is None:
            target_entropy = -0.5 * action_dim

        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau,
            target_entropy=target_entropy,
            backup_entropy=backup_entropy,            
        ))

        return SACAgent(rng, critic=critic, target_critic=target_critic, actor=actor, temp=temp, config=config)

def get_default_config():
    import ml_collections

    return ml_collections.ConfigDict({
        'actor_lr': 3e-4,
        'critic_lr': 3e-4,
        'temp_lr': 3e-4,
        'hidden_dims': (256, 256),
        'discount': 0.99,
        'tau': 0.005,
        'target_entropy': ml_collections.config_dict.placeholder(float),
        'backup_entropy': True,
    })

In [2]:
import os
from functools import partial
import numpy as np
import jax
import tqdm
import gymnasium as gym

#import examples.mujoco.sac as learner

from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import ReplayBuffer

#from ml_collections import config_flags
import pickle
#from flax.training import checkpoints


#FLAGS = flags.FLAGS
env_name='Hopper-v4'
seed=np.random.choice(1000000)
eval_episodes=10
batch_size = 256
max_steps = int(1e6)
start_steps = int(1e4)                     
log_interval = 10000
eval_interval = 10000

wandb_config = default_wandb_config()
wandb_config.update({
    'project': 'd4rl_test',
    'group': 'sac_test',
    'name': 'sac_{env_name}',
})


env = EpisodeMonitor(gym.make(env_name))
eval_env = EpisodeMonitor(gym.make(env_name))
setup_wandb({"bonjour":1})

example_transition = dict(
    observations=env.observation_space.sample(),
    actions=env.action_space.sample(),
    rewards=0.0,
    masks=1.0,
    next_observations=env.observation_space.sample(),
)

replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))

#agent = learner.create_learner(seed,
agent = create_learner(seed,
                example_transition['observations'][None],
                example_transition['actions'][None],
                max_steps=max_steps,
                #**FLAGS.config
                )

exploration_metrics = dict()
obs,info = env.reset()    
exploration_rng = jax.random.PRNGKey(0)

for i in tqdm.tqdm(range(1, max_steps + 1),
                    smoothing=0.1,
                    dynamic_ncols=True):

    if i < start_steps:
        action = env.action_space.sample()
    else:
        exploration_rng, key = jax.random.split(exploration_rng)
        action = agent.sample_actions(obs, seed=key)

    #next_obs, reward, done, info = env.step(action)
    next_obs, reward, done, truncated, info = env.step(action)
    reward = reward/300
    mask = float(not done or 'TimeLimit.truncated' in info)
    
    replay_buffer.add_transition(dict(
        observations=obs,
        actions=action,
        rewards=reward,
        masks=mask,
        next_observations=next_obs,
    ))
    obs = next_obs

    if (done or truncated):
        exploration_metrics = {f'exploration/{k}': v for k, v in flatten(info).items()}
        obs,info= env.reset()
        episode_reward = 0.0

    if replay_buffer.size < start_steps:
        continue

    
    #agent, update_info = agent.update(batch)
    for j in range(5):
        batch = replay_buffer.sample(batch_size)  
        agent, critic_update_info = agent.update_critic(batch)
    
    batch = replay_buffer.sample(batch_size)      
    agent, actor_update_info = agent.update_actor(batch)    
    
    update_info = {**critic_update_info, **actor_update_info}
    

    if i % log_interval == 0:
        train_metrics = {f'training/{k}': v for k, v in update_info.items()}
        wandb.log(train_metrics, step=i)
        wandb.log(exploration_metrics, step=i)
        exploration_metrics = dict()

    if i % eval_interval == 0:
        
        
        policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
        eval_info = evaluate(policy_fn, eval_env, num_episodes=eval_episodes)
        eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
        wandb.log(eval_metrics, step=i,commit=True)

    # if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
    #     checkpoints.save_checkpoint(FLAGS.save_dir, agent, i)



2023-11-28 13:04:25.972393: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


  9%|▉         | 94972/1000000 [04:39<44:18, 340.38it/s]  


KeyboardInterrupt: 

: 