In [1]:
"""Implementations of algorithms for continuous control."""
import functools
from jaxrl_m.typing import *

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import Policy, Critic, ensemblize

import flax
import flax.linen as nn

from functools import partial

class Temperature(nn.Module):
    initial_temperature: float = 1.

    @nn.compact
    def __call__(self) -> jnp.ndarray:
        log_temp = self.param('log_temp',
                              init_fn=lambda key: jnp.full(
                                  (), jnp.log(self.initial_temperature)))
        return jnp.exp(log_temp)

class SACAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    critic: TrainState
    target_critic: TrainState
    actor: TrainState
    temp: TrainState
    config: dict = nonpytree_field()
    
    
    @jax.jit    
    def reset_critic_optimizer(agent):
    
        new_opt_state = agent.critic.tx.init(agent.critic.params)
        new_critic = agent.critic.replace(opt_state=new_opt_state)
        
        return agent.replace(critic=new_critic)

    
    # @jax.jit
    # def update_critic(agent, batch: Batch,num_update=0):
    #     new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

     
    #     def critic_loss_fn(critic_params):
    #         next_dist = agent.actor(batch['next_observations'])
    #         next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=next_key)

    #         next_q1, next_q2 = agent.target_critic(batch['next_observations'], next_actions,True,
    #                                                params=None,rngs={'dropout': next_key})
    #         next_q = jnp.minimum(next_q1, next_q2)
    #         target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q

    #         if agent.config['backup_entropy']:
    #             target_q = target_q - agent.config['discount'] * batch['masks'] * next_log_probs * agent.temp()
            
            
    #         q1, q2 = agent.critic(batch['observations'], batch['actions'],True,
    #                                             params=critic_params,rngs={'dropout': curr_key},
    #                                             )
    #         critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
            
    #         return critic_loss, {
    #             'critic_loss': critic_loss,
    #             'q1': q1.mean(),
    #         }     
    

        
    #     new_critic, critic_info = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
    #     new_target_critic = target_update(agent.critic, agent.target_critic, agent.config['target_update_rate'])
    #     return agent.replace(rng=new_rng, critic=new_critic, target_critic=new_target_critic), {**critic_info}
    
    
    #     # update = lambda target : target_update(new_critic,target,agent.config['target_update_rate'])
    #     # no_update = lambda target: target
    #     # new_target_critic = lax.cond(num_update%2==0,update,no_update,agent.target_critic)
        
        
        
    @partial(jax.jit,static_argnames=('num_batches',))  
    def update_critic2(agent, transitions: Batch,idxs:jnp.array,num_steps:int):
        
        def one_update(agent, batch: Batch):
                
            new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        
            def critic_loss_fn(critic_params):
                next_dist = agent.actor(batch['next_observations'])
                next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=next_key)

                next_q1, next_q2 = agent.target_critic(batch['next_observations'], next_actions,True,
                                                    params=None,rngs={'dropout': next_key})
                next_q = jnp.minimum(next_q1, next_q2)
                target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_q

                if agent.config['backup_entropy']:
                    target_q = target_q - agent.config['discount'] * batch['masks'] * next_log_probs * agent.temp()
                
                
                q1, q2 = agent.critic(batch['observations'], batch['actions'],True,
                                                    params=critic_params,rngs={'dropout': curr_key},
                                                    )
                critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
                
                return critic_loss, {
                    'critic_loss': critic_loss,
                    'q1': q1.mean(),
                }  
            
            new_critic, critic_info = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
            new_target_critic = target_update(agent.critic, agent.target_critic, agent.config['target_update_rate'])
            return agent.replace(rng=new_rng, critic=new_critic, target_critic=new_target_critic)
        
        
        get_batch = lambda transitions,idx : jax.tree_map(lambda x : x[idx],transitions)
        
        agent = jax.lax.fori_loop(0, num_steps, 
                        lambda i, agent: one_update(agent,get_batch(transitions,idxs[i])),
                        agent)
        
        return agent,{}
    

        
    @jax.jit
    def update_actor(agent, batch: Batch):
        new_rng, curr_key, next_key = jax.random.split(agent.rng, 3)

        def actor_loss_fn(actor_params):
            # dist = agent.actor(batch['observations'], params=actor_params)
            # actions, log_probs = dist.sample_and_log_prob(seed=curr_key)
            # q1, q2 = agent.critic(batch['observations'], actions)
            
            observations = jnp.repeat(batch['observations'],5,axis=0)
            discounts = jnp.repeat(batch['discounts'],5,axis=0)
            
            dist = agent.actor(observations, params=actor_params)
            actions, log_probs = dist.sample_and_log_prob(seed=curr_key)
            q1, q2 = agent.critic(observations, actions)#params=agent.target_critic.params)
            
            #q = jnp.minimum(q1, q2)
            q = (q1 + q2)/2

            actor_loss = ((log_probs * agent.temp() - q)).mean()
            #actor_loss = ((log_probs * agent.temp() - q)*discounts).mean()
            return actor_loss, {
                'actor_loss': actor_loss,
                'entropy': -1 * log_probs.mean(),
            }
        
        def temp_loss_fn(temp_params, entropy, target_entropy):
            temperature = agent.temp(params=temp_params)
            temp_loss = (temperature * (entropy - target_entropy)).mean()
            return temp_loss, {
                'temp_loss': temp_loss,
                'temperature': temperature,
            }

        
        new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)
        temp_loss_fn = functools.partial(temp_loss_fn, entropy=actor_info['entropy'], target_entropy=agent.config['target_entropy'])
        new_temp, temp_info = agent.temp.apply_loss_fn(loss_fn=temp_loss_fn, has_aux=True)


        return agent.replace(rng=new_rng,actor=new_actor, temp=new_temp), {**actor_info, **temp_info}
        


    @jax.jit
    def sample_actions(agent,   
                       observations: np.ndarray,
                       *,
                       seed: PRNGKey,
                       temperature: float = 1.0,
                       ) -> jnp.ndarray:
        actions = agent.actor(observations, temperature=temperature).sample(seed=seed)
        #actions = jnp.clip(actions, -1, 1)
        return actions



def create_learner(
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 temp_lr: float = 3e-2,## Test
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 target_entropy: float = None,
                 backup_entropy: bool = True,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        action_dim = actions.shape[-1]
        actor_def = Policy((64,64), action_dim=action_dim, 
            log_std_min=-10.0, state_dependent_std=True, tanh_squash_distribution=True, final_fc_init_scale=1.0)

        actor_params = actor_def.init(actor_key, observations)['params']
        actor = TrainState.create(actor_def, actor_params, tx=optax.adam(learning_rate=actor_lr))

        critic_def = ensemblize(Critic, num_qs=2,split_rngs={"dropout":True})(hidden_dims)
        
        critic_params = critic_def.init(critic_key, observations, actions)['params']
        critic = TrainState.create(critic_def, critic_params, tx=optax.adam(learning_rate=critic_lr))
        target_critic = TrainState.create(critic_def, critic_params)

        temp_def = Temperature()
        temp_params = temp_def.init(rng)['params']
        temp = TrainState.create(temp_def, temp_params, tx=optax.adam(learning_rate=temp_lr))

        if target_entropy is None:
            target_entropy = -0.5 * action_dim

        config = flax.core.FrozenDict(dict(
            discount=discount,
            target_update_rate=tau,
            target_entropy=target_entropy,
            backup_entropy=backup_entropy,            
        ))

        return SACAgent(rng, critic=critic, target_critic=target_critic, actor=actor, temp=temp, config=config)





In [2]:
from jaxrl_m.dataset import ReplayBuffer
from flax import struct
import chex



@struct.dataclass
class PolicyRollout:
    
    policy_params : chex.Array
    num_rollouts : chex.Array 
    policy_return : chex.Array
    observations : chex.Array
    disc_masks : chex.Array
    

def rollout_policy(agent,env,exploration_rng,
                   replay_buffer,actor_buffer,
                   warmup=False,num_rollouts=10,):
    
    
    actor_buffer.reset()
    obs,_ = env.reset()  
    n_steps,n_rollouts,episode_step,disc,mask = 0,0,0,1.,1.
    max_steps = num_rollouts*1000
    observations,disc_masks,rewards = np.zeros((max_steps,obs.shape[0])),np.zeros((max_steps,)),np.zeros((max_steps,))
    
    
    while n_rollouts < num_rollouts:
        
        if warmup:
            action = env.action_space.sample()
        else:
            exploration_rng, key = jax.random.split(exploration_rng)
            action = agent.sample_actions(obs, seed=key)
        
        next_obs, reward, done, truncated, info = env.step(action)
        #reward = reward/400
        
        mask = float(not done)
    
        transition = dict(observations=obs,actions=action,
            rewards=reward,masks=mask,next_observations=next_obs,discounts=disc)
        
        
        replay_buffer.add_transition(transition)
        actor_buffer.add_transition(transition)
    
        observations[1000*n_rollouts+episode_step] = obs
        disc_masks[1000*n_rollouts+episode_step] = disc
        rewards[1000*n_rollouts+episode_step] = reward
        
        obs = next_obs
        disc *= (0.99*mask)
        episode_step += 1
        n_steps += 1
        
        if (done or truncated) :
            
            #exploration_metrics = {f'exploration/{k}': v for k, v in flatten(info).items()}
            obs,_= env.reset()
            n_rollouts += 1
            episode_step = 0
            disc,mask = 1.,1.

    policy_return = (disc_masks*rewards).sum()/num_rollouts
    policy_rollout = PolicyRollout(policy_params=agent.actor.params,
                                   policy_return=policy_return,
                                   observations=observations,
                                   disc_masks=disc_masks,
                                    num_rollouts=num_rollouts)
    
    return replay_buffer,actor_buffer,policy_rollout,policy_return,n_steps

In [3]:




def f(anc_agent,obs,actor):

    dist = anc_agent.actor(obs, params=actor)
    actions, _ = dist.sample_and_log_prob(seed=jax.random.PRNGKey(0))
    q1, q2 = anc_agent.critic(obs, actions,params=anc_agent.target_critic.params)
    #q = jnp.minimum(q1, q2)
    q = (q1+q2)/2
    
    return q
    

@jax.jit
def estimate_return(anc_agent,anc_return,acq_rollout:PolicyRollout,):
    
    # acq_obs = acq_rollout.observations
    # acq_masks = acq_rollout.disc_masks
    acq_obs = jnp.repeat(acq_rollout.observations,10,axis=0)
    acq_masks = jnp.repeat(acq_rollout.disc_masks,10,axis=0)
    
    acq_actor = acq_rollout.policy_params
    acq_return = acq_rollout.policy_return
    
    anc_actor = anc_agent.actor.params
    
    acq_q = f(anc_agent,acq_obs,acq_actor)
    anc_q = f(anc_agent,acq_obs,anc_actor)
    
    adv = ((acq_q - anc_q)*acq_masks).sum()/10
    acq_return_pred = anc_return + adv
  
    
    return acq_return_pred,acq_return


def evaluate_critic(anc_agent,anc_return,policy_rollouts):

    y_pred,y= [],[]
    for policy_rollout in policy_rollouts:
        
        acq_return_pred,acq_return = estimate_return(anc_agent,anc_return,policy_rollout)
        y_pred.append(acq_return_pred),y.append(acq_return)
        
    y_pred,y = np.array(y_pred),np.array(y)
    a2 = jnp.clip(((y-y_pred)**2),a_min=1e-4).sum()
    b2=((y-y.mean())**2).sum()
    R2 = 1-(a2/b2)  
    
    return R2


In [4]:
import os
from functools import partial
import numpy as np
import jax
import tqdm
import gymnasium as gym


from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import ReplayBuffer
from collections import deque


env_name='InvertedPendulum-v4'
seed=np.random.choice(1000000)
eval_episodes=10
batch_size = 256
max_steps = int(1e6)
start_steps = 5000                   
log_interval = 5000
#eval_interval = 10000

wandb_config = default_wandb_config()
wandb_config.update({
    'project': 'd4rl_test',
    'group': 'sac_test',
    'name': 'sac_{env_name}',
})


env = EpisodeMonitor(gym.make(env_name))
eval_env = EpisodeMonitor(gym.make(env_name))
setup_wandb({"bonjour":1})

example_transition = dict(
    observations=env.observation_space.sample(),
    actions=env.action_space.sample(),
    rewards=0.0,
    masks=1.0,
    next_observations=env.observation_space.sample(),
    discounts=1.0,
)

replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))
actor_buffer = ReplayBuffer.create(example_transition, size=int(5e3))

agent = create_learner(seed,
                example_transition['observations'][None],
                example_transition['actions'][None],
                max_steps=max_steps,
                #**FLAGS.config
                )

exploration_metrics = dict()
obs,info = env.reset()    
exploration_rng = jax.random.PRNGKey(0)
i = 0
unlogged_steps = 0
policy_rollouts = deque([], maxlen=10)
with tqdm.tqdm(total=max_steps) as pbar:
    
    while i < max_steps:
    
        replay_buffer,actor_buffer,policy_rollout,policy_return,num_steps = rollout_policy(agent,env,exploration_rng,
                   replay_buffer,actor_buffer,
                   warmup=(i < start_steps))
        policy_rollouts.append(policy_rollout)
        unlogged_steps += num_steps
        i+=num_steps
        pbar.update(num_steps)
        
            
        if replay_buffer.size > start_steps:
        
         
            transitions = replay_buffer.get_all()
            idxs = jax.random.choice(a=replay_buffer.size, shape=(10000,256), replace=True,key=jax.random.PRNGKey(0))
            agent.reset_critic_optimizer()
            agent, critic_update_info = agent.update_critic2(transitions,idxs,10000)
            R2 = evaluate_critic(agent,policy_rollouts[-1].policy_return,policy_rollouts)

            
            actor_batch = actor_buffer.get_all()      
            agent, actor_update_info = agent.update_actor(actor_batch)    
            
            
            update_info = {**critic_update_info, **actor_update_info, 'R2_validation': R2}
            
            
            if unlogged_steps > log_interval:
                exploration_metrics = {f'exploration/disc_return': policy_return}
                wandb.log(exploration_metrics, step=int(i),commit=False)
                train_metrics = {f'training/{k}': v for k, v in update_info.items()}
                wandb.log(train_metrics, step=int(i),commit=False)
                #wandb.log(exploration_metrics, step=i)
                policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
                eval_info = evaluate(policy_fn, eval_env, num_episodes=eval_episodes)
                eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
                print('evaluating')
                wandb.log(eval_metrics, step=int(i),commit=True)
                unlogged_steps = 0


2023-12-15 00:17:52.257123: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmahdikallel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Extra kwargs: {'max_steps': 1000000}


  1%|          | 5100/1000000 [00:04<17:52, 928.05it/s]  

evaluating


  1%|          | 10251/1000000 [01:29<3:10:33, 86.56it/s]

evaluating


  2%|▏         | 15599/1000000 [02:13<1:57:19, 139.83it/s]

evaluating


  2%|▏         | 21052/1000000 [02:43<1:24:32, 192.98it/s]

evaluating


  3%|▎         | 26620/1000000 [03:06<1:06:27, 244.08it/s]

evaluating


  3%|▎         | 31792/1000000 [03:23<50:37, 318.78it/s]  

evaluating


  4%|▎         | 36960/1000000 [03:36<39:31, 406.17it/s]

evaluating


  4%|▍         | 42455/1000000 [03:46<28:00, 569.96it/s]

evaluating


  5%|▍         | 47903/1000000 [03:54<24:00, 660.78it/s]

evaluating


  5%|▌         | 54379/1000000 [04:04<24:03, 655.19it/s]

evaluating


  6%|▌         | 61189/1000000 [04:14<22:55, 682.45it/s]

evaluating


  7%|▋         | 67757/1000000 [04:23<19:53, 780.93it/s]

evaluating


  7%|▋         | 73212/1000000 [04:29<18:41, 826.22it/s]

evaluating


  8%|▊         | 79319/1000000 [04:36<16:59, 902.72it/s]

evaluating


  9%|▊         | 86569/1000000 [04:45<17:51, 852.31it/s]

evaluating


  9%|▉         | 93081/1000000 [04:52<16:20, 924.78it/s]

evaluating


 10%|█         | 100416/1000000 [04:59<15:19, 978.02it/s]

evaluating


 11%|█         | 108477/1000000 [05:07<14:24, 1031.42it/s]

evaluating


 12%|█▏        | 115327/1000000 [05:14<14:54, 988.51it/s] 

evaluating


 12%|█▏        | 121625/1000000 [05:21<14:57, 978.25it/s]

evaluating


 13%|█▎        | 131308/1000000 [05:29<13:05, 1105.98it/s]

evaluating


 14%|█▍        | 140467/1000000 [05:38<13:13, 1083.34it/s]

evaluating


 15%|█▍        | 145682/1000000 [05:44<13:43, 1037.29it/s]

evaluating


 15%|█▌        | 151202/1000000 [05:50<14:00, 1010.04it/s]

evaluating


 16%|█▌        | 158054/1000000 [05:56<13:50, 1013.30it/s]

evaluating


 16%|█▋        | 164027/1000000 [06:03<14:07, 986.72it/s] 

evaluating


 17%|█▋        | 173708/1000000 [06:11<12:49, 1073.84it/s]

evaluating


 18%|█▊        | 180098/1000000 [06:18<13:21, 1023.34it/s]

evaluating


 19%|█▉        | 188796/1000000 [06:25<12:35, 1074.11it/s]

evaluating


 20%|█▉        | 195676/1000000 [06:32<12:49, 1045.74it/s]

evaluating


 20%|██        | 203305/1000000 [06:40<12:37, 1052.31it/s]

evaluating


 21%|██▏       | 212844/1000000 [06:47<11:43, 1119.32it/s]

evaluating


 22%|██▏       | 222566/1000000 [06:55<11:03, 1171.08it/s]

evaluating


 23%|██▎       | 232330/1000000 [07:02<10:35, 1208.73it/s]

evaluating


 24%|██▍       | 241318/1000000 [07:10<10:28, 1206.96it/s]

evaluating


 25%|██▍       | 249262/1000000 [07:17<10:38, 1176.36it/s]

evaluating


 26%|██▌       | 258255/1000000 [07:24<10:21, 1194.11it/s]

evaluating


 27%|██▋       | 268130/1000000 [07:32<10:00, 1219.69it/s]

evaluating


 28%|██▊       | 277866/1000000 [07:40<09:46, 1232.00it/s]

evaluating


 29%|██▉       | 287608/1000000 [07:47<09:33, 1242.99it/s]

evaluating


 30%|██▉       | 297351/1000000 [07:55<09:21, 1251.82it/s]

evaluating


 31%|███       | 307104/1000000 [08:03<09:14, 1250.65it/s]

evaluating


 32%|███▏      | 317104/1000000 [08:11<09:02, 1259.86it/s]

evaluating


 33%|███▎      | 327104/1000000 [08:18<08:45, 1280.50it/s]

evaluating


 34%|███▎      | 337104/1000000 [08:26<08:39, 1276.26it/s]

evaluating


 35%|███▍      | 347104/1000000 [08:34<08:30, 1279.39it/s]

evaluating


 36%|███▌      | 357104/1000000 [08:41<08:13, 1302.95it/s]

evaluating


 36%|███▌      | 357104/1000000 [08:47<15:50, 676.52it/s] 


KeyboardInterrupt: 

In [None]:
#opt_state = tx.init(params)

#dir(agent.critic)
new_opt_state = agent.critic.tx.init(agent.critic.params)
new_critic = agent.critic.replace(opt_state=new_opt_state)
