In [1]:
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy

In [2]:
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy
from ray.rllib.agents.a3c.a2c import A2CTrainer

In [3]:
import gym
from typing import Optional, Dict

import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
    Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping, sequence_mask
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
    PolicyID, LocalOptimizer

torch, nn = try_import_torch()

ImportError: cannot import name 'ValueNetworkMixin' from 'ray.rllib.agents.ppo.ppo_torch_policy' (/home/ml2558/miniconda3/envs/tf-gpu/lib/python3.9/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py)

In [None]:
def after_init(policy: Policy, obs_space: gym.spaces.Space, 
              action_space: gym.spaces.Space, config: TrainerConfigDict)->None:
        policy.past_len = 5        
        policy.past_models = deque(maxlen =policy.past_len)
        policy.timestep = 0
    

In [None]:
def compute_div_loss(policy: Policy, model: ModelV2,
                      dist_class: ActionDistribution,
                      train_batch: SampleBatch):
    logits, _ = model.from_batch(train_batch)
    values = model.value_function()
    valid_mask = torch.ones_like(values, dtype=torch.bool)
    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    
    divs = []
    for idx, past_model in enumerate(policy.past_models):
        logits, _ = past_model.from_batch(train_batch)
        values = past_model.value_function()
        valid_mask = torch.ones_like(values, dtype=torch.bool)
        dist = dist_class(logits, past_model)
        past_log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) 
        div = div_metric(log_probs, past_log_probs).sum(1)
        div = div.mean(0)
        divs.append(div)
    
    divs_sort_idx = np.argsort([d.data[0] for d in divs])
    div_loss_orig = 0
    for idx in divs_sort_idx:
        div_loss += divs[idx]
        div_loss_orig += divs[idx]
    
    div_loss = div_loss / self.past_len
    
    return div_loss

In [None]:
def actor_critic_loss(policy: Policy, model: ModelV2,
                      dist_class: ActionDistribution,
                      train_batch: SampleBatch) -> TensorType:
    logits, _ = model.from_batch(train_batch)
    values = model.value_function()
    policy.timestep += 1
    
    if policy.timestep % 100 == 0:
        policy.past_models.append(copy.deepcopy(model))
    
    if policy.is_recurrent():
        B = len(train_batch[SampleBatch.SEQ_LENS])
        max_seq_len = logits.shape[0] // B
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        valid_mask = torch.reshape(mask_orig, [-1])
    else:
        valid_mask = torch.ones_like(values, dtype=torch.bool)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    pi_err = -torch.sum(
        torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
                            valid_mask))

    # Compute a value function loss.
    if policy.config["use_critic"]:
        value_err = 0.5 * torch.sum(
            torch.pow(
                torch.masked_select(
                    values.reshape(-1) -
                    train_batch[Postprocessing.VALUE_TARGETS], valid_mask),
                2.0))
    # Ignore the value function.
    else:
        value_err = 0.0

    entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))


    total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] -
                  entropy * policy.config["entropy_coeff"] - compute_div_loss(policy, model, dist_class, train_batch))

    policy.entropy = entropy
    policy.pi_err = pi_err
    policy.value_err = value_err

    return total_loss

In [None]:
CustomPolicy = A3CTorchPolicy.with_updates(
    name="MyCustomA3CTorchPolicy",
    loss_fn=actor_critic_loss,
    after_init=after_init)
CustomTrainer = A2CTrainer.with_updates(
    default_policy=CustomPolicy)

In [None]:
#tune.run(CustomTrainer, config={"env": 'Frostbite-v0', "num_gpus":1})

#tune.run(CustomTrainer, config={"env": 'Frostbite-v0', "num_gpus":0})#, 'model': { 'custom_model': 'test_model' }})
tune.register_env("cache_guessing_game_env_fix", CacheGuessingGameEnv)#Fix)
# Two ways of training
# method 2b
config = {
    'env': 'cache_guessing_game_env_fix', #'cache_simulator_diversity_wrapper',
    'env_config': {
        'verbose': 1,
        "force_victim_hit": False,
        'flush_inst': False,
        "allow_victim_multi_access": False,
        "attacker_addr_s": 0,
        "attacker_addr_e": 3,
        "victim_addr_s": 0,
        "victim_addr_e": 1,
        "reset_limit": 1,
        "cache_configs": {
                # YAML config file for cache simulaton
            "architecture": {
              "word_size": 1, #bytes
              "block_size": 1, #bytes
              "write_back": True
            },
            "cache_1": {#required
              "blocks": 2, 
              "associativity": 2,  
              "hit_time": 1 #cycles
            },
            "mem": {#required
              "hit_time": 1000 #cycles
            }
        }
    }, 
    #'gamma': 0.9, 
    'num_gpus': 1, 
    'num_workers': 1, 
    'num_envs_per_worker': 1, 
    #'entropy_coeff': 0.001, 
    #'num_sgd_iter': 5, 
    #'vf_loss_coeff': 1e-05, 
    'model': {
        'custom_model': 'test_model',#'rnn', 
        #'max_seq_len': 20, 
        #'custom_model_config': {
        #    'cell_size': 32
        #   }
    }, 
    'framework': 'torch',
}
tune.run(CustomTrainer, config=config)#config={"env": 'Freeway-v0', "num_gpus":1})

