In [1]:
import gym
import numpy as np
import renom as rm
from renom.cuda import set_cuda_active
from renom_rl.dqn import DQN
from renom_rl.env import BaseEnv
from gym.core import Env
from PIL import Image

set_cuda_active(True)
env = gym.make('BreakoutNoFrameskip-v4')

class CustomEnv(BaseEnv):
    
    def __init__(self, env):
        self.env = env
        self.action_shape = 4
        self.state_shape = (4, 84, 84)
        super(CustomEnv, self).__init__()
    
    def reset(self):
        self.env.reset()
        for _ in range(32):
            state, _, _ = self.step(self.env.action_space.sample())
        return state
    
    def sample(self):
        return int(np.random.rand()*4)
    
    def render(self):
        self.env.render()

    def _preprocess(self,state):
        resized_image = Image.fromarray(state).resize((84,110)).convert('L')
        image_array = np.asarray(resized_image.getdata()).reshape(110, 84)/255. * 2 - 1
        final_image = image_array[26:110, :]
        return final_image
    
    def step(self, action):
        state_list = []
        reward_list = []
        terminal = False
        for _ in range(4):
            s, r, t, _ = self.env.step(action)
            state_list.append(self._preprocess(s))
            reward_list.append(r)
            if t:
                terminal = True
        state = np.stack(state_list)
        return state, (np.sum(reward_list) > 0).astype(np.int), terminal

custom_env = CustomEnv(env)
q_network = rm.Sequential([rm.Conv2d(32, filter=8, stride=4),
                           rm.Relu(),
                           rm.Conv2d(64, filter=4, stride=2), 
                           rm.Relu(),
                           rm.Conv2d(64, filter=3, stride=1), 
                           rm.Relu(), 
                           rm.Flatten(), 
                           rm.Dense(512),
                           rm.Relu(), 
                           rm.Dense(custom_env.action_shape)])

In [2]:
model = DQN(custom_env, q_network)

In [3]:
model.train(render=True, greedy_step=100000)

Run random 5000 step for storing experiences


episode 001 avg_loss: 0.055 total_reward [train:0.000 test:-] e-greedy:0.001: : 129it [00:03, 38.48it/s]
episode 002 avg_loss: 0.030 total_reward [train:0.000 test:-] e-greedy:0.002: : 123it [00:03, 36.89it/s]
episode 003 avg_loss: 0.025 total_reward [train:1.000 test:-] e-greedy:0.003: : 120it [00:03, 38.83it/s]
episode 004 avg_loss: 0.022 total_reward [train:1.000 test:-] e-greedy:0.005: : 148it [00:03, 38.86it/s]
episode 005 avg_loss: 0.020 total_reward [train:2.000 test:-] e-greedy:0.006: : 182it [00:04, 37.00it/s]
episode 006 avg_loss: 0.018 total_reward [train:1.000 test:-] e-greedy:0.008: : 153it [00:04, 36.93it/s]
episode 007 avg_loss: 0.020 total_reward [train:4.000 test:-] e-greedy:0.010: : 236it [00:06, 38.63it/s]
episode 008 avg_loss: 0.017 total_reward [train:2.000 test:-] e-greedy:0.011: : 175it [00:04, 37.17it/s]
episode 009 avg_loss: 0.016 total_reward [train:1.000 test:-] e-greedy:0.013: : 133it [00:03, 36.57it/s]
episode 010 avg_loss: 0.021 total_reward [train:2.000 t

episode 079 avg_loss: 0.010 total_reward [train:2.000 test:-] e-greedy:0.103: : 187it [00:05, 37.15it/s]
episode 080 avg_loss: 0.010 total_reward [train:0.000 test:-] e-greedy:0.104: : 110it [00:03, 36.20it/s]
episode 081 avg_loss: 0.008 total_reward [train:0.000 test:-] e-greedy:0.105: : 108it [00:02, 38.90it/s]
episode 082 avg_loss: 0.009 total_reward [train:1.000 test:-] e-greedy:0.106: : 159it [00:04, 36.64it/s]
episode 083 avg_loss: 0.008 total_reward [train:2.000 test:-] e-greedy:0.107: : 159it [00:04, 36.89it/s]
episode 084 avg_loss: 0.008 total_reward [train:1.000 test:-] e-greedy:0.109: : 136it [00:04, 33.63it/s]
episode 085 avg_loss: 0.013 total_reward [train:3.000 test:-] e-greedy:0.111: : 223it [00:06, 31.87it/s]
episode 086 avg_loss: 0.010 total_reward [train:0.000 test:-] e-greedy:0.112: : 101it [00:03, 30.79it/s]
episode 087 avg_loss: 0.012 total_reward [train:1.000 test:-] e-greedy:0.113: : 141it [00:04, 31.28it/s]
episode 088 avg_loss: 0.012 total_reward [train:2.000 t

episode 157 avg_loss: 0.008 total_reward [train:1.000 test:-] e-greedy:0.203: : 130it [00:03, 36.67it/s]
episode 158 avg_loss: 0.009 total_reward [train:0.000 test:-] e-greedy:0.204: : 94it [00:02, 35.77it/s]
episode 159 avg_loss: 0.009 total_reward [train:2.000 test:-] e-greedy:0.205: : 171it [00:04, 37.33it/s]
episode 160 avg_loss: 0.010 total_reward [train:2.000 test:-] e-greedy:0.207: : 173it [00:04, 37.19it/s]
episode 161 avg_loss: 0.009 total_reward [train:2.000 test:-] e-greedy:0.208: : 169it [00:04, 35.33it/s]
episode 162 avg_loss: 0.007 total_reward [train:2.000 test:-] e-greedy:0.210: : 173it [00:05, 31.80it/s]
episode 163 avg_loss: 0.008 total_reward [train:4.000 test:-] e-greedy:0.212: : 231it [00:07, 32.06it/s]
episode 164 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.213: : 92it [00:02, 32.76it/s]
episode 165 avg_loss: 0.009 total_reward [train:0.000 test:-] e-greedy:0.214: : 95it [00:03, 30.94it/s]
episode 166 avg_loss: 0.008 total_reward [train:1.000 test

episode 235 avg_loss: 0.007 total_reward [train:1.000 test:-] e-greedy:0.309: : 125it [00:03, 34.45it/s]
episode 236 avg_loss: 0.008 total_reward [train:3.000 test:-] e-greedy:0.311: : 220it [00:05, 38.31it/s]
episode 237 avg_loss: 0.008 total_reward [train:1.000 test:-] e-greedy:0.312: : 152it [00:04, 39.04it/s]
episode 238 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.313: : 102it [00:02, 35.33it/s]
episode 239 avg_loss: 0.007 total_reward [train:1.000 test:-] e-greedy:0.314: : 147it [00:04, 35.90it/s]
episode 240 avg_loss: 0.005 total_reward [train:3.000 test:-] e-greedy:0.316: : 217it [00:05, 36.57it/s]
episode 241 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.317: : 119it [00:03, 35.77it/s]
episode 242 avg_loss: 0.008 total_reward [train:2.000 test:-] e-greedy:0.319: : 170it [00:04, 36.50it/s]
episode 243 avg_loss: 0.008 total_reward [train:1.000 test:-] e-greedy:0.320: : 140it [00:03, 37.06it/s]
episode 244 avg_loss: 0.011 total_reward [train:0.000 t

episode 313 avg_loss: 0.008 total_reward [train:2.000 test:-] e-greedy:0.411: : 148it [00:04, 38.71it/s]
episode 314 avg_loss: 0.008 total_reward [train:1.000 test:-] e-greedy:0.412: : 145it [00:04, 36.15it/s]
episode 315 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.413: : 126it [00:03, 36.08it/s]
episode 316 avg_loss: 0.007 total_reward [train:1.000 test:-] e-greedy:0.414: : 146it [00:04, 36.28it/s]
episode 317 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.415: : 98it [00:02, 35.48it/s]
episode 318 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.417: : 140it [00:03, 37.72it/s]
episode 319 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.418: : 122it [00:03, 35.77it/s]
episode 320 avg_loss: 0.008 total_reward [train:2.000 test:-] e-greedy:0.419: : 165it [00:04, 35.64it/s]
episode 321 avg_loss: 0.006 total_reward [train:2.000 test:-] e-greedy:0.421: : 174it [00:04, 36.10it/s]
episode 322 avg_loss: 0.009 total_reward [train:0.000 te

episode 391 avg_loss: 0.009 total_reward [train:1.000 test:-] e-greedy:0.507: : 126it [00:03, 35.08it/s]
episode 392 avg_loss: 0.006 total_reward [train:3.000 test:-] e-greedy:0.509: : 218it [00:05, 36.53it/s]
episode 393 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.510: : 99it [00:02, 35.33it/s]
episode 394 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.511: : 93it [00:02, 35.15it/s]
episode 395 avg_loss: 0.006 total_reward [train:4.000 test:-] e-greedy:0.513: : 227it [00:06, 36.61it/s]
episode 396 avg_loss: 0.006 total_reward [train:2.000 test:-] e-greedy:0.514: : 149it [00:04, 36.06it/s]
episode 397 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.515: : 92it [00:02, 38.01it/s]
episode 398 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.516: : 96it [00:02, 37.30it/s]
episode 399 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.517: : 100it [00:02, 37.91it/s]
episode 400 avg_loss: 0.007 total_reward [train:2.000 test:

episode 469 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.603: : 97it [00:02, 34.68it/s]
episode 470 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.604: : 93it [00:02, 34.74it/s]
episode 471 avg_loss: 0.007 total_reward [train:0.000 test:-] e-greedy:0.605: : 95it [00:02, 34.96it/s]
episode 472 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.607: : 192it [00:05, 37.41it/s]
episode 473 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.608: : 90it [00:02, 35.02it/s]
episode 474 avg_loss: 0.006 total_reward [train:0.000 test:-] e-greedy:0.608: : 94it [00:02, 34.69it/s]
episode 475 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.610: : 173it [00:04, 35.86it/s]
episode 476 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.611: : 90it [00:02, 34.78it/s]
episode 477 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.612: : 139it [00:03, 35.19it/s]
episode 478 avg_loss: 0.008 total_reward [train:1.000 test:-]

episode 547 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.700: : 175it [00:04, 35.67it/s]
episode 548 avg_loss: 0.005 total_reward [train:1.000 test:-] e-greedy:0.701: : 139it [00:03, 35.54it/s]
episode 549 avg_loss: 0.006 total_reward [train:2.000 test:-] e-greedy:0.703: : 168it [00:04, 37.19it/s]
episode 550 avg_loss: 0.005 total_reward [train:0.000 test:2.000] e-greedy:0.703: : 92it [00:03, 36.88it/s]
episode 551 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.705: : 127it [00:03, 35.39it/s]
episode 552 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.706: : 166it [00:04, 35.47it/s]
episode 553 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.707: : 94it [00:02, 34.65it/s]
episode 554 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.708: : 91it [00:02, 34.69it/s]
episode 555 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.709: : 140it [00:03, 36.67it/s]
episode 556 avg_loss: 0.007 total_reward [train:1.000 

episode 625 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.794: : 93it [00:02, 34.46it/s]
episode 626 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.795: : 151it [00:04, 35.34it/s]
episode 627 avg_loss: 0.006 total_reward [train:0.000 test:-] e-greedy:0.796: : 93it [00:02, 34.62it/s]
episode 628 avg_loss: 0.003 total_reward [train:1.000 test:-] e-greedy:0.797: : 120it [00:03, 36.45it/s]
episode 629 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.798: : 169it [00:04, 35.51it/s]
episode 630 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.799: : 91it [00:02, 34.57it/s]
episode 631 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.800: : 98it [00:02, 34.69it/s]
episode 632 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.801: : 118it [00:03, 35.20it/s]
episode 633 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.802: : 121it [00:03, 35.10it/s]
episode 634 avg_loss: 0.005 total_reward [train:0.000 test:

episode 703 avg_loss: 0.005 total_reward [train:1.000 test:-] e-greedy:0.889: : 124it [00:03, 36.87it/s]
episode 704 avg_loss: 0.006 total_reward [train:0.000 test:-] e-greedy:0.890: : 92it [00:02, 36.59it/s]
episode 705 avg_loss: 0.006 total_reward [train:0.000 test:-] e-greedy:0.891: : 90it [00:02, 34.56it/s]
episode 706 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.892: : 137it [00:03, 35.06it/s]
episode 707 avg_loss: 0.005 total_reward [train:4.000 test:-] e-greedy:0.894: : 232it [00:06, 36.74it/s]
episode 708 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.895: : 91it [00:02, 34.56it/s]
episode 709 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.896: : 155it [00:04, 35.19it/s]
episode 710 avg_loss: 0.005 total_reward [train:2.000 test:-] e-greedy:0.898: : 167it [00:04, 35.36it/s]
episode 711 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.899: : 137it [00:03, 35.17it/s]
episode 712 avg_loss: 0.003 total_reward [train:1.000 test

episode 781 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 92it [00:02, 36.26it/s]
episode 782 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 172it [00:04, 36.59it/s]
episode 783 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 97it [00:02, 34.65it/s]
episode 784 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 94it [00:02, 34.26it/s]
episode 785 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 141it [00:04, 35.13it/s]
episode 786 avg_loss: 0.007 total_reward [train:1.000 test:-] e-greedy:0.900: : 120it [00:03, 36.61it/s]
episode 787 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 126it [00:03, 35.02it/s]
episode 788 avg_loss: 0.005 total_reward [train:0.000 test:-] e-greedy:0.900: : 93it [00:02, 34.35it/s]
episode 789 avg_loss: 0.006 total_reward [train:1.000 test:-] e-greedy:0.900: : 147it [00:04, 35.26it/s]
episode 790 avg_loss: 0.005 total_reward [train:1.000 test:

episode 859 avg_loss: 0.005 total_reward [train:1.000 test:-] e-greedy:0.900: : 120it [00:03, 37.26it/s]
episode 860 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 97it [00:02, 35.68it/s]
episode 861 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 167it [00:04, 36.43it/s]
episode 862 avg_loss: 0.004 total_reward [train:3.000 test:-] e-greedy:0.900: : 195it [00:05, 36.78it/s]
episode 863 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 166it [00:04, 36.53it/s]
episode 864 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 172it [00:04, 38.38it/s]
episode 865 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 173it [00:04, 36.46it/s]
episode 866 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 97it [00:02, 35.75it/s]
episode 867 avg_loss: 0.004 total_reward [train:3.000 test:-] e-greedy:0.900: : 201it [00:05, 36.86it/s]
episode 868 avg_loss: 0.004 total_reward [train:3.000 tes

episode 937 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 94it [00:02, 34.50it/s]
episode 938 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 124it [00:03, 37.00it/s]
episode 939 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 122it [00:03, 34.64it/s]
episode 940 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 90it [00:02, 34.48it/s]
episode 941 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 167it [00:04, 35.14it/s]
episode 942 avg_loss: 0.005 total_reward [train:3.000 test:-] e-greedy:0.900: : 216it [00:06, 36.90it/s]
episode 943 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 137it [00:03, 35.10it/s]
episode 944 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 167it [00:04, 35.28it/s]
episode 945 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 94it [00:02, 34.49it/s]
episode 946 avg_loss: 0.003 total_reward [train:1.000 test

episode 1015 avg_loss: 0.004 total_reward [train:2.000 test:-] e-greedy:0.900: : 196it [00:05, 36.75it/s]
episode 1016 avg_loss: 0.002 total_reward [train:0.000 test:-] e-greedy:0.900: : 96it [00:02, 36.02it/s]
episode 1017 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 97it [00:02, 33.83it/s]
episode 1018 avg_loss: 0.003 total_reward [train:1.000 test:-] e-greedy:0.900: : 121it [00:03, 34.87it/s]
episode 1019 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 172it [00:04, 36.07it/s]
episode 1020 avg_loss: 0.003 total_reward [train:1.000 test:-] e-greedy:0.900: : 144it [00:04, 33.11it/s]
episode 1021 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 99it [00:02, 34.22it/s]
episode 1022 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 97it [00:02, 33.25it/s]
episode 1023 avg_loss: 0.004 total_reward [train:1.000 test:-] e-greedy:0.900: : 124it [00:03, 36.04it/s]
episode 1024 avg_loss: 0.004 total_reward [train:1

episode 1092 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 168it [00:04, 36.71it/s]
episode 1093 avg_loss: 0.003 total_reward [train:0.000 test:-] e-greedy:0.900: : 94it [00:02, 34.34it/s]
episode 1094 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 152it [00:04, 36.72it/s]
episode 1095 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 167it [00:04, 35.20it/s]
episode 1096 avg_loss: 0.003 total_reward [train:1.000 test:-] e-greedy:0.900: : 122it [00:03, 34.90it/s]
episode 1097 avg_loss: 0.002 total_reward [train:1.000 test:-] e-greedy:0.900: : 144it [00:04, 36.75it/s]
episode 1098 avg_loss: 0.004 total_reward [train:0.000 test:-] e-greedy:0.900: : 94it [00:02, 34.38it/s]
episode 1099 avg_loss: 0.003 total_reward [train:2.000 test:-] e-greedy:0.900: : 156it [00:04, 36.56it/s]
episode 1100 avg_loss: 0.003 total_reward [train:2.000 test:1.000] e-greedy:0.900: : 172it [00:05, 36.20it/s]
episode 1101 avg_loss: 0.003 total_reward [t

KeyboardInterrupt: 

In [None]:
import time
start_t = time.time()
a = np.random.permutation(int(1e1))
print(time.time()-start_t)

# 