In [1]:
import ray 
ray.init()

{'node_ip_address': '127.0.0.1',
 'raylet_ip_address': '127.0.0.1',
 'redis_address': '127.0.0.1:6379',
 'object_store_address': '/tmp/ray/session_2022-02-10_16-27-28_126872_9064/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-02-10_16-27-28_126872_9064/sockets/raylet',
 'webui_url': None,
 'session_dir': '/tmp/ray/session_2022-02-10_16-27-28_126872_9064',
 'metrics_export_port': 62493,
 'gcs_address': '127.0.0.1:60919',
 'node_id': '04610d67fc11f8e272fcb8a81af8cc4edbaac313a911c7837a9bf54e'}

In [2]:
from ray import tune
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gym 
from typing import Sequence
from ray.rllib.utils.typing import ModelConfigDict
import torch 
import torch.nn as nn 

# custom model is processed before the fully connected layer 

class CustomModel(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,

        ):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.linear1 = nn.Linear(self.obs_space.shape[0], 30)
        self.linear2 = nn.Linear(30,50)
        self.linear3 = nn.Linear(50, model_config['fcnet_hiddens'][0])
            
    def forward(self, input_dict, state, seq_lens):
        x =input_dict['obs']
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = nn.functional.relu(x)   
        x = self.linear3(x)
        return x, state

In [23]:
from ray.rllib.agents.dqn import DQNTrainer

config={"env":"CartPole-v0", 
                "framework":"torch", 
                "hiddens":[128],
                "dueling":False,
                "double_q":False,
                "model":{"custom_model": CustomModel, 
                          "fcnet_hiddens":[32],
                          "fcnet_activation":"relu",
                         }
                 }
tune.run("DQN", 
         config=config, 
         stop={"training_iteration":5} , 
         checkpoint_freq=2,
         local_dir=".",
         verbose=1)

2022-02-10 16:35:23,248	INFO tune.py:636 -- Total run time: 10.67 seconds (10.10 seconds for the tuning loop).


<ray.tune.analysis.experiment_analysis.ExperimentAnalysis at 0x7f9838dba640>

In [25]:
trainer = DQNTrainer(config=config)
trainer.restore("DQN/DQN_CartPole-v0_fb565_00000_0_2022-02-10_16-35-12/checkpoint_000004/checkpoint-4")

2022-02-10 16:35:29,198	INFO trainable.py:472 -- Restored on 127.0.0.1 from checkpoint: DQN/DQN_CartPole-v0_fb565_00000_0_2022-02-10_16-35-12/checkpoint_000004/checkpoint-4
2022-02-10 16:35:29,200	INFO trainable.py:480 -- Current state after restoring: {'_iteration': 4, '_timesteps_total': 128, '_time_total': 6.641242742538452, '_episodes_total': 196}


In [26]:
policy = trainer.get_policy()
policy.model

CustomModel_as_DQNTorchModel(
  (linear1): Linear(in_features=4, out_features=30, bias=True)
  (linear2): Linear(in_features=30, out_features=50, bias=True)
  (linear3): Linear(in_features=50, out_features=32, bias=True)
  (advantage_module): Sequential(
    (dueling_A_0): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=32, out_features=128, bias=True)
        (1): ReLU()
      )
    )
    (A): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=128, out_features=2, bias=True)
      )
    )
  )
)