In [1]:
%load_ext tensorboard

import math
import gym 
import plotly.express as px
import numpy as np
import warnings
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import random
from collections import Counter, deque
import pandas as pd
from sklearn.model_selection import train_test_split
import datetime
from sklearn.metrics import confusion_matrix
import plotly.figure_factory as ff

import os
import glob
from IPython import display as ipythondisplay
from tqdm.notebook import tqdm
from gym.wrappers import Monitor
from IPython.display import HTML
import base64
import io
import pickle
import torch
from torch import nn
import kornia.augmentation as kaug
warnings.filterwarnings("ignore")

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
if torch.cuda.is_available():
    display.start()

In [4]:
import os, sys, copy, argparse, shutil
def parse_arguments():
    parser = argparse.ArgumentParser(description='Deep Q Network Argument Parser')
    parser.add_argument('--seed', dest='seed', type=int, default=1)
    parser.add_argument('--env', dest='env', type=str, default='CartPole-v0')
    parser.add_argument('--save_interval', type=int, default=50, help='save model every n episodes')
    parser.add_argument('--log_interval', type=int, default=10, help='logging every n episodes')
    parser.add_argument('--render', help='render', type=int, default=1)
    parser.add_argument('--batch_size', help='batch_size', type=int, default=32)
    parser.add_argument('--train_freq', help='train_frequency', type=int, default=1)
    parser.add_argument('--max_episode', help='maximum episode', type=int, default=None)
    parser.add_argument('--max_timesteps', help='maximum timestep', type=int, default=100000000)
    parser.add_argument('--lr', dest='lr', type=float, default=0.00025)
    parser.add_argument('--lr_decay', action='store_true', help='decay learning rate')
    parser.add_argument('--gamma', help='discount_factor', type=float, default=0.99)
    parser.add_argument('--warmup_mem', type=int, help='warmup memory size', default=1000)
    parser.add_argument('--frame_skip', type=int, help='number of frames to skip for each action', default=3)
    parser.add_argument('--frame_stack', type=int, help='number of frames to stack', default=4)
    parser.add_argument('--memory', help='memory size', type=int, default=1000000)
    parser.add_argument('--initial_epsilon', '-ie', help='initial_epsilon', type=float, default=0.5)
    parser.add_argument('--final_epsilon', '-fe', help='final_epsilon', type=float, default=0.05)
    parser.add_argument('--max_epsilon_decay_steps', '-eds', help='maximum steps to decay epsilon', type=int, default=100000)
    parser.add_argument('--max_grad_norm', type=float, default=None, help='maximum gradient norm')
    parser.add_argument('--soft_update', '-su', action='store_true', help='soft update target network')
    parser.add_argument('--double_q', '-dq', action='store_true', help='enabling double DQN')
    parser.add_argument('--dueling_net', '-dn', action='store_true', help='enabling dueling network')
    parser.add_argument('--test', action='store_true', help='test the trained model')
    parser.add_argument('--tau', type=float, default=0.01, help='tau for soft target network update')
    parser.add_argument('--hard_update_freq', '-huf', type=int, default=2000, help='hard target network update frequency')
    parser.add_argument('--save_dir', type=str, default='./data')
    parser.add_argument('--resume_step', '-rs', type=int, default=None)
    return parser.parse_args()

In [5]:
#@title Set up constants for env and training
test = False 
save_dir = './data'
render = False
max_episode = None
max_timesteps = 100000000


In [6]:
#@title Augmentations
'''
color_jitter
random_elastic_transform
random_fisheye
random_color_equalize
random_gaussian_blur
random_gaussian_noise
random_horizontal_flip
random_color_invert
random_perspective_shift
random_shift
'''
color_jitter = kaug.ColorJitter(
        brightness=np.random.random(),
        contrast=np.random.random(),
        saturation=np.random.random(),
        hue=np.random.random(),
        p=0.5
        )
random_elastic_transform = kaug.RandomElasticTransform()
random_fisheye = kaug.RandomFisheye(
        center_x=torch.tensor([-.3, .3]).to(device),
        center_y=torch.tensor([-.3, .3]).to(device),
        gamma=torch.tensor([.9, 1.]).to(device),
        )
# need to divide by 255.0
random_color_equalize = lambda obs: kaug.RandomEqualize()(obs / 255.) * 255
random_gaussian_blur = kaug.RandomGaussianBlur(
        kernel_size=(9, 9),
        sigma = (5., 5.)
        )
random_gaussian_noise = kaug.RandomGaussianNoise()
random_horizontal_flip = kaug.RandomHorizontalFlip()
random_color_invert = kaug.RandomInvert()
random_perspective_shift = kaug.RandomPerspective()
get_random_shift = lambda h, w, shift_by: nn.Sequential(kaug.RandomCrop((h - shift_by, w - shift_by)), nn.ReplicationPad2d(20), kaug.RandomCrop((h - shift_by, w - shift_by)))

In [7]:
def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias

def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    from math import floor
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    h = floor( ((h_w[0] + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride) + 1)
    w = floor( ((h_w[1] + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride) + 1)
    return h, w

# for 84 x 84 inputs
OUT_DIM = {2: 39, 4: 35, 6: 31}
# for 64 x 64 inputs
OUT_DIM_64 = {2: 29, 4: 25, 6: 21}

''' TODO change the layer parameters ''' 
class PixelEncoder(nn.Module):
    """Convolutional encoder of pixels observations."""
    def __init__(self, obs_shape, input_channels, feature_dim=50, num_layers=3, num_filters=64, output_logits=False):
        super().__init__()

        assert len(obs_shape) == 3
        self.obs_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
        self.feature_dim = feature_dim
        self.num_layers = num_layers
        
        # 160, 210, 3
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=5, stride=5)
        conv1_shape = conv_output_shape(self.obs_shape[1:], kernel_size=5, stride=5)
        # Input to conv2: 32, 42, 32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        conv2_shape = conv_output_shape(conv1_shape, kernel_size=4, stride=2)
        # Input to conv3: 15, 20, 64
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        conv3_shape = conv_output_shape(conv2_shape, kernel_size=4, stride=1)
        # Output from conv3: 12, 17, 64

        # out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[num_layers]
        out_dims = conv3_shape
        self.fc = nn.Linear(num_filters * out_dims[0] * out_dims[1], self.feature_dim)
        self.ln = nn.LayerNorm(self.feature_dim)

        self.outputs = dict()
        self.output_logits = output_logits

    def reparameterize(self, mu, logstd):
        std = torch.exp(logstd)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward_conv(self, obs):
        self.outputs['obs'] = obs
        conv1 = torch.relu(self.conv1(obs))
        self.outputs['conv1'] = conv1
        conv2 = torch.relu(self.conv2(conv1))
        self.outputs['conv2'] = conv2
        conv3 = torch.relu(self.conv3(conv2))
        self.outputs['conv3'] = conv3

        h = conv3.reshape(conv3.size(0), -1)
        return h

    def forward(self, obs, detach=False):
        h = self.forward_conv(obs)

        if detach:
            h = h.detach()

        h_fc = self.fc(h)
        self.outputs['fc'] = h_fc

        h_norm = self.ln(h_fc)
        self.outputs['ln'] = h_norm

        if self.output_logits:
            out = h_norm
        else:
            out = torch.tanh(h_norm)
            self.outputs['tanh'] = out

        return out

    def copy_conv_weights_from(self, source):
        """Tie convolutional layers"""
        # only tie conv layers
        for i in range(self.num_layers):
            tie_weights(src=source.convs[i], trg=self.convs[i])

    def log(self, L, step, log_freq):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_encoder/%s_hist' % k, v, step)
            if len(v.shape) > 2:
                L.log_image('train_encoder/%s_img' % k, v[0], step)

        for i in range(self.num_layers):
            L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
        L.log_param('train_encoder/fc', self.fc, step)
        L.log_param('train_encoder/ln', self.ln, step)

In [8]:
def wrap_env(env, train=True):
    suffix = 'train' if train else 'test'
    monitor_dir = os.path.join(save_dir, 'monitor_%s' % suffix)
    os.makedirs(monitor_dir, exist_ok=True)
    if not train:
        video_save_interval = 10
        env = Monitor(env, directory=monitor_dir,
                      video_callable=lambda episode_id: episode_id % video_save_interval == 0,
                      force=True)
    else:
        if render:
            if max_episode is not None:
                video_save_interval = int(max_episode / 3)
            else:
                video_save_interval = int(max_timesteps / float(env._max_episode_steps) / 3)
            env = Monitor(env, directory=monitor_dir,
                          video_callable=lambda episode_id: episode_id % video_save_interval == 0,
                          force=True)
        else:
            env = Monitor(env, directory=monitor_dir, video_callable=False, force=True)
    return env

In [9]:
class ReplayMemory(object):
    def __init__(self, max_epi_num=50, max_epi_len=200, obs_shape=(210, 160)):
        # capacity is the maximum number of steps in memory
        self.max_epi_num = max_epi_num
        self.max_epi_len = max_epi_len
        # saves each tuple of (state, action, next state, reward)
        self.capacity = 100 # self.max_epi_num * max_epi_len
        self.idx = 0
        # Use 6 frames (stacking 2 x 3 frame obs)
        self.obs_memory = np.zeros((self.capacity, *obs_shape, 3)) # deque(maxlen=self.max_epi_num * max_epi_len)
        self.next_memory = np.zeros((self.capacity, *obs_shape, 3))
        self.act_memory = np.zeros((self.capacity, 1))
        self.reward_memory = np.zeros((self.capacity, 1))
        self.is_av = False
        self.current_epi = 0

    def reset(self):
        self.current_epi = 0
        self.memory.clear()

    ''' deprecated for tuple buffer '''
    def create_new_epi(self):
        pass

    def remember(self, state, next_state, action, reward):
        idx = self.idx % self.capacity
        self.obs_memory[idx] = state.copy()
        self.next_memory[idx] = next_state.copy()
        self.act_memory[idx] = action
        self.reward_memory[idx] = reward
        self.idx += 1

                
    # samples batch_size
    def sample(self, batch_size):
        if batch_size < self.idx:
            max_len = min(self.capacity, self.idx)
            idx = np.random.randint(0, max_len - 1, batch_size)
            return self.obs_memory[idx], self.next_memory[idx], self.act_memory[idx], self.reward_memory[idx]
        return self.obs_memory[:self.idx], self.next_memory[:self.idx], self.act_memory[:self.idx], self.reward_memory[:self.idx]

    def size(self):
        return self.idx

    def is_available(self):
        self.is_av = True
        if self.idx <= 1:
            self.is_av = False
        return self.is_av

    def print_info(self):
        pass

In [10]:
#@title Create a training conv agent
import torch.nn.functional as F

class DQNetworkConv(nn.Module):
    def __init__(self, in_channels, act_dim, dueling=False):
        super(DQNetworkConv, self).__init__()
        self.act_dim = act_dim
        self.dueling = dueling
        # 160, 210, 3 
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=5, stride=5)
        # Input to conv2: 32, 42, 32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # Input to conv3: 15, 20, 64
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        # Output from conv3: 12, 17, 64
        if self.dueling:
            self.v_fc4 = nn.Linear(12 * 17 * 64, 512)
            self.adv_fc4 = nn.Linear(12 * 17 * 64, 512)
            self.v_fc5 = nn.Linear(512, 1)
            self.adv_fc5 = nn.Linear(512, self.act_dim)
        else:
            self.fc4 = nn.Linear(12 * 17 * 64, 512)
            self.fc5 = nn.Linear(512, self.act_dim)
        self.parameters = (list(self.conv1.parameters())) + (list(self.conv2.parameters())) + (list(self.conv3.parameters())) + (list(self.fc4.parameters())) + (list(self.fc5.parameters()))

    def forward(self, st):
        out = F.relu(self.conv1(st))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out.view(out.size(0), -1)
        if self.dueling:
            val = F.relu(self.v_fc4(out))
            adv = F.relu(self.adv_fc4(out))
            val = self.v_fc5(val)
            adv = self.adv_fc5(adv)
            out = val.expand_as(adv) + adv - adv.mean(-1, keepdim=True).expand_as(adv)
        else:
            out = F.relu(self.fc4(out))
            out = self.fc5(out)
        return out

In [11]:
#@title Create a training FC agent
import torch.nn.functional as F

class DQNetworkFC(nn.Module):
    def __init__(self, z_dim, act_dim, dueling=False):
        super(DQNetworkFC, self).__init__()
        self.act_dim = act_dim
        self.input_dim = z_dim 
        self.dueling = dueling
        if self.dueling:
            self.v_fc1 = nn.Linear(z_dim, 512)
            self.adv_fc1 = nn.Linear(z_dim, 512)
            self.v_fc2 = nn.Linear(512, 1)
            self.adv_fc2 = nn.Linear(512, 256)
            self.v_fc3 = nn.Linear(256, 1)
            self.adv_fc3 = nn.Linear(256, self.act_dim)
        else:
            self.fc1 = nn.Linear(z_dim, 512)
            self.fc2 = nn.Linear(512, 256)
            self.fc3 = nn.Linear(256, self.act_dim)

    def forward(self, st):
        out = F.relu(self.fc1(st))
        out = F.relu(self.fc2(out))
        ''' Do we need a relu on the last layer if the output is probability over action space? '''
        out = F.relu(self.fc3(out))
        return out

In [12]:
def process_obs(obs, divide=False):
    obs = torch.Tensor(obs / 255. if divide else obs)
    if len(obs.shape) < 4:
        obs = obs.unsqueeze(0)
    obs = obs.permute(0, 3, 1, 2)
    return obs.to(device)

In [13]:
def take_action(env, action):
    state, rew, done, _ = env.step(action)
    obs = env.render(mode='rgb_array')
    return obs, rew, done

In [14]:
MAX_STEPS = 200

In [15]:
class CURL(nn.Module):
    """
    CURL
    """

    def __init__(self, obs_shape, z_dim, batch_size, encoder, output_type="continuous", critic=None, critic_target=None):
        super(CURL, self).__init__()
        self.obs_shape = obs_shape
        self.batch_size = batch_size

        # self.encoder = critic.encoder
        self.encoder = encoder 

        # self.encoder_target = critic_target.encoder 
        self.fc1 = nn.Linear(z_dim * 2, 50)
        self.fc2 = nn.Linear(50, 1)

        # self.W = nn.Parameter(torch.rand(z_dim, z_dim))
        self.output_type = output_type

    def encode(self, x, detach=False, ema=False):
        """
        Encoder: z_t = e(x_t)
        :param x: x_t, x y coordinates
        :return: z_t, value in r2
        """
        if ema:
            with torch.no_grad():
                z_out = self.encoder_target(x)
        else:
            z_out = self.encoder(x)

        if detach:
            z_out = z_out.detach()
        return z_out

    def compute_logits(self, z_a, z_mod):
        """
        Uses logits trick for CURL:
        - compute (B,B) matrix z_a (W z_pos.T)
        - positives are all diagonal elements
        - negatives are all other elements
        - to compute loss use multiclass cross entropy with identity matrix for labels
        """
#         Wz = torch.matmul(self.W, z_mod.T)  # (z_dim,B)
#         logits = torch.matmul(z_a, Wz)  # (B,B)
#         logits = logits - torch.max(logits, 1)[0][:, None]
#         return logits
        input_zs = torch.cat([z_a, z_mod], 1)
        logits = F.relu(self.fc1(input_zs))
        logits = F.sigmoid(self.fc2(logits))
        return logits

In [16]:
#@title Generate a batch of negatively labelled examples given observations

def generate_negatives(obs):
    neg_idx = np.random.randint(len(obs), size=len(obs))
    pos_idx = np.arange(len(obs))
    resample = (neg_idx == pos_idx)
    for (i, r) in enumerate(resample):
        if r:
            idx = neg_idx[i]
        else:
            idx = np.random.randint(0, len(obs), 1)[0]
            while idx == i:
                idx = np.random.randint(0, len(obs), 1)[0]
        neg_idx[i] = idx
    return (obs[neg_idx]).copy()


In [17]:
#@title Create a training agent (wrapper for conv agent)

GAMMA = 0.99

class Agent(object):
    def __init__(self, act_dim, z_dim, in_channels=3, max_epi_num=50, max_epi_len=300, CURL=None, aug=None, conv_net=False, random_shift=None):
        self.N_action = act_dim
        self.max_epi_num = max_epi_num
        self.max_epi_len = max_epi_len
        ''' To decide when to copy weights to the target network '''
        self.num_param_updates = 0
        self.CURL = CURL
        self.aug = aug
        self.random_shift = random_shift
        if conv_net:
            self.conv_net = DQNetworkConv(in_channels, act_dim).to(device)
            self.target = DQNetworkConv(in_channels, act_dim).to(device)
        else:
            ''' if using the encoder head for contrastive loss '''
            self.conv_net = DQNetworkFC(z_dim, act_dim).to(device)
            self.target = DQNetworkFC(z_dim, act_dim).to(device)
        self.buffer = ReplayMemory(max_epi_num=self.max_epi_num, max_epi_len=self.max_epi_len, obs_shape=CURL.obs_shape[:2])
        self.gamma = 0.99
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(
            list(self.conv_net.parameters()) + 
            list(self.CURL.parameters()) + 
            list(self.CURL.encoder.parameters()), lr=1e-3)

    def remember(self, state, action, reward, next_state):
        self.buffer.remember(state, next_state, action, reward)

    ''' Copy the weights to the target network every 100 updates '''
    def train(self, batch_size=32, target_update_freq=100, use_encoder=True):
        if self.buffer.is_available():
            obs, next_obs, action_list, reward_list = self.buffer.sample(batch_size)
            
            ''' Pass through the encoder to get encodings
             If also training the contrastive loss
             include that here! 
             1. data augmentation to create pos and negative pairs
             2. encoder 
             3. update encoder loss function (using a separate optimizer) or add to the loss computed below
            '''

            losses = []

            # check if obs is a numpy or a torch tensor
            if use_encoder:
                obs_anchor = self.random_shift(process_obs(obs.copy()))
                obs_pos = self.random_shift(process_obs(obs.copy()))
                mixed_obs = generate_negatives(obs)
                mixed_obs = process_obs(mixed_obs)
                obs_neg = self.random_shift(mixed_obs)
                if self.aug is not None:
                    obs_pos = self.aug(obs_pos)
                    obs_neg = self.aug(obs_neg)
                
                # TODO: separae 6 channels stacked frames into 2 sets of 3 channel inputs
                # obs.shape, (1, 6, 380, 580)
                z_a = self.CURL.encode(obs_anchor[:, :3, :, :])
                
                z_pos = self.CURL.encode(obs_pos[:, :3, :, :])
                
                # Mix pairs to generate negative labels
                z_neg = self.CURL.encode(obs_neg[:, :3, :, :])
                
                next_obs = self.random_shift(process_obs(next_obs.copy()))
                z_next = self.CURL.encode(next_obs[:, :3, :, :])

                # logits = self.CURL.compute_logits(z_a, z_pos)
                # labels = torch.arange(logits.shape[0]).long().to(device)
                
                # TODO: concatenate z_a1 and z_a2 into z_a ...
                pos_logits = self.CURL.compute_logits(z_a, z_pos)
                neg_logits = self.CURL.compute_logits(z_a, z_neg)
                # [32, 32]
                pos_labels = torch.ones((pos_logits.shape[0], 1))# .long()
                neg_labels = torch.zeros((neg_logits.shape[0], 1))# .long() 
                # TODO: stack pos and neg logits and labels (double check dim)
                logits = torch.cat([pos_logits, neg_logits], 0)
                labels = torch.cat([pos_labels, neg_labels], 0).to(device)
                
                # pass into the loss function
                # encoding_loss = nn.CrossEntropyLoss()(logits, labels)
                encoding_loss = nn.BCELoss()(logits, labels)

                ''' Combine encoding loss with rl loss below '''
                losses.append(encoding_loss)

                # Then pass that encoding through the conv_net to get Q value estimates
                Qs = self.conv_net(z_a)
                next_Qs = self.target(z_next).detach().max(1)[0]
            
            else:
                ''' If not using the encoder, pass the obs directly to the CNN '''
                obs = process_obs(obs)
                # estimate current q values from observations
                Qs = self.conv_net(obs)
                # find next max q values based on next observations
                next_Qs = self.target(next_obs).detach().max(1)[0]
            
            ''' find target q values ''' 
            next_Qs = next_Qs.cpu().numpy() 
            Qs = torch.gather(Qs, dim=1, index=torch.tensor(action_list, dtype=torch.int64).to(device)).float().to(device)
            target_Qs = torch.tensor(reward_list.squeeze(-1) + GAMMA * next_Qs).float().to(device)
            ''' try to set Qs equal to target_Qs '''
            q_loss = self.loss_fn(Qs, target_Qs)
            losses.append(q_loss)
            
            ''' Loss update for q network and encoder head '''
            losses = torch.stack(losses).sum()
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            self.num_param_updates += 1
            if self.num_param_updates % target_update_freq == 0:
                self.target.load_state_dict(self.conv_net.state_dict())

    # TODO: check the sizes of inputs and outputs
    def get_action(self, obs, epsilon, use_encoding=True):
        ''' 
         If using an encoder, need to pass that thorugh the encoder
         then use the encoding to pass through self.conv_net
        '''
        # obs = torch.tensor(obs)
        if use_encoding:
            obs = self.random_shift(process_obs(obs.copy()))
            # obs1 [1, 512] shape
            obs = self.CURL.encode(obs[:, :3, :, :], detach=True)
            # obs2 = self.CURL.encode(obs[:, 3:, :, :], detach=True)
            # TODO: separae 6 channels stacked frames into 2 sets of 3 channel inputs
            # Concatenate into [1, 1024]
            # obs = torch.hstack((obs1, obs2))

        # Dividing obs by 255 is handled in encoder forward (only needed for use_encoding=False)
        if len(obs.shape) == 1:
            obs = obs.unsqueeze(0)

        # epsilon greedy for selecting which action to take
        if random.random() > epsilon:
            qs = self.conv_net(obs)
            action = qs[0].argmax().data.item()
        else:
            action = random.randint(0, self.N_action-1)

        return action

def get_decay(epi_iter):
    decay = math.pow(0.99, epi_iter)
    return decay

In [18]:
def main(aug=None, train_curve_filename_prefix="default_curl_cartpole"):
    env = gym.make('CartPole-v0')
    env.reset()
    max_epi_iter = 1000
    max_MC_iter = 200
    obs = env.render(mode='rgb_array')
    obs_shape = obs.shape
    shift_by = 20
    random_shift = get_random_shift(*obs_shape[:2], shift_by)
    cropped_obs_shape = (obs_shape[0] - shift_by, obs_shape[1] - shift_by, obs_shape[2])
    
    ''' Replace the pixel encoder with the pretrained resnet18 encoder '''
    resnet18 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True).to(device)
    # [0~61] 62 layers in total
    for (i, param) in enumerate(resnet18.parameters()):
        if i < 57:
            param.requires_grad = False
    
    z_dim = resnet18.fc.in_features
    resnet18.fc = nn.Flatten()
    CURL_encoder = CURL(obs_shape=obs_shape, z_dim=z_dim, batch_size=1, encoder=resnet18, output_type="continuous").to(device)
    agent = Agent(act_dim=env.action_space.n, z_dim=z_dim, max_epi_num=max_epi_iter, max_epi_len=max_MC_iter, CURL=CURL_encoder, aug=aug, random_shift=random_shift)
    train_curve = []
    train_steps = []
    exploration_rate = 0.8
    
    for epi_iter in range(max_epi_iter):
        random.seed()
        env.reset()
        obs = env.render(mode='rgb_array')
        ''' For step 0, copy the same observation frames 2 times '''
        # 400, 600, 6
        # stacked_obs = np.concatenate([obs, obs], -1)
        returns = 0.0
        steps = 0
        for MC_iter in range(max_MC_iter):
            exploration_rate *= get_decay(epi_iter)
            exploration_rate = min(0.1, exploration_rate)
            action = agent.get_action(obs, exploration_rate)
            next_obs, reward, done = take_action(env, action)
            returns += reward * agent.gamma ** (MC_iter)
            
            ''' Stack 2 observation frames for input and 2 for output'''
            # stacked_next = np.concatenate([stacked_obs[:,:,-3:], next_obs], -1)
            # agent.remember(stacked_obs, action, reward, stacked_next)
            # obs = next_obs.copy()
            # stacked_obs = stacked_next.copy()
            
            if done or MC_iter >= max_MC_iter-1:
                if MC_iter < max_MC_iter - 1:
                    # penalize for early termination
                    reward = 0 # - max_epi_iter
                steps = MC_iter
                
                
            # reward now includes the penalty for early termination
            # stacked_next = np.concatenate([stacked_obs[:,:,-3:], next_obs], -1)
            agent.remember(obs, action, reward, next_obs)
            obs = next_obs.copy()
            # stacked_obs = stacked_next.copy()
            
            if done:
                break 
                
        print('Episode', epi_iter, 'returns', returns, 'after', steps+1, 'timesteps')
        if epi_iter % 1 == 0:
            train_curve.append(returns)
            train_steps.append(steps)
        if epi_iter % 100 == 0:
            print(f"Saving at episode {epi_iter}")
            np.save(f'{train_curve_filename_prefix}_{max_MC_iter}MC_{max_epi_iter}steps_trainable_resnet18_no_penalty_no_stacking', np.array(train_steps))
            np.save(f'{train_curve_filename_prefix}_{max_MC_iter}MC_{max_epi_iter}returns_trainable_resnet18_no_penalty_no_stacking', np.array(train_curve))
        if agent.buffer.is_available():
            for _ in range(1):
                agent.train()
    env.close()
    np.save(f'{train_curve_filename_prefix}_{max_MC_iter}MC_{max_epi_iter}_steps_trainable_resnet18_no_penalty_no_stacking', np.array(train_curve))
    print(train_curve)

In [19]:
main()

Using cache found in /home/ubuntu/.cache/torch/hub/pytorch_vision_v0.10.0


Episode 0 returns 9.561792499119552 after 10 timesteps
Saving at episode 0
Episode 1 returns 9.561792499119552 after 10 timesteps
Episode 2 returns 8.64827525163591 after 9 timesteps
Episode 3 returns 8.64827525163591 after 9 timesteps
Episode 4 returns 8.64827525163591 after 9 timesteps
Episode 5 returns 11.361512828387072 after 12 timesteps
Episode 6 returns 9.561792499119552 after 10 timesteps
Episode 7 returns 7.72553055720799 after 8 timesteps
Episode 8 returns 8.64827525163591 after 9 timesteps
Episode 9 returns 9.561792499119552 after 10 timesteps
Episode 10 returns 7.72553055720799 after 8 timesteps
Episode 11 returns 8.64827525163591 after 9 timesteps
Episode 12 returns 9.561792499119552 after 10 timesteps
Episode 13 returns 7.72553055720799 after 8 timesteps
Episode 14 returns 7.72553055720799 after 8 timesteps
Episode 15 returns 8.64827525163591 after 9 timesteps
Episode 16 returns 9.561792499119552 after 10 timesteps
Episode 17 returns 8.64827525163591 after 9 timesteps
Epi

Episode 148 returns 8.64827525163591 after 9 timesteps
Episode 149 returns 9.561792499119552 after 10 timesteps
Episode 150 returns 9.561792499119552 after 10 timesteps
Episode 151 returns 8.64827525163591 after 9 timesteps
Episode 152 returns 9.561792499119552 after 10 timesteps
Episode 153 returns 8.64827525163591 after 9 timesteps
Episode 154 returns 8.64827525163591 after 9 timesteps
Episode 155 returns 9.561792499119552 after 10 timesteps
Episode 156 returns 9.561792499119552 after 10 timesteps
Episode 157 returns 9.561792499119552 after 10 timesteps
Episode 158 returns 8.64827525163591 after 9 timesteps
Episode 159 returns 9.561792499119552 after 10 timesteps
Episode 160 returns 8.64827525163591 after 9 timesteps
Episode 161 returns 10.466174574128356 after 11 timesteps
Episode 162 returns 9.561792499119552 after 10 timesteps
Episode 163 returns 10.466174574128356 after 11 timesteps
Episode 164 returns 9.561792499119552 after 10 timesteps
Episode 165 returns 8.64827525163591 afte

Episode 295 returns 9.561792499119552 after 10 timesteps
Episode 296 returns 9.561792499119552 after 10 timesteps
Episode 297 returns 8.64827525163591 after 9 timesteps
Episode 298 returns 9.561792499119552 after 10 timesteps
Episode 299 returns 9.561792499119552 after 10 timesteps
Episode 300 returns 9.561792499119552 after 10 timesteps
Saving at episode 300
Episode 301 returns 8.64827525163591 after 9 timesteps
Episode 302 returns 9.561792499119552 after 10 timesteps
Episode 303 returns 8.64827525163591 after 9 timesteps
Episode 304 returns 9.561792499119552 after 10 timesteps
Episode 305 returns 9.561792499119552 after 10 timesteps
Episode 306 returns 9.561792499119552 after 10 timesteps
Episode 307 returns 9.561792499119552 after 10 timesteps
Episode 308 returns 8.64827525163591 after 9 timesteps
Episode 309 returns 8.64827525163591 after 9 timesteps
Episode 310 returns 9.561792499119552 after 10 timesteps
Episode 311 returns 9.561792499119552 after 10 timesteps
Episode 312 returns

Episode 441 returns 9.561792499119552 after 10 timesteps
Episode 442 returns 8.64827525163591 after 9 timesteps
Episode 443 returns 9.561792499119552 after 10 timesteps
Episode 444 returns 9.561792499119552 after 10 timesteps
Episode 445 returns 9.561792499119552 after 10 timesteps
Episode 446 returns 8.64827525163591 after 9 timesteps
Episode 447 returns 8.64827525163591 after 9 timesteps
Episode 448 returns 8.64827525163591 after 9 timesteps
Episode 449 returns 8.64827525163591 after 9 timesteps
Episode 450 returns 9.561792499119552 after 10 timesteps
Episode 451 returns 10.466174574128356 after 11 timesteps
Episode 452 returns 8.64827525163591 after 9 timesteps
Episode 453 returns 9.561792499119552 after 10 timesteps
Episode 454 returns 7.72553055720799 after 8 timesteps
Episode 455 returns 9.561792499119552 after 10 timesteps
Episode 456 returns 9.561792499119552 after 10 timesteps
Episode 457 returns 8.64827525163591 after 9 timesteps
Episode 458 returns 8.64827525163591 after 9 t

Episode 588 returns 9.561792499119552 after 10 timesteps
Episode 589 returns 8.64827525163591 after 9 timesteps
Episode 590 returns 9.561792499119552 after 10 timesteps
Episode 591 returns 9.561792499119552 after 10 timesteps
Episode 592 returns 9.561792499119552 after 10 timesteps
Episode 593 returns 9.561792499119552 after 10 timesteps
Episode 594 returns 8.64827525163591 after 9 timesteps
Episode 595 returns 9.561792499119552 after 10 timesteps
Episode 596 returns 9.561792499119552 after 10 timesteps
Episode 597 returns 8.64827525163591 after 9 timesteps
Episode 598 returns 8.64827525163591 after 9 timesteps
Episode 599 returns 9.561792499119552 after 10 timesteps
Episode 600 returns 8.64827525163591 after 9 timesteps
Saving at episode 600
Episode 601 returns 8.64827525163591 after 9 timesteps
Episode 602 returns 7.72553055720799 after 8 timesteps
Episode 603 returns 9.561792499119552 after 10 timesteps
Episode 604 returns 8.64827525163591 after 9 timesteps
Episode 605 returns 8.648

Episode 735 returns 9.561792499119552 after 10 timesteps
Episode 736 returns 9.561792499119552 after 10 timesteps
Episode 737 returns 8.64827525163591 after 9 timesteps
Episode 738 returns 9.561792499119552 after 10 timesteps
Episode 739 returns 9.561792499119552 after 10 timesteps
Episode 740 returns 8.64827525163591 after 9 timesteps
Episode 741 returns 7.72553055720799 after 8 timesteps
Episode 742 returns 9.561792499119552 after 10 timesteps
Episode 743 returns 8.64827525163591 after 9 timesteps
Episode 744 returns 7.72553055720799 after 8 timesteps
Episode 745 returns 9.561792499119552 after 10 timesteps
Episode 746 returns 8.64827525163591 after 9 timesteps
Episode 747 returns 9.561792499119552 after 10 timesteps
Episode 748 returns 9.561792499119552 after 10 timesteps
Episode 749 returns 8.64827525163591 after 9 timesteps
Episode 750 returns 9.561792499119552 after 10 timesteps
Episode 751 returns 10.466174574128356 after 11 timesteps
Episode 752 returns 9.561792499119552 after 

Episode 882 returns 9.561792499119552 after 10 timesteps
Episode 883 returns 7.72553055720799 after 8 timesteps
Episode 884 returns 9.561792499119552 after 10 timesteps
Episode 885 returns 8.64827525163591 after 9 timesteps
Episode 886 returns 8.64827525163591 after 9 timesteps
Episode 887 returns 10.466174574128356 after 11 timesteps
Episode 888 returns 9.561792499119552 after 10 timesteps
Episode 889 returns 8.64827525163591 after 9 timesteps
Episode 890 returns 7.72553055720799 after 8 timesteps
Episode 891 returns 7.72553055720799 after 8 timesteps
Episode 892 returns 8.64827525163591 after 9 timesteps
Episode 893 returns 7.72553055720799 after 8 timesteps
Episode 894 returns 8.64827525163591 after 9 timesteps
Episode 895 returns 9.561792499119552 after 10 timesteps
Episode 896 returns 9.561792499119552 after 10 timesteps
Episode 897 returns 8.64827525163591 after 9 timesteps
Episode 898 returns 7.72553055720799 after 8 timesteps
Episode 899 returns 7.72553055720799 after 8 timeste

Code references for DQN:

https://github.com/taochenshh/dqn-pytorch

https://github.com/transedward/pytorch-dqn (for sampling from replay buffer)

CURL code: https://github.com/MishaLaskin/curl