In [1]:
from ray import tune
from ray import air
from deflector_gym.wrappers import ExpandObservation, BestRecorder
from ray.tune import register_env
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.simple_q import SimpleQConfig
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchModel
from ray.rllib.algorithms.simple_q.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from torch import nn
import torch
import numpy as np
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import deflector_gym
from operator import itemgetter

def init_params(net, val=np.sqrt(2)):
    for module in net.modules():
        if isinstance(module, nn.Conv2d):
            nn.init.orthogonal_(module.weight, val)
            module.bias.data.zero_()
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, val)
            if module.bias is not None:
                module.bias.data.zero_()

class convrelu(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.convrelu = nn.Sequential(
            nn.Conv1d(nin, nout, 3, padding='same', padding_mode='circular'),
            nn.BatchNorm1d(nout),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.convrelu(x)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class ShallowUQnet(TorchModelV2, nn.Module):
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            **kwargs,
        ):
        TorchModelV2.__init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            **kwargs,
        )
        nn.Module.__init__(self)

        self.effdata = []
        self.score_sum = []
        self.score_init_final = []
        self.ncells = 256
        init_params(self)


        self.conv1_1 = nn.Conv1d(1, 16, 3, padding='same', bias=True, padding_mode='circular')
        self.conv1_2 = convrelu(16, 16)
        self.conv1_3 = convrelu(16, 16)
        self.pool_1 = nn.MaxPool1d(2)  # non-Uniform

        self.conv2_1 = nn.Conv1d(16, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv2_2 = convrelu(32, 32)
        self.conv2_3 = convrelu(32, 32)
        self.pool_2 = nn.MaxPool1d(2)  # non-Uniform

        self.conv3_1 = nn.Conv1d(32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv3_2 = convrelu(32, 32)
        self.conv3_3 = convrelu(32, 32)
        self.pool_3 = nn.MaxPool1d(2)  # Uniform (X

        self.conv4_1 = nn.Conv1d(32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv4_2 = convrelu(32, 32)
        self.conv4_3 = convrelu(32, 32)
        self.pool_4 = nn.MaxPool1d(2)  # Uniform (X

        self.conv6_1 = nn.Conv1d(32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv6_2 = convrelu(32, 32)
        self.conv6_3 = convrelu(32, 32)
        self.upsam_6 = nn.Upsample(scale_factor=2)  # Uniform (X

        self.conv8_1 = nn.Conv1d(32 + 32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv8_2 = convrelu(32, 32)
        self.conv8_3 = convrelu(32, 32)
        self.upsam_8 = nn.Upsample(scale_factor=2)  # Uniform (X

        self.conv9_1 = nn.Conv1d(32 + 32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv9_2 = convrelu(32, 32)
        self.conv9_3 = convrelu(32, 32)
        self.upsam_9 = nn.Upsample(scale_factor=2)  # Uniform (X

        self.conv10_1 = nn.Conv1d(32 + 32, 32, 3, padding='same', bias=True, padding_mode='circular')
        self.conv10_2 = convrelu(32, 32)
        self.conv10_3 = convrelu(32, 32)
        self.upsam_10 = nn.Upsample(scale_factor=2)  # non-Uniform

        self.conv11_1 = nn.Conv1d(16 + 32, 16, 3, padding='same', bias=True, padding_mode='circular')
        self.conv11_2 = convrelu(16, 16)
        self.conv11_3 = convrelu(16, 16)

        self.conv11_fin = nn.Conv1d(16, 1, 3, padding='same', bias=True, padding_mode='circular')
        # self.lin = nn.LazyLinear(256)

    def forward(self, input_dict, state, seq_lens):
        img = input_dict['obs']

        res1_1 = self.conv1_1(img)
        temp = self.conv1_2(res1_1)
        temp = self.conv1_3(temp) + res1_1
        shortcut1 = temp
        temp = self.pool_1(shortcut1)

        res2_1 = self.conv2_1(temp)
        temp = self.conv2_2(res2_1)
        temp = self.conv2_3(temp) + res2_1
        shortcut2 = temp
        temp = self.pool_2(shortcut2)

        res3_1 = self.conv3_1(temp)
        temp = self.conv3_2(res3_1)
        temp = self.conv3_3(temp) + res3_1
        shortcut3 = temp
        temp = self.pool_3(shortcut3)

        res4_1 = self.conv4_1(temp)
        temp = self.conv4_2(res4_1)
        temp = self.conv4_3(temp) + res4_1
        shortcut4 = temp
        temp = self.pool_4(shortcut4)

        res6_1 = self.conv6_1(temp)
        temp = self.conv6_2(res6_1)
        temp = self.conv6_3(temp) + res6_1
        temp = self.upsam_6(temp)
        temp = torch.cat([temp, shortcut4], dim=1)  ######

        res8_1 = self.conv8_1(temp)
        temp = self.conv8_2(res8_1)
        temp = self.conv8_3(temp) + res8_1
        temp = self.upsam_8(temp)
        temp = torch.cat([temp, shortcut3], dim=1)  ######

        res9_1 = self.conv9_1(temp)
        temp = self.conv9_2(res9_1)
        temp = self.conv9_3(temp) + res9_1
        temp = self.upsam_9(temp)
        temp = torch.cat([temp, shortcut2], dim=1)  ######

        res10_1 = self.conv10_1(temp)
        temp = self.conv10_2(res10_1)
        temp = self.conv10_3(temp) + res10_1
        temp = self.upsam_10(temp)
        temp = torch.cat([temp, shortcut1], dim=1)  ######

        res11_1 = self.conv11_1(temp)
        temp = self.conv11_2(res11_1)
        temp = self.conv11_3(temp) + res11_1
        temp = self.conv11_fin(temp)

        # out = self.lin(temp)
        
        temp = temp.flatten(1)
        
        # self._value_logits = temp.argmax()
        # print(self._value_logits.shape)

        return temp, []

In [None]:
gym.spaces.MultiDiscrete

In [6]:
import gym

In [None]:
class OneHot(gym.ObservationWrapper):
    def __init__(self) -> None:
        super().__init__()
        self.observation_space = gym.spaces.Box(
            low=0., high=1.,
            shape=(n_cells,), #### TODO fix shape
            dtype=np.float64
        )
        
    def observation(self, obs):
        obs[obs == -1] = 0

        return obs
    # def __init__(self, env) -> None:
    #     super().__init__(env)
    #     self.observation_space = gym.spaces.MultiDiscrete(
    #         low=-1., high=1.,
    #         shape=(n_cells,), #### TODO fix shape
    #         dtype=np.float64
    #     )

In [3]:

class Callbacks(DefaultCallbacks):
    def on_episode_end(
            self,
            *,
            worker,
            base_env,
            policies,
            episode,
            **kwargs,
    ) -> None:
        envs = base_env.get_sub_environments()
        bests = [e.best for e in envs]
        best = max(bests, key=itemgetter(0))
        max_eff = best[0]
        # img = best[1][np.newaxis, np.newaxis, :].repeat(32, axis=1)
        # mean_eff = np.array([i[0] for i in bests]).mean()

        episode.custom_metrics['best_efficiency'] = max_eff
        # episode.custom_metrics['mean_efficiency'] = mean_eff

        # episode.media['best_structure'] = img

In [5]:
stop = {
    # "training_iteration": args.stop_iters,
    "timesteps_total": 200000,
    # "episode_reward_mean": args.stop_reward,
}
def make_env(config):
    e = deflector_gym.make('MeentIndex-v0')
    e = BestRecorder(e)
    e = OneHot(e)
    e = ExpandObservation(e)
    e = NormalizeReward(e)
    return e

register_env('MeentIndex-v0', make_env)
ModelCatalog.register_custom_model(ShallowUQnet.__name__, ShallowUQnet)
config = DQNConfig()
config = config.rollouts(horizon=1024)\
    .framework(framework='torch')\
        .environment(env='MeentIndex-v0')\
            .resources(num_gpus=1)\
                .training(model={'custom_model': ShallowUQnet.__name__})\
                    .callbacks(Callbacks)
                    
tuner = tune.Tuner(
    'DQN',
    param_space=config.to_dict(),
    # tune_config=tune.TuneConfig(),
    run_config=air.RunConfig(
        stop=stop,
        # callbacks=Callbacks,
        local_dir='/mnt/8tb/anthony/pirl',
        name='debug-onehot'
    ),
)
results = tuner.fit()

0,1
Current time:,2022-11-17 21:13:37
Running for:,03:21:17.75
Memory:,62.4/251.5 GiB

Trial name,status,loc,iter,total time (s),ts,reward,num_recreated_worker s,episode_reward_max,episode_reward_min
DQN_MeentIndex-v0_24e60_00000,TERMINATED,143.248.153.115:2808284,200,12049.9,200000,0.0339682,0,0.171579,-0.0333579


[2m[36m(DQN pid=2808284)[0m 2022-11-17 17:52:22,447	INFO simple_q.py:307 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting `simple_optimizer=True` if this doesn't work for you.
[2m[36m(DQN pid=2808284)[0m 2022-11-17 17:52:22,449	INFO algorithm.py:457 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_recreated_workers,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
DQN_MeentIndex-v0_24e60_00000,200000,"{'num_env_steps_sampled': 200000, 'num_env_steps_trained': 1592000, 'num_agent_steps_sampled': 200000, 'num_agent_steps_trained': 1592000, 'last_target_update_ts': 199504, 'num_target_updates': 398}","{'best_efficiency_mean': 0.15977086460438988, 'best_efficiency_min': 0.053865344287038174, 'best_efficiency_max': 0.34347248518865314}",2022-11-17_21-13-37,True,1024,{},0.171579,0.0339682,-0.0333579,1,195,9422b69edde14b65bba1eb68cf17018f,rayleigh-B7105F48TV4HR-2T-G,"{'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'mean_q': 15220645888.0, 'min_q': 601132160.0, 'max_q': 43182227456.0, 'cur_lr': 0.0005}, 'model': {}, 'td_error': array([-6.8721050e+08, -2.4957920e+08, -6.8721050e+08, -1.5144269e+08,  -4.2622182e+08, 2.3760384e+08, -3.9450112e+08, -5.5739187e+08,  1.9858637e+08, -8.3678413e+08, -7.0995251e+08, -1.3656602e+08,  -4.2924339e+08, 3.8496768e+07, 8.4837581e+08, 1.4258637e+08,  2.3984640e+07, 6.4837120e+07, -4.5959014e+08, 3.9655731e+08,  -4.5422019e+09, -2.4957920e+08, 7.5013427e+08, -5.8420224e+08,  -6.7792896e+08, -1.1958682e+08, 1.6794726e+08, 1.8855117e+08,  -9.8466816e+08, 4.1121587e+08, 5.6833434e+08, 1.6306790e+09],  dtype=float32), 'mean_td_error': -225499088.0}}, 'num_env_steps_sampled': 200000, 'num_env_steps_trained': 1592000, 'num_agent_steps_sampled': 200000, 'num_agent_steps_trained': 1592000, 'last_target_update_ts': 199504, 'num_target_updates': 398}",200,143.248.153.115,200000,1592000,200000,1000,1592000,8000,0,0,0,8000,"{'cpu_util_percent': 60.02934782608696, 'ram_util_percent': 24.693478260869572, 'gpu_util_percent0': 0.1848913043478261, 'vram_util_percent0': 0.7355468749999998, 'gpu_util_percent1': 0.031521739130434774, 'vram_util_percent1': 0.17841796875000004, 'gpu_util_percent2': 0.00010869565217391305, 'vram_util_percent2': 0.00048828125, 'gpu_util_percent3': 0.0002173913043478261, 'vram_util_percent3': 0.00048828125}",2808284,{},{},{},"{'mean_raw_obs_processing_ms': 0.47938471272842087, 'mean_inference_ms': 9.066336269441424, 'mean_action_processing_ms': 0.07009383916206997, 'mean_env_wait_ms': 20.188486318527275, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 0.17157877919177966, 'episode_reward_min': -0.03335788297334374, 'episode_reward_mean': 0.033968192783428316, 'episode_len_mean': 1024.0, 'episode_media': {}, 'episodes_this_iter': 1, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {'best_efficiency_mean': 0.15977086460438988, 'best_efficiency_min': 0.053865344287038174, 'best_efficiency_max': 0.34347248518865314}, 'hist_stats': {'episode_reward': [0.009409476661925396, -0.016949291277859296, 0.08036058838381828, 0.0300383704817059, 0.04590879882878562, 0.0569851250964622, 0.01797723453908672, 0.019369686693044966, 0.00642935801993726, 0.07039444943227376, 0.03461076185254807, 0.1077086484081064, 0.022732348504180308, 0.00926802498951599, 0.13266162228969744, -0.03335788297334374, 0.00231473719494326, 0.020058340609632228, 0.01917600119725949, 0.07614703933064804, 0.046638797802705134, -0.002933174534071454, 0.058398265357221174, 0.08404340908652266, 0.008218986580348496, 0.08388369428842916, 0.02643597093361139, -0.00038658347731100346, 0.05136851699444345, 0.17157877919177966, 0.014656120388461202, 0.08836484013967466, 0.047250257676878786, -0.0057960965771535556, 0.007429246653615404, 0.011317919969374875, 0.012162249979924628, 0.02373330900437821, -0.0029389714047068365, 0.05330020542370838, 0.005640154034666344, 0.11593519695766893, 0.08013000637496437, 0.06637419057684901, -0.02383774303459022, 0.03348495180522261, -0.031463688770816886, 0.06576570403263468, 3.539973088098874e-05, 0.035370028677920384, 0.0718978078408373, 0.011860244711461833, -0.0002663473898134594, 0.011881383649662488, -0.02306618835065075, 0.008629080356941807, 0.004247198170419467, -0.008124477940328115, -0.002426806822832162, 0.0068751464260916086, 0.03564666964824132, 0.03553417598985162, 0.013196688740918089, 0.06696193281610457, 0.06214510734879784, 0.09879451516090258, 0.05773088914461408, 0.015571799032926235, 0.009097372644176491, 0.030231074089595224, 0.029404461964539608, 0.002227846384746244, 0.06430179258311632, 0.1053225713666956, 0.08351545416422784, -0.017436548227893464, 0.013759711443219982, -0.009610052184127478, 3.4699527112302606e-05, 0.039728355696849765, 0.02297069803821069, -0.00604548169096656, 0.002272149769291777, 0.005391850002465006, 0.023546720470885598, 0.014806147863874106, -0.025802469456619688, 0.028228762569675638, 0.006749439155327547, 0.07963682971806689, 0.011986558481403045, 0.10200344412880502, 0.02227611580669326, 0.1112980481014064, 0.05637665946701863, 0.06195065279079312, 0.00897557591562596, 0.0127651947261011, 0.08626993274891717, 0.11809350962384943], 'episode_lengths': [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.47938471272842087, 'mean_inference_ms': 9.066336269441424, 'mean_action_processing_ms': 0.07009383916206997, 'mean_env_wait_ms': 20.188486318527275, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",12049.9,92.8394,12049.9,"{'training_iteration_time_ms': 300.591, 'load_time_ms': 1.896, 'load_throughput': 16874.879, 'learn_time_ms': 114.384, 'learn_throughput': 279.76, 'synch_weights_time_ms': 0.07}",1668687217,0,200000,200,24e60_00000,4.5695


2022-11-17 21:13:38,377	INFO tune.py:777 -- Total run time: 12078.81 seconds (12077.68 seconds for the tuning loop).


In [9]:
results.get_best_result()

Result(metrics={'custom_metrics': {'best_efficiency_mean': 0.15977086460438988, 'best_efficiency_min': 0.053865344287038174, 'best_efficiency_max': 0.34347248518865314}, 'episode_media': {}, 'num_recreated_workers': 0, 'info': {'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'mean_q': 15220645888.0, 'min_q': 601132160.0, 'max_q': 43182227456.0, 'cur_lr': 0.0005}, 'model': {}, 'td_error': array([-6.8721050e+08, -2.4957920e+08, -6.8721050e+08, -1.5144269e+08,
       -4.2622182e+08,  2.3760384e+08, -3.9450112e+08, -5.5739187e+08,
        1.9858637e+08, -8.3678413e+08, -7.0995251e+08, -1.3656602e+08,
       -4.2924339e+08,  3.8496768e+07,  8.4837581e+08,  1.4258637e+08,
        2.3984640e+07,  6.4837120e+07, -4.5959014e+08,  3.9655731e+08,
       -4.5422019e+09, -2.4957920e+08,  7.5013427e+08, -5.8420224e+08,
       -6.7792896e+08, -1.1958682e+08,  1.6794726e+08,  1.8855117e+08,
       -9.8466816e+08,  4.1121587e+08,  5.6833434e+08,  1.6306790e+09],
      dtype=float