In [None]:
#@title Run to import libraries
!pip install Box2D > /dev/null 2>&1
!pip install gym[box2d] pyvirtualdisplay > /dev/null 2>&1
!pip install plotly --upgrade

In [None]:
!pip install gym[atari]

In [None]:
!pip install tensorboard

In [None]:
!pip install pandas

In [None]:
!pip install pyvirtualdisplay

In [None]:
!pip install torch

In [None]:
!pip install gym[accept-rom-license]

In [None]:
!pip install kornia

In [1]:
%load_ext tensorboard

import math
import gym 
import plotly.express as px
import numpy as np
import warnings
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import random
from collections import Counter, deque
import pandas as pd
from sklearn.model_selection import train_test_split
import datetime
from sklearn.metrics import confusion_matrix
import plotly.figure_factory as ff

import os
import glob
from IPython import display as ipythondisplay
from tqdm.notebook import tqdm
from gym.wrappers import Monitor
from IPython.display import HTML
import base64
import io
import pickle
import torch
from torch import nn
import kornia.augmentation as aug
warnings.filterwarnings("ignore")

In [2]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
# display.start()

In [3]:
import os, sys, copy, argparse, shutil
def parse_arguments():
    parser = argparse.ArgumentParser(description='Deep Q Network Argument Parser')
    parser.add_argument('--seed', dest='seed', type=int, default=1)
    parser.add_argument('--env', dest='env', type=str, default='CartPole-v0')
    parser.add_argument('--save_interval', type=int, default=50, help='save model every n episodes')
    parser.add_argument('--log_interval', type=int, default=10, help='logging every n episodes')
    parser.add_argument('--render', help='render', type=int, default=1)
    parser.add_argument('--batch_size', help='batch_size', type=int, default=32)
    parser.add_argument('--train_freq', help='train_frequency', type=int, default=1)
    parser.add_argument('--max_episode', help='maximum episode', type=int, default=None)
    parser.add_argument('--max_timesteps', help='maximum timestep', type=int, default=100000000)
    parser.add_argument('--lr', dest='lr', type=float, default=0.00025)
    parser.add_argument('--lr_decay', action='store_true', help='decay learning rate')
    parser.add_argument('--gamma', help='discount_factor', type=float, default=0.99)
    parser.add_argument('--warmup_mem', type=int, help='warmup memory size', default=1000)
    parser.add_argument('--frame_skip', type=int, help='number of frames to skip for each action', default=3)
    parser.add_argument('--frame_stack', type=int, help='number of frames to stack', default=4)
    parser.add_argument('--memory', help='memory size', type=int, default=1000000)
    parser.add_argument('--initial_epsilon', '-ie', help='initial_epsilon', type=float, default=0.5)
    parser.add_argument('--final_epsilon', '-fe', help='final_epsilon', type=float, default=0.05)
    parser.add_argument('--max_epsilon_decay_steps', '-eds', help='maximum steps to decay epsilon', type=int, default=100000)
    parser.add_argument('--max_grad_norm', type=float, default=None, help='maximum gradient norm')
    parser.add_argument('--soft_update', '-su', action='store_true', help='soft update target network')
    parser.add_argument('--double_q', '-dq', action='store_true', help='enabling double DQN')
    parser.add_argument('--dueling_net', '-dn', action='store_true', help='enabling dueling network')
    parser.add_argument('--test', action='store_true', help='test the trained model')
    parser.add_argument('--tau', type=float, default=0.01, help='tau for soft target network update')
    parser.add_argument('--hard_update_freq', '-huf', type=int, default=2000, help='hard target network update frequency')
    parser.add_argument('--save_dir', type=str, default='./data')
    parser.add_argument('--resume_step', '-rs', type=int, default=None)
    return parser.parse_args()

In [4]:
#@title Set up constants for env and training
test = False 
save_dir = './data'
render = False
max_episode = None
max_timesteps = 100000000


In [5]:
#@title Augmentations

color_jitter = aug.ColorJitter(
        brightness=np.random.random(),
        contrast=np.random.random(),
        saturation=np.random.random(),
        hue=np.random.random(),
        p=0.5
        )
random_elastic_transform = aug.RandomElasticTransform()
random_fisheye = aug.RandomFisheye(
        center_x=torch.tensor([-.3, .3]),
        center_y=torch.tensor([-.3, .3]),
        gamma=torch.tensor([.9, 1.]),
        )
# need to divide by 255.0
random_color_equalize = aug.RandomEqualize()
random_gaussian_blur = aug.RandomGaussianBlur(
        kernel_size=(9, 9),
        sigma = (5., 5.)
        )
random_gaussian_noise = aug.RandomGaussianNoise()
random_horizontal_flip = aug.RandomHorizontalFlip()
random_color_invert = aug.RandomInvert()
random_perspective_shift = aug.RandomPerspective()
random_shift = nn.Sequential(aug.RandomCrop((190, 140)), nn.ReplicationPad2d(20), aug.RandomCrop((210, 160)))

In [6]:
def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias


# for 84 x 84 inputs
OUT_DIM = {2: 39, 4: 35, 6: 31}
# for 64 x 64 inputs
OUT_DIM_64 = {2: 29, 4: 25, 6: 21}

''' TODO change the layer parameters ''' 
class PixelEncoder(nn.Module):
    """Convolutional encoder of pixels observations."""
    # [210, 160] --> crop [190, 140]
    def __init__(self, obs_shape, feature_dim=50, num_layers=3, num_filters=64, output_logits=False):
        super().__init__()

        assert len(obs_shape) == 3
        self.obs_shape = obs_shape
        self.feature_dim = feature_dim
        self.num_layers = num_layers
        
        # 160, 210, 3 
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=5)
        # Input to conv2: 32, 42, 32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # Input to conv3: 15, 20, 64
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        # Output from conv3: 12, 17, 64

        # out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[num_layers]
        out_dims = (12, 17)
        self.fc = nn.Linear(num_filters * out_dims[0] * out_dims[1], self.feature_dim)
        self.ln = nn.LayerNorm(self.feature_dim)

        self.outputs = dict()
        self.output_logits = output_logits

    def reparameterize(self, mu, logstd):
        std = torch.exp(logstd)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward_conv(self, obs):
        self.outputs['obs'] = obs
        conv1 = torch.relu(self.conv1(obs))
        self.outputs['conv1'] = conv1
        conv2 = torch.relu(self.conv2(conv1))
        self.outputs['conv2'] = conv2
        conv3 = torch.relu(self.conv3(conv2))
        self.outputs['conv3'] = conv3

        h = conv3.view(conv3.size(0), -1)
        return h

    def forward(self, obs, detach=False):
        h = self.forward_conv(obs)

        if detach:
            h = h.detach()

        h_fc = self.fc(h)
        self.outputs['fc'] = h_fc

        h_norm = self.ln(h_fc)
        self.outputs['ln'] = h_norm

        if self.output_logits:
            out = h_norm
        else:
            out = torch.tanh(h_norm)
            self.outputs['tanh'] = out

        return out

    def copy_conv_weights_from(self, source):
        """Tie convolutional layers"""
        # only tie conv layers
        for i in range(self.num_layers):
            tie_weights(src=source.convs[i], trg=self.convs[i])

    def log(self, L, step, log_freq):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_encoder/%s_hist' % k, v, step)
            if len(v.shape) > 2:
                L.log_image('train_encoder/%s_img' % k, v[0], step)

        for i in range(self.num_layers):
            L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
        L.log_param('train_encoder/fc', self.fc, step)
        L.log_param('train_encoder/ln', self.ln, step)

In [7]:
def wrap_env(env, train=True):
    suffix = 'train' if train else 'test'
    monitor_dir = os.path.join(save_dir, 'monitor_%s' % suffix)
    os.makedirs(monitor_dir, exist_ok=True)
    if not train:
        video_save_interval = 10
        env = Monitor(env, directory=monitor_dir,
                      video_callable=lambda episode_id: episode_id % video_save_interval == 0,
                      force=True)
    else:
        if render:
            if max_episode is not None:
                video_save_interval = int(max_episode / 3)
            else:
                video_save_interval = int(max_timesteps / float(env._max_episode_steps) / 3)
            env = Monitor(env, directory=monitor_dir,
                          video_callable=lambda episode_id: episode_id % video_save_interval == 0,
                          force=True)
        else:
            env = Monitor(env, directory=monitor_dir, video_callable=False, force=True)
    return env

In [8]:
class ReplayMemory(object):
    def __init__(self, max_epi_num=2000, max_epi_len=200):
        # capacity is the maximum number of steps in memory
        self.max_epi_num = max_epi_num
        self.max_epi_len = max_epi_len
        # saves each tuple of (state, action, next state, reward)
        self.capacity = 1000 # self.max_epi_num * max_epi_len
        self.idx = 0
        self.obs_memory = np.zeros((self.capacity, 210, 160, 3)) # deque(maxlen=self.max_epi_num * max_epi_len)
        self.next_memory = np.zeros((self.capacity, 210, 160, 3))
        self.act_memory = np.zeros((self.capacity, 1))
        self.reward_memory = np.zeros((self.capacity, 1))
        self.is_av = False
        self.current_epi = 0

    def reset(self):
        self.current_epi = 0
        self.memory.clear()

    ''' deprecated for tuple buffer '''
    def create_new_epi(self):
        pass

    def remember(self, state, next_state, action, reward):
        if self.idx == self.capacity:
            self.idx = 0
        self.obs_memory[self.idx] = state.copy()
        self.next_memory[self.idx] = next_state.copy()
        self.act_memory[self.idx] = action
        self.reward_memory[self.idx] = reward
        self.idx += 1
        
        '''
        if len(self.memory) < self.capacity:
            new_sample = np.array([state, action, reward, next_state])
            if len(self.memory) == 0:
                self.memory = [new_sample]
            else:
                length = len(self.memory)
                self.memory.append(new_sample)
        '''
                
    # samples batch_size
    def sample(self, batch_size):
        if batch_size < self.idx:
            idx = np.random.randint(0, self.idx - 1, batch_size)
            return self.obs_memory[idx], self.next_memory[idx], self.act_memory[idx], self.reward_memory[idx]
        return self.obs_memory, self.next_memory, self.act_memory, self.reward_memory

    def size(self):
        return self.idx

    def is_available(self):
        self.is_av = True
        if self.idx <= 1:
            self.is_av = False
        return self.is_av

    def print_info(self):
        pass

In [9]:
#@title Create a training conv agent
import torch.nn.functional as F

class DQNetworkConv(nn.Module):
    def __init__(self, in_channels, act_dim, dueling=False):
        super(DQNetworkConv, self).__init__()
        self.act_dim = act_dim
        self.dueling = dueling
        # 160, 210, 3 
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=5, stride=5)
        # Input to conv2: 32, 42, 32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # Input to conv3: 15, 20, 64
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        # Output from conv3: 12, 17, 64
        if self.dueling:
            self.v_fc4 = nn.Linear(12 * 17 * 64, 512)
            self.adv_fc4 = nn.Linear(12 * 17 * 64, 512)
            self.v_fc5 = nn.Linear(512, 1)
            self.adv_fc5 = nn.Linear(512, self.act_dim)
        else:
            self.fc4 = nn.Linear(12 * 17 * 64, 512)
            self.fc5 = nn.Linear(512, self.act_dim)
        self.parameters = (list(self.conv1.parameters())) + (list(self.conv2.parameters())) + (list(self.conv3.parameters())) + (list(self.fc4.parameters())) + (list(self.fc5.parameters()))

    def forward(self, st):
        out = F.relu(self.conv1(st))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out.view(out.size(0), -1)
        if self.dueling:
            val = F.relu(self.v_fc4(out))
            adv = F.relu(self.adv_fc4(out))
            val = self.v_fc5(val)
            adv = self.adv_fc5(adv)
            out = val.expand_as(adv) + adv - adv.mean(-1, keepdim=True).expand_as(adv)
        else:
            out = F.relu(self.fc4(out))
            out = self.fc5(out)
        return out

In [10]:
#@title Create a training FC agent
import torch.nn.functional as F

class DQNetworkFC(nn.Module):
    def __init__(self, z_dim, act_dim, dueling=False):
        super(DQNetworkFC, self).__init__()
        self.act_dim = act_dim
        self.input_dim = z_dim 
        self.dueling = dueling
        if self.dueling:
            self.v_fc1 = nn.Linear(z_dim, 512)
            self.adv_fc1 = nn.Linear(z_dim, 512)
            self.v_fc2 = nn.Linear(512, 1)
            self.adv_fc2 = nn.Linear(512, 256)
            self.v_fc3 = nn.Linear(256, 1)
            self.adv_fc3 = nn.Linear(256, self.act_dim)
        else:
            self.fc1 = nn.Linear(z_dim, 512)
            self.fc2 = nn.Linear(512, 256)
            self.fc3 = nn.Linear(256, self.act_dim)

    def forward(self, st):
        out = F.relu(self.fc1(st))
        out = F.relu(self.fc2(out))
        ''' Do we need a relu on the last layer if the output is probability over action space? '''
        out = F.relu(self.fc3(out))
        return out

In [11]:
def process_obs(obs):
    obs = torch.Tensor(obs / 255.)
    if len(obs.shape) < 4:
        obs = obs.unsqueeze(0)
    obs = obs.permute(0, 3, 1, 2)
    return obs

In [12]:
def take_action(env, action):
    state, rew, done, _ = env.step(action)
    obs = env.render(mode='rgb_array')
    return obs, rew, done

In [13]:
MAX_STEPS = 200

In [14]:
class CURL(nn.Module):
    """
    CURL
    """

    def __init__(self, obs_shape, z_dim, batch_size, encoder, output_type="continuous", critic=None, critic_target=None):
        super(CURL, self).__init__()
        self.batch_size = batch_size

        # self.encoder = critic.encoder
        self.encoder = encoder 

        # self.encoder_target = critic_target.encoder 

        self.W = nn.Parameter(torch.rand(z_dim, z_dim))
        self.output_type = output_type

    def encode(self, x, detach=False, ema=False):
        """
        Encoder: z_t = e(x_t)
        :param x: x_t, x y coordinates
        :return: z_t, value in r2
        """
        if ema:
            with torch.no_grad():
                z_out = self.encoder_target(x)
        else:
            z_out = self.encoder(x)

        if detach:
            z_out = z_out.detach()
        return z_out

    def compute_logits(self, z_a, z_mod):
        """
        Uses logits trick for CURL:
        - compute (B,B) matrix z_a (W z_pos.T)
        - positives are all diagonal elements
        - negatives are all other elements
        - to compute loss use multiclass cross entropy with identity matrix for labels
        """
        Wz = torch.matmul(self.W, z_mod.T)  # (z_dim,B)
        logits = torch.matmul(z_a, Wz)  # (B,B)
        logits = logits - torch.max(logits, 1)[0][:, None]
        return logits

In [15]:
#@title Generate a batch of negatively labelled examples given observations

def generate_negatives(obs):
    neg_idx = np.random.randint(len(obs), size=len(obs))
    pos_idx = np.arange(len(obs))
    resample = (neg_idx == pos_idx)
    for (i, r) in enumerate(resample):
        if r:
            idx = neg_idx[i]
        else:
            idx = np.random.randint(0, len(obs), 1)[0]
            while idx == i:
                idx = np.random.randint(0, len(obs), 1)[0]
        neg_idx[i] = idx
    return (obs[neg_idx]).copy()


In [16]:
#@title Create a training agent (wrapper for conv agent)

GAMMA = 0.99

class Agent(object):
    def __init__(self, act_dim, in_channels=3, max_epi_num=50, max_epi_len=300, CURL=None, conv_net=False):
        self.N_action = act_dim
        self.max_epi_num = max_epi_num
        self.max_epi_len = max_epi_len
        ''' To decide when to copy weights to the target network '''
        self.num_param_updates = 0
        self.CURL = CURL
        if conv_net:
            self.conv_net = DQNetworkConv(in_channels, act_dim)
            self.target = DQNetworkConv(in_channels, act_dim)
        else:
            ''' if using the encoder head for contrastive loss '''
            self.conv_net = DQNetworkFC(self.CURL.encoder.feature_dim, act_dim)
            self.target = DQNetworkFC(self.CURL.encoder.feature_dim, act_dim)
        self.buffer = ReplayMemory(max_epi_num=self.max_epi_num, max_epi_len=self.max_epi_len)
        self.gamma = 0.99
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(list(self.conv_net.parameters()) + list(self.CURL.parameters()) + list(self.CURL.encoder.parameters()), lr=1e-3)

    def remember(self, state, action, reward, next_state):
        self.buffer.remember(state, next_state, action, reward)

    ''' Copy the weights to the target network every 100 updates '''
    def train(self, batch_size=32, target_update_freq=100, use_encoder=True):
        if self.buffer.is_available():
            obs, next_obs, action_list, reward_list = self.buffer.sample(batch_size)
            
            ''' Pass through the encoder to get encodings
             If also training the contrastive loss
             include that here! 
             1. data augmentation to create pos and negative pairs
             2. encoder 
             3. update encoder loss function (using a separate optimizer) or add to the loss computed below
            '''

            losses = [] 

            # check if obs is a numpy or a torch tensor
            if use_encoder:
                obs_anchor = process_obs(obs.copy()) # / 255.)
                obs_pos = random_shift(obs_anchor)
                # mixed_obs = generate_negatives(obs)
                # mixed_obs = process_obs(mixed_obs)
                # obs_neg = random_shift(mixed_obs)
                z_a = self.CURL.encode(obs_anchor)
                z_pos = self.CURL.encode(obs_pos)
                # Mix pairs to generate negative labels
                # z_neg = self.CURL.encode(obs_neg)
                next_obs = process_obs(next_obs.copy())
                z_next = self.CURL.encode(next_obs)

                logits = self.CURL.compute_logits(z_a, z_pos)
                labels = torch.arange(logits.shape[0]).long()
                '''
                pos_logits = self.CURL.compute_logits(z_a, z_pos)
                neg_logits = self.CURL.compute_logits(z_a, z_neg)
                # [32, 32]
                pos_labels = torch.ones(pos_logits.shape[0]).long()
                neg_labels = torch.zeros(neg_logits.shape[0]).long() 
                # TODO: stack pos and neg logits and labels (double check dim)
                logits = torch.stack([pos_logits, neg_logits]).squeeze(0)
                labels = torch.stack([pos_labels, neg_labels]).squeeze(0)
                '''
                
                # pass into the loss function
                encoding_loss = nn.CrossEntropyLoss()(logits, labels)

                ''' Combine encoding loss with rl loss below '''
                losses.append(encoding_loss)

                # Then pass that encoding through the conv_net to get Q value estimates
                Qs = self.conv_net(z_a)
                next_Qs = self.target(z_next).detach().max(1)[0]
            
            else:
                ''' If not using the encoder, pass the obs directly to the CNN '''
                obs = process_obs(obs)
                # estimate current q values from observations
                Qs = self.conv_net(obs)
                # find next max q values based on next observations
                next_Qs = self.target(next_obs).detach().max(1)[0]
            
            ''' find target q values ''' 
            next_Qs = next_Qs.numpy()
            Qs = torch.gather(Qs, dim=1, index=torch.tensor(action_list, dtype=torch.int64))
            target_Qs = torch.tensor(reward_list.squeeze(-1) + GAMMA * next_Qs).long()
            ''' try to set Qs equal to target_Qs '''
            q_loss = self.loss_fn(Qs, target_Qs).long()
            losses.append(q_loss)
            
            ''' Loss update for q network and encoder head '''
            losses = torch.stack(losses).sum()
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            self.num_param_updates += 1
            if self.num_param_updates % target_update_freq == 0:
                self.target.load_state_dict(self.conv_net.state_dict())

    # TODO: check the sizes of inputs and outputs
    def get_action(self, obs, epsilon, use_encoding=True):
        ''' 
         If using an encoder, need to pass that thorugh the encoder
         then use the encoding to pass through self.conv_net
        '''
        # obs = torch.tensor(obs)
        if use_encoding:
            obs = process_obs(obs)
            obs = self.CURL.encode(obs).detach()

        # Dividing obs by 255 is handled in encoder forward (only needed for use_encoding=False)
        if len(obs.shape) == 1:
            obs = obs.unsqueeze(0)

        # epsilon greedy for selecting which action to take
        if random.random() > epsilon:
            qs = self.conv_net(obs)
            action = qs[0].argmax().data.item()
        else:
            action = random.randint(0, self.N_action-1)

        return action

def get_decay(epi_iter):
    decay = math.pow(0.999, epi_iter)
    if decay < 0.05:
        decay = 0.05
    return decay

In [17]:
def main():
    env = gym.make('ALE/SpaceInvaders-v5')
    max_epi_iter = 4000
    max_MC_iter = 400
    obs = env.render(mode='rgb_array')
    encoder = PixelEncoder(obs_shape=obs.shape, feature_dim=50, num_layers=2, num_filters=64, output_logits=False)
    CURL_encoder = CURL(obs_shape=obs.shape, z_dim=50, batch_size=1, encoder=encoder, output_type="continuous")
    agent = Agent(act_dim=env.action_space.n, max_epi_num=10000, max_epi_len=max_MC_iter, CURL=CURL_encoder)
    train_curve = []
    for epi_iter in range(max_epi_iter):
        random.seed()
        env.reset()
        obs = env.render(mode='rgb_array')
        returns = 0.0
        for MC_iter in range(max_MC_iter):
            action = agent.get_action(obs, get_decay(epi_iter))
            next_obs, reward, done = take_action(env, action)
            returns += reward * agent.gamma ** (MC_iter)
            agent.remember(obs, action, reward, next_obs)
            obs = next_obs.copy()
            if done or MC_iter >= max_MC_iter-1:
                agent.buffer.create_new_epi()
                break
        print('Episode', epi_iter, 'returns', returns)
        if epi_iter % 1 == 0:
            train_curve.append(returns)
        if agent.buffer.is_available():
            agent.train()
        
    print(train_curve)

In [18]:
main()

A.L.E: Arcade Learning Environment (version +978d2ce)
[Powered by Stella]


Episode 0 returns 51.54540406576974
Episode 1 returns 173.57588974650156
Episode 2 returns 301.7792550142323
Episode 3 returns 108.58535312846712
Episode 4 returns 98.54669305938216
Episode 5 returns 51.366155668833315
Episode 6 returns 115.33674990993892
Episode 7 returns 120.78674959584366
Episode 8 returns 96.0214342684786
Episode 9 returns 26.36302651333793
Episode 10 returns 94.38063102233812
Episode 11 returns 65.69998943463104
Episode 12 returns 35.36232880432375
Episode 13 returns 147.1739753945593
Episode 14 returns 89.49836256097485
Episode 15 returns 122.63505123309535
Episode 16 returns 114.93800648538009
Episode 17 returns 46.204274575136786
Episode 18 returns 68.7230447034982
Episode 19 returns 119.00689668489433
Episode 20 returns 45.943030805533354
Episode 21 returns 13.384461478169396
Episode 22 returns 113.09578373945791
Episode 23 returns 131.806795458997
Episode 24 returns 86.0781584575677
Episode 25 returns 146.5043245915619
Episode 26 returns 64.02480596088192
Epi

Episode 217 returns 332.10777073250193
Episode 218 returns 81.36285077249737
Episode 219 returns 83.45838378607532
Episode 220 returns 94.19772899931
Episode 221 returns 105.01880046140896
Episode 222 returns 63.13268747739718
Episode 223 returns 76.44782276713892
Episode 224 returns 26.295938847421503
Episode 225 returns 114.07586757273297
Episode 226 returns 66.24437185138194
Episode 227 returns 144.70193120770452
Episode 228 returns 74.60012650715478
Episode 229 returns 89.25554135946633
Episode 230 returns 84.92068651930767
Episode 231 returns 98.79281783733849
Episode 232 returns 70.9244993409438
Episode 233 returns 70.04999512418277
Episode 234 returns 120.59138936367219
Episode 235 returns 96.32443710907063
Episode 236 returns 130.52327654241245
Episode 237 returns 92.74843671658687
Episode 238 returns 42.8822413501605
Episode 239 returns 52.83951372334718
Episode 240 returns 115.05256584819051
Episode 241 returns 38.19281106747375
Episode 242 returns 13.026747199791863
Episode 

Episode 431 returns 122.42778752621406
Episode 432 returns 125.29168837489412
Episode 433 returns 124.00172972277288
Episode 434 returns 103.6036856113859
Episode 435 returns 50.210968101700175
Episode 436 returns 76.7444677812315
Episode 437 returns 253.67510707972414
Episode 438 returns 85.38690574548467
Episode 439 returns 100.92643281233339
Episode 440 returns 57.196455657485366
Episode 441 returns 16.84665708597546
Episode 442 returns 90.17829294189157
Episode 443 returns 263.17267490820245
Episode 444 returns 145.4556998804628
Episode 445 returns 141.36687266092707
Episode 446 returns 70.07425996153607
Episode 447 returns 168.87971813299313
Episode 448 returns 307.98003706917206
Episode 449 returns 141.95545995986802
Episode 450 returns 39.66489722596145
Episode 451 returns 28.53292359106002
Episode 452 returns 124.890934832606
Episode 453 returns 113.31192820487007
Episode 454 returns 44.07845561339368
Episode 455 returns 120.62417066015594
Episode 456 returns 39.05427245347745


Episode 646 returns 92.37529059047185
Episode 647 returns 102.96396961841896
Episode 648 returns 58.956754796855115
Episode 649 returns 68.95215730992102
Episode 650 returns 48.39526728224671
Episode 651 returns 39.51694265543781
Episode 652 returns 75.13650877206442
Episode 653 returns 87.66972608801834
Episode 654 returns 72.22737461613923
Episode 655 returns 97.55966449539001
Episode 656 returns 71.37701013293483
Episode 657 returns 259.05106742017836
Episode 658 returns 163.06781862384523
Episode 659 returns 77.01669515387948
Episode 660 returns 92.2490142641935
Episode 661 returns 35.83311011672011
Episode 662 returns 45.507273424056294
Episode 663 returns 59.163301836320485
Episode 664 returns 109.37183272526126
Episode 665 returns 33.85177833372542
Episode 666 returns 90.19033823200056
Episode 667 returns 93.84580007164483
Episode 668 returns 4.425471700390396
Episode 669 returns 98.78285304516062
Episode 670 returns 26.803732998263705
Episode 671 returns 107.4294689336592
Episo

Episode 860 returns 158.12277742999984
Episode 861 returns 38.943744220913636
Episode 862 returns 30.69642518287098
Episode 863 returns 52.19941641783069
Episode 864 returns 115.94993655692859
Episode 865 returns 68.54883381504756
Episode 866 returns 34.57368674060712
Episode 867 returns 105.29221829988113
Episode 868 returns 57.94519112542265
Episode 869 returns 78.93932299669514
Episode 870 returns 27.158409456676544
Episode 871 returns 29.101813870988654
Episode 872 returns 153.49699785721137
Episode 873 returns 29.812240521681446
Episode 874 returns 73.6132806447662
Episode 875 returns 116.35658305236066
Episode 876 returns 24.989681035361563
Episode 877 returns 30.649439389424245
Episode 878 returns 83.38622453155769
Episode 879 returns 63.20397349184056
Episode 880 returns 191.1282002849001
Episode 881 returns 64.13049008511273
Episode 882 returns 85.91068860061868
Episode 883 returns 37.801334098802606
Episode 884 returns 16.994006413607032
Episode 885 returns 62.78218240638739


Episode 1072 returns 41.023910400977144
Episode 1073 returns 65.29982047428487
Episode 1074 returns 97.85521064809481
Episode 1075 returns 43.21134667162733
Episode 1076 returns 22.31117218608589
Episode 1077 returns 97.44947472549393
Episode 1078 returns 53.103421538823895
Episode 1079 returns 61.98735108543954
Episode 1080 returns 68.84359650798427
Episode 1081 returns 29.443851243063985
Episode 1082 returns 106.79004050173742
Episode 1083 returns 31.718504666900134
Episode 1084 returns 80.90925006115253
Episode 1085 returns 31.984786067123128
Episode 1086 returns 75.76933588665165
Episode 1087 returns 4.847302681479113
Episode 1088 returns 42.96118723820662
Episode 1089 returns 76.91611793331639
Episode 1090 returns 63.0485575094777
Episode 1091 returns 59.108252086202945
Episode 1092 returns 17.931917479478326
Episode 1093 returns 25.697096142001072
Episode 1094 returns 12.433813355986189
Episode 1095 returns 45.89004137034692
Episode 1096 returns 29.842977774010837
Episode 1097 re

Episode 1281 returns 45.240056432400515
Episode 1282 returns 99.06016075474255
Episode 1283 returns 17.977268895243327
Episode 1284 returns 87.6180669413798
Episode 1285 returns 84.56169213458281
Episode 1286 returns 121.40954824018552
Episode 1287 returns 31.18506105099667
Episode 1288 returns 17.184487788890678
Episode 1289 returns 265.2699247259007
Episode 1290 returns 28.055103509579954
Episode 1291 returns 66.86879643783487
Episode 1292 returns 17.483505805832337
Episode 1293 returns 43.862744880932624
Episode 1294 returns 65.4088788028849
Episode 1295 returns 111.42137618909183
Episode 1296 returns 82.77251511495186
Episode 1297 returns 89.63507607709616
Episode 1298 returns 91.64801937173469
Episode 1299 returns 37.33906719202844
Episode 1300 returns 48.48264019805906
Episode 1301 returns 90.062354593934
Episode 1302 returns 51.23029577485504
Episode 1303 returns 92.99904862640055
Episode 1304 returns 81.79271115556998
Episode 1305 returns 43.378118524267144
Episode 1306 returns

Episode 1491 returns 94.6309595125336
Episode 1492 returns 71.27313324101625
Episode 1493 returns 45.727767793667326
Episode 1494 returns 118.44610439475964
Episode 1495 returns 12.937964580597763
Episode 1496 returns 189.2173626425645
Episode 1497 returns 55.072526128183284
Episode 1498 returns 69.83722329044275
Episode 1499 returns 89.81793976462629
Episode 1500 returns 29.41547540254126
Episode 1501 returns 34.89630628981067
Episode 1502 returns 314.97832360773356
Episode 1503 returns 109.52795206852255
Episode 1504 returns 21.815416686510606
Episode 1505 returns 49.23719329142241
Episode 1506 returns 110.47282450636885
Episode 1507 returns 48.91024972124091
Episode 1508 returns 120.09747214778332
Episode 1509 returns 100.53275480712098
Episode 1510 returns 113.09521652196861
Episode 1511 returns 53.80794536424257
Episode 1512 returns 80.2211972797947
Episode 1513 returns 76.6167981046959
Episode 1514 returns 74.49502767063386
Episode 1515 returns 149.28238330802398
Episode 1516 ret

Episode 1699 returns 47.65756442208084
Episode 1700 returns 48.78743912443984
Episode 1701 returns 114.30240944166381
Episode 1702 returns 63.93369092751142
Episode 1703 returns 21.876988519056454
Episode 1704 returns 70.13352570082172
Episode 1705 returns 30.736725079091546
Episode 1706 returns 37.46431433362157
Episode 1707 returns 87.10478655063473
Episode 1708 returns 38.75519478422824
Episode 1709 returns 187.45125392687024
Episode 1710 returns 22.841156804496883
Episode 1711 returns 68.11431808545922
Episode 1712 returns 88.35386673761343
Episode 1713 returns 72.77178466942007
Episode 1714 returns 103.51388950401893
Episode 1715 returns 81.2592021207906
Episode 1716 returns 77.43191090419849
Episode 1717 returns 111.8141551064493
Episode 1718 returns 71.59226579954769
Episode 1719 returns 107.73818642513717
Episode 1720 returns 89.50722836381102
Episode 1721 returns 75.73415196141372
Episode 1722 returns 115.55130966940379
Episode 1723 returns 107.20167820127254
Episode 1724 retu

Episode 1908 returns 118.97790662533393
Episode 1909 returns 51.64257919122141
Episode 1910 returns 7.226271621120381
Episode 1911 returns 212.73936683219006
Episode 1912 returns 55.529416099340835
Episode 1913 returns 66.37922213081656
Episode 1914 returns 75.60709209630498
Episode 1915 returns 27.87293606030179
Episode 1916 returns 0.0
Episode 1917 returns 70.87230732629487
Episode 1918 returns 31.620159952420067
Episode 1919 returns 23.953521968818436
Episode 1920 returns 72.13683313058864
Episode 1921 returns 17.65835040698847
Episode 1922 returns 90.38362649367784
Episode 1923 returns 60.740251191846866
Episode 1924 returns 100.70559486735586
Episode 1925 returns 4.4476655961697045
Episode 1926 returns 211.16426866373553
Episode 1927 returns 39.26083857985108
Episode 1928 returns 59.899870158858654
Episode 1929 returns 25.904331505580814
Episode 1930 returns 50.31086935637187
Episode 1931 returns 43.119548366616314
Episode 1932 returns 58.230166107119516
Episode 1933 returns 65.24

Episode 2117 returns 62.40569053114031
Episode 2118 returns 9.060644977231139
Episode 2119 returns 100.49308089018797
Episode 2120 returns 49.81763730284111
Episode 2121 returns 110.97217147094618
Episode 2122 returns 187.92107247469244
Episode 2123 returns 77.7035107203261
Episode 2124 returns 71.92698021296509
Episode 2125 returns 54.0115355737597
Episode 2126 returns 68.25260481251442
Episode 2127 returns 12.507558124805982
Episode 2128 returns 19.225940218219794
Episode 2129 returns 42.40152246547207
Episode 2130 returns 27.470937983299983
Episode 2131 returns 38.00068747960697
Episode 2132 returns 0.0
Episode 2133 returns 29.350771325757464
Episode 2134 returns 38.13461767750329
Episode 2135 returns 31.02373290683165
Episode 2136 returns 23.906826256973552
Episode 2137 returns 30.103122481468315
Episode 2138 returns 42.79578869552824
Episode 2139 returns 17.878760379018885
Episode 2140 returns 17.69284010419218
Episode 2141 returns 26.015896749699028
Episode 2142 returns 33.807258

Episode 2327 returns 12.615549881559744
Episode 2328 returns 17.269438349337555
Episode 2329 returns 4.222187488964646
Episode 2330 returns 18.967181880576888
Episode 2331 returns 78.04993104021551
Episode 2332 returns 85.18428429905113
Episode 2333 returns 0.0
Episode 2334 returns 209.89215595809924
Episode 2335 returns 35.257572408224064
Episode 2336 returns 84.10355605436908
Episode 2337 returns 35.60479370480073
Episode 2338 returns 26.766211940971534
Episode 2339 returns 14.873634855541411
Episode 2340 returns 15.919690558901342
Episode 2341 returns 25.77936325464283
Episode 2342 returns 26.390220096625242
Episode 2343 returns 82.86687195587356
Episode 2344 returns 21.232350671054945
Episode 2345 returns 18.966530891984902
Episode 2346 returns 27.899659371806607
Episode 2347 returns 14.959764527396018
Episode 2348 returns 17.04023250289829
Episode 2349 returns 19.507362696047597
Episode 2350 returns 22.938747854025735
Episode 2351 returns 4.5831859722216945
Episode 2352 returns 35

Episode 2536 returns 26.907305739971314
Episode 2537 returns 66.84569804915263
Episode 2538 returns 0.0
Episode 2539 returns 16.69422334750061
Episode 2540 returns 17.785063279886852
Episode 2541 returns 21.140074986353216
Episode 2542 returns 8.275100760764754
Episode 2543 returns 38.74916654587142
Episode 2544 returns 28.704781188705663
Episode 2545 returns 13.672563118977802
Episode 2546 returns 18.444908204684268
Episode 2547 returns 17.302895835484456
Episode 2548 returns 18.961679763921694
Episode 2549 returns 42.914410549976516
Episode 2550 returns 4.316152426553716
Episode 2551 returns 27.07198768170091
Episode 2552 returns 21.00748341327194
Episode 2553 returns 16.387878752880084
Episode 2554 returns 86.03423780282442
Episode 2555 returns 31.668407436175713
Episode 2556 returns 11.395044876457774
Episode 2557 returns 55.91641948622225
Episode 2558 returns 10.558421067320104
Episode 2559 returns 42.807708396432176
Episode 2560 returns 15.734985472024714
Episode 2561 returns 51.

Episode 2746 returns 11.494159358902296
Episode 2747 returns 92.5392538206254
Episode 2748 returns 60.38656227193891
Episode 2749 returns 23.077687819411203
Episode 2750 returns 13.01494580817416
Episode 2751 returns 16.984816659693646
Episode 2752 returns 58.41507263796756
Episode 2753 returns 8.682896325355337
Episode 2754 returns 244.89118243447845
Episode 2755 returns 34.404035336031434
Episode 2756 returns 62.992995140941716
Episode 2757 returns 55.84333706176627
Episode 2758 returns 35.05889647897105
Episode 2759 returns 64.40115483323841
Episode 2760 returns 13.7542973077462
Episode 2761 returns 12.283470872860013
Episode 2762 returns 52.44414684666306
Episode 2763 returns 80.20126637981463
Episode 2764 returns 12.310469664522966
Episode 2765 returns 26.947159496474956
Episode 2766 returns 33.8393057889934
Episode 2767 returns 41.16449074633722
Episode 2768 returns 22.802196020710518
Episode 2769 returns 13.166470587095528
Episode 2770 returns 4.403388552372854
Episode 2771 retu

Episode 2956 returns 8.881548621316258
Episode 2957 returns 63.99906318393898
Episode 2958 returns 27.928554092712922
Episode 2959 returns 35.22834838864207
Episode 2960 returns 70.40257779505122
Episode 2961 returns 17.40797293353543
Episode 2962 returns 12.076012491647532
Episode 2963 returns 170.06260660119338
Episode 2964 returns 25.089910169178943
Episode 2965 returns 35.91257958538756
Episode 2966 returns 33.41391416184585
Episode 2967 returns 25.87693710323461
Episode 2968 returns 30.64124817360793
Episode 2969 returns 26.634430808656205
Episode 2970 returns 78.70536041413472
Episode 2971 returns 23.220631482715827
Episode 2972 returns 59.097631352136546
Episode 2973 returns 67.72825166169208
Episode 2974 returns 30.138845971153504
Episode 2975 returns 15.528758204434615
Episode 2976 returns 33.11004803904874
Episode 2977 returns 59.6298697413945
Episode 2978 returns 55.02452940483734
Episode 2979 returns 69.89633902572085
Episode 2980 returns 25.486836537618444
Episode 2981 ret

Episode 3165 returns 26.068617248941244
Episode 3166 returns 33.49857289122071
Episode 3167 returns 61.91854282300676
Episode 3168 returns 42.29995677235106
Episode 3169 returns 8.179274063417058
Episode 3170 returns 15.849068984033496
Episode 3171 returns 57.912915341331804
Episode 3172 returns 22.305176617476278
Episode 3173 returns 29.69217893625147
Episode 3174 returns 8.011825873833976
Episode 3175 returns 12.49880156391717
Episode 3176 returns 71.05207223552898
Episode 3177 returns 12.8792638612154
Episode 3178 returns 73.69336889091299
Episode 3179 returns 75.95171481953022
Episode 3180 returns 24.883526500790765
Episode 3181 returns 52.752403881100534
Episode 3182 returns 8.824487780649049
Episode 3183 returns 36.30524289894426
Episode 3184 returns 40.17517746298269
Episode 3185 returns 280.65198716395554
Episode 3186 returns 45.59907004190664
Episode 3187 returns 8.582358239134349
Episode 3188 returns 76.72998683119384
Episode 3189 returns 90.5944919008677
Episode 3190 returns

Episode 3374 returns 41.40273133637707
Episode 3375 returns 83.73133012687767
Episode 3376 returns 81.94644559854375
Episode 3377 returns 71.59994362227354
Episode 3378 returns 46.05967688087448
Episode 3379 returns 36.870940409122014
Episode 3380 returns 52.49161701959309
Episode 3381 returns 18.002691666619178
Episode 3382 returns 8.589440213225089
Episode 3383 returns 16.617866467844717
Episode 3384 returns 8.816934618655623
Episode 3385 returns 23.53705423529192
Episode 3386 returns 30.050428511449198
Episode 3387 returns 56.02761262462121
Episode 3388 returns 16.155062524856817
Episode 3389 returns 37.76640255888978
Episode 3390 returns 42.38898412594453
Episode 3391 returns 0.0
Episode 3392 returns 15.608365970176525
Episode 3393 returns 29.072402173356927
Episode 3394 returns 78.94391750968774
Episode 3395 returns 33.81225092428542
Episode 3396 returns 28.760928618136898
Episode 3397 returns 9.14521339670122
Episode 3398 returns 29.51267079927071
Episode 3399 returns 35.40055160

Episode 3583 returns 84.57638545973803
Episode 3584 returns 39.487159817269855
Episode 3585 returns 17.06260463393178
Episode 3586 returns 8.419273652122257
Episode 3587 returns 4.222187488964646
Episode 3588 returns 40.20064580746606
Episode 3589 returns 26.735078219468665
Episode 3590 returns 36.17389732508097
Episode 3591 returns 13.021629418949999
Episode 3592 returns 31.600086752635534
Episode 3593 returns 94.51118868874953
Episode 3594 returns 19.873507977486117
Episode 3595 returns 38.622408950069726
Episode 3596 returns 39.301322369649085
Episode 3597 returns 7.919779927594238
Episode 3598 returns 93.02112777523463
Episode 3599 returns 3.5546880106336753
Episode 3600 returns 54.19443025377799
Episode 3601 returns 26.95297268081422
Episode 3602 returns 26.50053289576921
Episode 3603 returns 98.71405805466068
Episode 3604 returns 4.412208557278855
Episode 3605 returns 19.99558232365227
Episode 3606 returns 100.09341129076687
Episode 3607 returns 48.652715155948115
Episode 3608 re

Episode 3792 returns 12.866760422504314
Episode 3793 returns 38.72021107955493
Episode 3794 returns 31.0542780695698
Episode 3795 returns 52.338100413907995
Episode 3796 returns 41.9230421864078
Episode 3797 returns 104.5249053036379
Episode 3798 returns 21.75814273967615
Episode 3799 returns 69.38980278765777
Episode 3800 returns 30.42800794157622
Episode 3801 returns 39.72086092746633
Episode 3802 returns 29.989913781698196
Episode 3803 returns 30.809800178679254
Episode 3804 returns 17.470462514544877
Episode 3805 returns 21.92074467633124
Episode 3806 returns 27.07550142668496
Episode 3807 returns 44.55345386703864
Episode 3808 returns 39.35194522849129
Episode 3809 returns 51.76884766532793
Episode 3810 returns 44.6070004330123
Episode 3811 returns 65.04876027297128
Episode 3812 returns 76.69434520692396
Episode 3813 returns 69.02748474317895
Episode 3814 returns 35.454851829628254
Episode 3817 returns 53.558314830011966
Episode 3818 returns 79.31466429770559
Episode 3819 returns 

In [None]:
def old_main():
    env = gym.make('MountainCars-v0')
    if len(env.observation_space.shape) >= 3:
        env = WrapAtariEnv(env=env, noop_max=30, frameskip=3, framestack=4, test=test)
    if not test:
        dele = input("Do you wanna recreate ckpt and log folders? (y/n)")
        if dele == 'y':
            if os.path.exists(save_dir):
                shutil.rmtree(save_dir)

    env = wrap_env(env, train=not test)
    print(env.observation_space.shape)
    if len(env.observation_space.shape) >= 3:
        q_net = DQNetworkConv
    else:
        assert(False)
    agent = DQNAgent(env=env, qnet=q_net)
    if args.test:
        agent.rollout(episodes=100, render=render)
    else:
        agent.train()
    agent.env.close()

In [None]:
def main_mountaincar():
    env = gym.make('ALE/SpaceInvaders-v5')
    env.reset()

    for i in range(5):
        env.step(env.action_space.sample())
        obs = env.render(mode='rgb_array')
        print("Step ", i, obs)
    env.close()

In [None]:
main_mountaincar()

Code references for DQN:

https://github.com/taochenshh/dqn-pytorch

https://github.com/transedward/pytorch-dqn (for sampling from replay buffer)

CURL code: https://github.com/MishaLaskin/curl