<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/curiousity_lstm_back.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### setup

In [None]:
# # https://github.com/kimhc6028/pytorch-noreward-rl
# https://stackoverflow.com/questions/67808779/running-gym-atari-in-google-colab
%pip install -U gym
%pip install -U gym[atari,accept-rom-license]
# !pip install gym[box2d]
import gym

!pip install gym-super-mario-bros nes-py
# https://github.com/Kautenja/gym-super-mario-bros
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
# env = gym_super_mario_bros.make('SuperMarioBros-v0')
# env = JoypadSpace(env, SIMPLE_MOVEMENT)

!pip install colabgymrender
!pip install perceiver-pytorch

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

log=False
# !pip install wandb
# import wandb
# wandb.login() # 
# wandb.init(project="curiousity_simple", entity="bobdole")
# log=True

!pip install einops
from math import pi, log
from functools import wraps
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat



Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# functions

#### gym wrappers

In [None]:
# https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/2_gym_wrappers_saving_loading.ipynb
import gym
class SparseEnv(gym.Wrapper): #https://alexandervandekleut.github.io/gym-wrappers/
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.total_rewards = 0
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        self.total_rewards += reward
        if done: reward = self.total_rewards
        else: reward = 0
        return observation, reward, done, info
    def reset(self):
        self.total_rewards = 0
        return self.env.reset()
# env = SparseEnv(gym.make("LunarLander-v2"))

class MarioSparse(gym.Wrapper):
    def __init__(self, env):
        # super().__init__(env)
        super(MarioSparse, self).__init__(env)
        self.env = env
        self.total_score = 0
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        life = info['life']
        score = info['score']
        self.total_score += score
        if life<2:
            print("MarioSparse: died")
            # return observation, score, True, info # lost one life, end env
            done = True
        # else:
            # self.total_score = 0
        return observation, score, done, info
    def reset(self):
        self.total_score = 0
        return self.env.reset()
# env = MarioSparse(env)

class MarioEarlyStop(gym.Wrapper):
    def __init__(self, env):
        # super().__init__(env)
        super(MarioEarlyStop, self).__init__(env)
        self.env = env
        self.max_pos = 0
        self.count_step = 0
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        x_pos = info['x_pos']
        if x_pos <= self.max_pos: self.count_step += 1
        else:
            self.max_pos = x_pos
            self.count_step = 0
        if self.count_step > 500:
            print("MarioEarlyStop: early stop ", self.max_pos)
            # return observation, reward, True, info # early stop
            done = True
        # else:
        return observation, reward, done, info
    def reset(self):
        self.max_pos = 0
        self.count_step = 0
        return self.env.reset()
# env = MarioEarlyStop(env)


#### cnn

In [None]:

class Conv_Encoder(nn.Module):
    # def __init__(self):
    def __init__(self, in_channels=1):
        super(Conv_Encoder, self).__init__()
        self.conv_encoder = nn.Sequential( # embed pi (240, 256, 3) -> 256 when flattened
            nn.Conv2d(in_channels, 8, 3, stride=2, padding=1), nn.ELU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 16, 5, stride=2, padding=2), nn.ELU(),
            nn.AdaptiveAvgPool2d((64,64)),
            nn.Conv2d(16, 8, 7, stride=2, padding=3), nn.ELU(),
            nn.Conv2d(8, 1, 5, stride=2, padding=2), nn.ELU(),
            # # nn.Conv2d(in_channels, out_channels=1, kernel_size=3, stride=2, padding=1),
            # nn.ReLU(),
            )
    def forward(self, x): # in [4, 3, 224, 224]
        x = self.conv_encoder(x)
        # x = x.view(-1, 16 * 5 * 5)
        return x # out [4, 1, 56, 56]


# models

#### model simplier

In [None]:
# model.py
# https://github.com/kimhc6028/pytorch-noreward-rl/blob/master/model.py
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ActorCritic(torch.nn.Module):
    def __init__(self, in_shape, action_space):
        super(ActorCritic, self).__init__()
        self.in_dim = in_shape # mario (240, 256)
        self.conv = nn.Sequential( # embed pi
            nn.Conv2d(in_shape[0], 32, 3, stride=2, padding=1), nn.ELU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ELU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ELU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ELU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ELU(), # added for RuntimeError: Input batch size 2 doesn't match hidden0 batch size 1
            )
        self.lstm = nn.LSTMCell(in_shape[1], 256)
        num_outputs = action_space.n
        self.critic_linear = nn.Linear(256, 1) # -> value
        self.actor_linear = nn.Linear(256, num_outputs) # -> action

        self.inv_lstm = nn.LSTMCell(in_shape[1], 256)
        self.fwd_lstm = nn.LSTMCell(in_shape[1], 256)
        self.inv_linear = nn.Sequential( # inv learning, predict at
            nn.Linear(in_shape[1] + in_shape[1], 256), nn.ReLU(),
            nn.Linear(256, num_outputs), nn.Softmax()
            ) # cat(phi(st), phi(st+1)) -> athat
        self.fwd_linear = nn.Sequential( # predict phi st+1
            nn.Linear(in_shape[1] + num_outputs, 256), nn.ReLU(),
            nn.Linear(256, in_shape[1])
            ) # cat(phi(st), at) -> phihat(st+1)

    def forward(self, inputs, icm):
        if icm == False: #A3C
            st, (a3c_hx, a3c_cx) = inputs # [1, 210, 160, 3], ([1, 256], [1, 256])
            vec_st = self.conv(st).view(-1, self.in_dim[1])
            a3c_hx1, a3c_cx1 = self.lstm(vec_st, (a3c_hx, a3c_cx))
            critic = self.critic_linear(a3c_hx1)
            actor = self.actor_linear(a3c_hx1)
            # print("forward A3C ",critic.shape, actor.shape, a3c_hx.shape, a3c_cx.shape)
            return critic, actor, (a3c_hx1, a3c_cx1) # [1, 1], [1, 18], ([1, 256], [1, 256])

        else: #icm
            (inv_hx, inv_cx), (fwd_hx, fwd_cx), st1, at = inputs
            vec_st1 = self.conv(st1).view(-1, self.in_dim[1])
            inv_hx1, inv_cx1 = self.inv_lstm(vec_st1, (icm_hx, icm_cx)) # inv model
            fwd_hx1, fwd_cx1 = self.fwd_lstm(vec_st1, (icm_hx, icm_cx)) # world model

            inv_vec = torch.cat((icm_hx, vec_st1), 1) # predict at
            fwd_vec = torch.cat((icm_hx, at), 1) # predict vec_st1
            inverse = self.inv_linear(inv_vec)
            forward = self.fwd_linear(fwd_vec)
            # print("forward icm ",vec_st1.shape, inverse.shape, forward.shape)
            return vec_st1, inverse, forward, (inv_hx1, inv_cx1), (fwd_hx1, fwd_cx1) # [1, 320], [1, 18], [1, 320], ()


#### back lstm

In [None]:

class ActorCritic(torch.nn.Module):
    def __init__(self, in_shape, action_space):
        super(ActorCritic, self).__init__()
        self.in_dim = in_shape # mario (240, 256, 3)
        self.conv = Conv(in_shape[2])#.to(device) # embed pi
        phist_size= 256
        hidden_size= 512
        # self.lstm = nn.LSTMCell(in_shape[1], 256)
        # print(in_shape[0]*in_shape[1]/4)
        
        self.conv_encoder = Conv_Encoder(in_shape[2])#.to(device)
        num_outputs = action_space.n
        self.lstmcell = nn.LSTMCell(phist_size, 512)

        self.actor_linear = nn.Linear(hidden_size, num_outputs) # vec_st -> action
        self.critic_linear = nn.Linear(hidden_size, 1) # vec_st -> value

        self.inv_lstm = nn.LSTMCell(phist_size, hidden_size)
        self.fwd_lstm = nn.LSTMCell(phist_size, 512)

        self.inv_linear = nn.Sequential( # inv learning, predict at
            # nn.Linear(in_shape[1] + in_shape[1], 512), nn.ReLU(),
            nn.Linear(hidden_size + phist_size, 512), nn.ReLU(),
            nn.Linear(512, num_outputs), nn.Softmax()
            ) # cat(phi(st), phi(st+1)) -> athat
        self.fwd_linear = nn.Sequential( # predict phi st+1
            # nn.Linear(in_shape[1] + num_outputs, 512), nn.ReLU(),
            nn.Linear(hidden_size + num_outputs, 512), nn.ReLU(),
            nn.Linear(512, phist_size)
            ) # cat(phi(st), at) -> phihat(st+1)

    def conv_encode(self, st):
        st = torch.transpose(st, 1,2)
        st = torch.transpose(st, 0,1) # [3, 240, 256] rgb, dim_x, dim_y
        phist = self.conv_encoder(st).flatten() # [256]
        # phist = phist.view(1,1,-1)
        phist = phist.view(1,-1)
        return phist # 256

    def forward(self, inputs, icm):
        if icm == False: #A3C
            st, (a3c_hx, a3c_cx) = inputs # [240, 256, 3]

            phist = self.conv_encode(st) # using cnn to encode
            # print(phist.shape, (a3c_hx.shape, a3c_cx.shape)) # [1, 1, 256] ([1, 512], [1, 512])
            vec_st, a3c_cx1 = self.lstmcell(phist, (a3c_hx, a3c_cx))

            critic = self.critic_linear(vec_st)
            actor = self.actor_linear(vec_st)
            # print(critic.shape,actor.shape)
            return critic, actor, (vec_st, a3c_cx1) # for cnn encode [1, 1], [1, 18], 
        else: #icm
            (inv_hx, inv_cx), (fwd_hx, fwd_cx), st1, at = inputs

            phist = self.conv_encode(st1).unsqueeze(0) # cnn [1, 1, 256]
            # print(phist.shape, self.inv_query.shape) #) torch.Size([1, 12]

            inv_hx1, inv_cx1 = self.inv_lstm(phist[0], (inv_hx, inv_cx)) # inv model
            fwd_hx1, fwd_cx1 = self.fwd_lstm(phist[0], (fwd_hx, fwd_cx)) # world model
            # print(torch.cat((inv_hx, phist[0]), 1).shape, torch.cat((fwd_hx, at), 1).shape)
            inv_vec = torch.cat((inv_hx1, phist[0]), 1) # predict at
            fwd_vec = torch.cat((fwd_hx1, at), 1) # predict vec_st1
            inverse = self.inv_linear(inv_vec)
            forward = self.fwd_linear(fwd_vec)

            # print("forward icm ",phist.shape, inverse.shape, forward.shape)
            # print("forward icm ",inverse, forward)
            return phist[0], inverse, forward, (inv_hx1, inv_cx1), (fwd_hx1, fwd_cx1) # [1, 320], [1, 18], [1, 320], ()


# wwwwwwwwwwwww

#### train

In [None]:
# train.py
# https://github.com/kimhc6028/pytorch-noreward-rl/blob/master/train.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

def train(env, args, model, optimizer=None):
    # torch.manual_seed(seed)
    # model = ActorCritic(env.observation_space.shape, env.action_space)
    if optimizer is None:
        optimizer = torch.optim.Adam(shared_model.parameters(), lr)
    model.train()
    for x in range(num_episodes):
        # model.load_state_dict(shared_model.state_dict()) # Sync with the shared model
        latent = (torch.zeros(1, 512).to(device), torch.zeros(1, 512).to(device))
        inv_latent = (torch.zeros(1, 512).to(device), torch.zeros(1, 512).to(device))
        fwd_latent = (torch.zeros(1, 512).to(device), torch.zeros(1, 512).to(device))
        values = []
        log_probs = []
        rewards = []
        entropies = []
        inverses = []
        forwards = []
        actions = []
        vec_st1s = []
        episode_length = 0

        state = env.reset()
        # state=state[:,:,0]
        state = torch.tensor(state.copy()).type(torch.float).to(device)
        # st1 = state.float()
        # print("#####www####",state.dtype,hx.dtype)
        while True:
            episode_length += 1
            value, logit, latent = model((state, latent), icm = False)
            prob = F.softmax(logit, dim=1)
            log_prob = F.log_softmax(logit, dim=1)
            entropy = -(log_prob * prob).sum(1) # of actor prob
            entropies.append(entropy.cpu())
            action = prob.multinomial(1).data
            log_prob = log_prob.gather(1, action)
            oh_action = torch.zeros(1, env.action_space.n)
            oh_action[0][action.item()] = 1.0
            at = oh_action
            actions.append(oh_action)
            state, reward, done, _ = env.step(action.item())
            state = torch.tensor(state.copy()).type(torch.float).to(device)
            # state=state[:,:,0]
            # print("reward",reward)
            done = done or episode_length >= max_episode_length
            # reward = max(min(reward, 1), -1) #why clip rewards?
            # st = st1
            # st1 = state.float()
            # vec_st1, inverse, forward, inv_latent, fwd_latent = model((inv_latent, fwd_latent, st1, at.to(device)), icm = True)            
            vec_st1, inverse, forward, inv_latent, fwd_latent = model((inv_latent, fwd_latent, state, at.to(device)), icm = True)            
            reward_intrinsic = eta * ((vec_st1 - forward).pow(2)).sum(1) / 2. # phist - predicted phist
            #reward_intrinsic = eta * ((vec_st1 - forward).pow(2)).sum(1).sqrt() / 2.
            # print("reward_intrinsic", reward_intrinsic)
            reward_intrinsic = reward_intrinsic.item()
            # print("ep ",x,", rwd ext: ", reward, " ,rwd int: ", reward_intrinsic.item())
            reward += reward_intrinsic
            values.append(value.cpu()) # from critic
            log_probs.append(log_prob.cpu())
            rewards.append(reward) # predicted state
            vec_st1s.append(vec_st1.cpu()) # encoded state
            inverses.append(inverse.cpu()) # predicted action
            forwards.append(forward.cpu()) # predicted state
            if done:
                print(episode_length)
                episode_length = 0
                break
        R = torch.zeros(1, 1)
        values.append(R)
        policy_loss = 0
        value_loss = 0
        inverse_loss = 0
        forward_loss = 0
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(rewards))):
            R = gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)
            # Generalized Advantage Estimataion
            # delta_t = rewards[i] + gamma * values[i + 1].data - values[i].data
            delta_t = torch.tensor(rewards[i]) + gamma * values[i + 1].data - values[i].data
            gae = gae * gamma * tau + delta_t
            policy_loss = policy_loss - log_probs[i] * gae - 0.01 * entropies[i]
            cross_entropy = - (actions[i] * torch.log(inverses[i] + 1e-15)).sum(1)
            inverse_loss = inverse_loss + cross_entropy
            forward_err = forwards[i] - vec_st1s[i]
            forward_loss = forward_loss + 0.5 * (forward_err.pow(2)).sum(1)
        # print(inverse_loss,forward_loss, policy_loss,value_loss)
        optimizer.zero_grad()
        # print("invvvvv",inverse_loss , forward_loss)
        # ((1-beta) * inverse_loss + beta * forward_loss).backward(retain_variables=True)
        inv_loss = (1-beta) * inverse_loss + beta * forward_loss
        pol_loss = lmbda * (policy_loss + 0.5 * value_loss)
        (inv_loss + pol_loss).backward()
        # (inv_loss + 0*pol_loss).backward()
        # (((1-beta) * inverse_loss + beta * forward_loss) + lmbda * (policy_loss + 0.5 * value_loss)).backward()
        # print(''.join([str(torch.argmax(a).item()) for a in actions]))
        print([torch.argmax(a).item() for a in actions])
        print("inv_loss: ", inv_loss.item(), " ,pol_loss: ", pol_loss.item())
        # if log:
        #     wandb.log({"inv_loss": inv_loss.item(), "pol_loss": pol_loss.item()})
        torch.nn.utils.clip_grad_norm(model.parameters(), 40)
        optimizer.step()
        del inv_loss, pol_loss, state
        del value, logit, latent, vec_st1, inverse, forward, inv_latent, fwd_latent
        del values, log_probs, rewards, entropies, inverses, forwards, actions, vec_st1s


#### test

In [None]:
# test.py
# https://github.com/kimhc6028/pytorch-noreward-rl/blob/master/test.py
import numpy as np
import torch
import torch.nn.functional as F
import time

def test(env, args, model):
    # torch.manual_seed(seed)
    # model = ActorCritic(env.observation_space.shape, env.action_space)
    # model.load_state_dict(shared_model.state_dict())
    model.eval()
    state = env.reset()
    # state = torch.from_numpy(state.copy()).type(torch.float).to(device)
    state = torch.tensor(state.copy()).type(torch.float).to(device)
    reward_sum = 0
    start_time = time.time()
    actions = []
    episode_length = 0
    result = []
    latent = None
    a3c_hx = torch.zeros(1, 512).to(device)
    a3c_cx = torch.zeros(1, 512).to(device)
    while True:
        episode_length += 1
        # value, logit, latent = model((state, latent), icm = False)
        value, logit, (a3c_hx, a3c_cx) = model((state, (a3c_hx, a3c_cx)), icm = False)
        prob = F.softmax(logit, dim=1) #from train
        action = prob.multinomial(1).data
        state, reward, done, _ = env.step(action.item())
        # state = torch.from_numpy(state.copy()).type(torch.float).to(device)
        state = torch.tensor(state.copy()).type(torch.float).to(device)

        done = done or episode_length >= max_episode_length
        # print("rwd ext: ", reward)
        reward_sum += reward
        actions.append(action[0])
        if done:
            end_time = time.time()
            print("Time {}, episode reward {}, episode length {}".format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(end_time - start_time)), reward_sum, episode_length))
            result.append((reward_sum, end_time - start_time))
            torch.save(model.state_dict(), 'model.pth')
            # print(''.join([str(a.item()) for a in actions]))
            print([a.item() for a in actions])
            break


#### main

In [None]:
# main.py
# https://github.com/kimhc6028/pytorch-noreward-rl/blob/master/main.py
# import os, sys, cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym

lr=0.01#0.001
gamma=1#0.99
tau=1.00
seed=1
num_processes=4
num_steps=20
max_episode_length=500 # 10000
env_name='PongDeterministic-v4'
# env_name='LunarLander-v2'
# env_name='MontezumaRevengeDeterministic-v4'
# env_name='MontezumaRevengeDeterministic-ram-v4'

no_shared=False
eta=0.01
beta=0.2
lmbda=0.1
outdir="output"
record='store_true'
num_episodes=20#100

torch.manual_seed(seed)
env = gym.make(env_name)
env = SparseEnv(env)
# env = gym_super_mario_bros.make('SuperMarioBros-v0')
# env = JoypadSpace(env, COMPLEX_MOVEMENT) # SIMPLE_MOVEMENT COMPLEX_MOVEMENT
# env = MarioSparse(env)
# env = MarioEarlyStop(env)
# query_environment("MountainCar-v0")

# print(env)
print(env.observation_space.shape, env.action_space) # (210, 160, 3) Discrete(18); mario complex (240, 256, 3) Discrete(12)

shared_model = ActorCritic(env.observation_space.shape, env.action_space).to(device)
# shared_model.share_memory()
if no_shared:
    optimizer = None
else:
    optimizer = torch.optim.Adam(shared_model.parameters(), lr=lr)
    # optimizer.share_memory()
args=None
# train(0, args, shared_model, optimizer)

# processes = []
# import torch.multiprocessing as mp
# p = mp.Process(target=test, args=(num_processes, args, shared_model))
# p.start()
# processes.append(p)
# for rank in range(0, num_processes):
#     p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
#     p.start()
#     processes.append(p)
# for p in processes:
#     p.join()



(210, 160, 3) Discrete(6)


#### run

In [None]:
max_episode_length=2000 # 10000


In [None]:
# train(env, args, shared_model)
# train(env, args, shared_model, optimizer)

for x in range(25):
    train(env, args, shared_model, optimizer)
test(env, args, shared_model)


# early stop ard = self.fwd_lstm(torch.cat((phi


765
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

KeyboardInterrupt: ignored

#### save

In [None]:

from google.colab import drive
drive.mount('/content/gdrive')
PATH="/content/gdrive/MyDrive/curious/" # for saving to google drive
name='model_mario_lstm_back.pth'
# PATH="/content/" # for saving on colab only
# name='model.pth'

# model=shared_model
# torch.save(model.state_dict(), PATH+name)

# model.load_state_dict(torch.load(PATH+name))
# shared_model=model


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


#### video

In [None]:

import gym
from colabgymrender.recorder import Recorder
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

# "MontezumaRevengeDeterministic-v4"
env = gym.make(env_name)
env = SparseEnv(env)
# env = gym_super_mario_bros.make('SuperMarioBros-v0')
# env = JoypadSpace(env, COMPLEX_MOVEMENT) # SIMPLE_MOVEMENT COMPLEX_MOVEMENT
# env = MarioSparse(env)
# env = MarioEarlyStop(env)
env = Recorder(env, './video')

state = env.reset()
# device='cpu'
# model = ActorCritic(env.observation_space.shape, env.action_space)#.to(device)
# model.load_state_dict(shared_model.state_dict())
# model.eval()
# latent = None
# torch.manual_seed(6)
x=0

# acts=[10, 2, 0, 5, 10, 0, 1, 8, 2, 5, 3, 8, 1, 9, 0, 0, 0, 2, 5, 0, 11, 6, 0, 3, 1, 2, 0, 6, 3, 1, 1, 8, 8, 1, 11, 2, 8, 8, 10, 1, 3, 10, 9, 11, 5, 2, 8, 9, 4, 7, 8, 0, 0, 6, 11, 5, 10, 9, 10, 1, 10, 7, 0, 6, 3, 10, 9, 11, 10, 9, 7, 3, 2, 0, 1, 7, 6, 3, 2, 0, 6, 2, 8, 11, 4, 1, 6, 11, 5, 11, 10, 0, 9, 1, 7, 3, 10, 6, 1, 0, 6, 8, 0, 0, 11, 9, 2, 6, 3, 0, 7, 10, 1, 3, 1, 9, 1, 3, 0, 3, 10, 10, 6, 1, 3, 6, 9, 9, 2, 7, 3, 10, 11, 11, 5, 10, 0, 1, 1, 1, 4, 9, 10, 0, 6, 7, 5, 2, 8, 0, 7, 8, 0, 6, 8, 7, 10, 10, 3, 2, 0, 9, 10, 6, 1, 3, 8, 8, 6, 0, 10, 1, 4, 9, 3, 7, 3, 6, 0, 3, 6, 5, 10, 0, 10, 0, 3, 1, 7, 10, 5, 0, 0, 1, 5, 9, 4, 3, 0, 0, 5, 8, 7, 3, 0, 0, 7, 10, 7, 1, 3, 2, 0, 3, 5, 7, 4, 6, 3, 1, 1, 7, 10, 10, 6, 11, 3, 1, 10, 1, 1, 0, 9, 0, 8, 10, 6, 7, 6, 0, 9, 3, 3, 0, 1, 6, 10, 1, 8, 0, 2, 0, 5, 0, 9, 4, 2, 5, 0, 3, 0, 8, 5, 0, 0, 8, 0, 1, 1, 9, 5, 3, 2, 4, 5, 1, 6, 5, 2, 11, 11, 11, 4, 2, 7, 6, 0, 6, 4, 1, 0, 10, 1, 3, 3, 9, 3, 3, 4, 3, 5, 1, 5, 10, 9, 6, 7, 1, 0, 9, 8, 11, 11, 0, 10, 0, 8, 10, 7, 5, 0, 6, 10, 3, 2, 3, 0, 1, 1, 5, 1, 0, 0, 6, 0, 8, 2, 8, 11, 2, 3, 0, 6, 1, 8, 3, 1, 9, 6, 7, 2, 1, 0, 8, 6, 3, 10, 4, 0, 2, 2, 2, 0, 10, 1, 3, 10, 1, 0, 3, 6, 10, 0, 3, 2, 7, 1, 5, 6, 7, 0, 1, 1, 9, 9, 9, 9, 9, 8, 0, 7, 9, 11, 1, 1, 0, 0, 0, 0, 3, 0, 1, 9, 5, 1, 5, 9, 1, 11, 1, 5, 6, 1, 1, 6, 5, 4, 1, 3, 0, 1, 6, 10, 3, 3, 0, 1, 0, 2, 1, 3, 3, 11, 6, 9, 2, 2, 9, 2, 3, 9, 9, 9, 10, 11, 3, 10, 0, 0, 3, 0, 1, 6, 7, 0, 9, 2, 8, 10, 4, 7, 11, 0, 1, 8, 2, 1, 6, 0, 4, 6, 0, 10, 8, 7, 10, 10, 6, 5, 10, 0, 0, 2, 0, 11, 1, 0, 0, 9, 4, 10, 1, 0, 6, 1, 4, 9, 10, 3, 0, 4, 9, 8, 0, 5, 8, 3, 6, 0, 6, 0, 0, 10, 1, 0, 3, 0, 0, 2, 3, 1, 1, 11, 3, 5, 11, 9, 8, 3, 0, 6, 3, 7, 2, 1, 11, 1, 6, 0, 10, 1, 7, 10, 8, 3, 8, 7, 6, 0, 9, 10, 0, 0, 5, 0, 10, 9, 0, 6, 0, 8, 11, 6, 9, 4, 7, 0, 0, 5, 6, 10, 11, 2, 4, 10, 1, 3, 10, 11, 9, 1, 8, 3, 11, 1, 11, 1, 11, 6, 1, 10, 3, 1, 2, 4, 10, 1, 3, 11, 4, 1, 10, 10, 4, 10, 0, 0, 0, 0, 8, 11, 6, 4, 6, 0, 6, 9, 0, 6, 10, 9, 1, 0, 3, 3, 6, 8, 6, 1, 9, 10, 8, 11, 9, 1, 0, 1, 0, 8, 9, 1, 9, 10, 3, 8, 5, 1, 9, 5, 10, 0, 9, 0, 10, 0, 2, 8, 0, 2, 6, 0, 5, 6, 0, 10, 10, 4, 4, 0, 6, 0, 11, 0, 9, 0, 8, 10, 9, 1, 8, 4, 10, 10, 0, 3, 8, 10, 0, 0, 7, 10, 10, 8, 3, 8, 4, 6, 8, 0, 5, 0, 9, 0, 7, 9, 8, 8, 9, 9, 0, 9, 4, 8, 10, 8, 9, 9, 0, 7, 3, 4, 0, 10, 6, 10, 3, 9, 4, 6, 8, 10, 0, 1, 5, 10, 10, 3, 9, 0, 0, 1, 8, 6, 7, 8, 6, 3, 6, 8, 0, 4, 0, 0, 10, 8, 11, 2, 3, 0, 2, 9, 0, 1, 0, 3, 5, 9, 0, 6, 9, 0, 10, 0, 0, 3, 11, 3, 2, 5, 2, 7, 11, 2, 5, 4, 10, 1, 0, 3, 8, 9, 10, 1, 1, 10, 5, 2, 10, 2, 1, 10, 4, 6, 7, 4, 11, 9, 4, 0, 2, 1, 6, 11, 1, 2, 0, 6, 1, 8, 6, 9, 10, 1, 0, 1, 4, 1, 1, 8, 1, 2, 3, 10, 6, 2, 10, 10, 8, 1, 11, 9, 0, 10, 1, 0, 2, 7, 9, 1, 3, 0, 8, 10, 8, 0, 1, 1, 0, 9, 10, 11, 0, 2, 6, 1, 6, 3, 8, 9, 6, 10, 8, 5, 0, 10, 10, 5, 0, 9, 1, 9, 7, 9, 0, 6, 1, 9, 1, 0, 5, 0, 0, 11, 0, 8, 1, 7, 7, 1, 6, 0, 8, 3, 5, 6, 9, 0, 2, 10, 3, 8, 10, 6, 0, 1, 0, 5, 9, 1, 0, 8, 4, 3, 11, 5, 1, 1, 11, 0, 6, 4, 10, 5, 5, 9, 0, 0, 8, 11, 4, 0, 6, 1, 3, 9, 0, 0, 7, 0, 11, 6, 10, 1, 3, 0, 0, 2, 7, 9, 10, 10, 3, 8, 11, 7, 1, 6, 0, 0, 6, 5, 10, 3, 0, 0, 0, 5, 3, 10, 6, 9, 1, 1, 0, 3, 9, 3, 1, 3, 9, 11, 10, 10, 1, 0, 10, 6, 6, 1, 7, 8, 3, 6, 11, 8, 3, 8, 1, 9, 2, 5, 0, 6, 9, 10, 6, 3, 1, 7, 0, 6, 3, 5, 0, 11, 0, 0, 4, 6, 1, 5, 10, 0, 1, 0, 7, 7, 8, 0, 1, 10, 0, 0, 3, 11, 3, 3, 0, 0, 8, 0, 3, 10, 0, 8, 8, 1, 10, 4, 2, 1, 10, 5, 5, 3, 9, 11, 9, 6, 2, 10, 0, 8, 10, 6, 7, 1, 0, 9, 0, 3, 5, 9, 10, 9, 2, 8, 6, 3, 8, 0, 8, 6, 3, 10, 1, 3, 5, 9, 8, 10, 9, 3, 10, 1, 10, 0, 6, 9, 10, 0, 3, 0, 10, 2, 4, 0, 9, 11, 0, 0, 0, 5, 8, 5, 6, 6, 8, 3, 6, 9, 6, 3, 0, 1, 10, 10, 6, 0, 9, 0, 3, 0, 11, 1, 7, 0, 1, 4, 10, 10, 9, 10, 2, 11, 9, 0, 7, 0, 8, 1]
# acts=[3, 1, 1, 2, 4, 10, 2, 1, 0, 4, 10, 4, 10, 11, 0, 0, 2, 0, 10, 0, 10, 0, 10, 2, 3, 0, 3, 4, 1, 10, 10, 10, 0, 10, 4, 0, 0, 10, 2, 2, 0, 10, 2, 10, 0, 4, 6, 10, 3, 2, 0, 9, 9, 4, 2, 10, 3, 10, 4, 0, 2, 0, 0, 10, 10, 2, 10, 0, 0, 4, 2, 0, 10, 0, 0, 2, 0, 0, 10, 0, 9, 9, 6, 3, 10, 3, 6, 0, 2, 4, 4, 3, 2, 10, 2, 1, 0, 3, 4, 3, 0, 10, 1, 3, 0, 10, 2, 9, 0, 9, 10, 3, 4, 10, 0, 3, 0, 0, 2, 7, 0, 0, 3, 2, 1, 3, 0, 2, 3, 9, 7, 0, 10, 0, 7, 4, 2, 2, 0, 0, 2, 2, 2, 10, 2, 3, 6, 1, 10, 10, 0, 2, 10, 10, 4, 2, 2, 2, 7, 4, 9, 0, 0, 0, 0, 2, 10, 0, 0, 6, 10, 2, 0, 6, 10, 0, 0, 10, 9, 0, 10, 0, 5, 3, 0, 10, 6, 3, 1, 0, 2, 10, 6, 0, 4, 2, 4, 4, 4, 1, 2, 0, 4, 0, 0, 3, 0, 1, 2, 0, 0, 3, 0, 0, 0, 10, 10, 2, 10, 0, 1, 2, 0, 10, 10, 0, 1, 6, 2, 2, 0, 0, 0, 0, 10, 10, 3, 0, 2, 4, 4, 1, 4, 2, 1, 3, 2, 3, 10, 3, 0, 2, 0, 10, 10, 0, 0, 2, 0, 10, 4, 2, 2, 10, 0, 3, 3, 2, 0, 2, 0, 2, 9, 4, 10, 10, 4, 0, 10, 9, 1, 6, 0, 0, 0, 10, 2, 0, 10, 2, 2, 0, 3, 10, 4, 2, 6, 0, 2, 10, 10, 2, 2, 6, 1, 4, 10, 3, 9, 0, 10, 2, 9, 0, 4, 4, 0, 0, 4, 2, 9, 10, 0, 2, 4, 1, 2, 1, 2, 4, 3, 2, 10, 0, 0, 10, 4, 0, 2, 9, 3, 4, 0, 10, 3, 0, 10, 0, 0, 2, 4, 10, 1, 0, 2, 0, 2, 1, 10, 1, 1, 6, 2, 0, 2, 2, 6, 10, 10, 0, 1, 10, 2, 10, 2, 3, 10, 11, 1, 1, 4, 1, 1, 2, 10, 3, 4, 0, 9, 0, 1, 0, 2, 0, 0, 2, 2, 4, 2, 11, 2, 2, 6, 0, 0, 0, 10, 2, 0, 0, 0, 0, 1, 0, 10, 4, 2, 0, 0, 2, 2, 2, 10, 0, 11, 9, 0, 2, 2, 9, 0, 2, 4, 10, 10, 4, 0, 7, 6, 2, 4, 0, 1, 0, 10, 2, 3, 10, 10, 2, 0, 0, 1, 10, 4, 4, 10, 3, 9, 1, 10, 0, 10, 4, 10, 0, 7, 0, 10, 2, 2, 0, 3, 0, 9, 4, 2, 10, 0, 2, 2, 4, 10, 0, 1, 4, 0, 10, 2, 0, 2, 10, 4, 2, 4, 2, 4, 10, 2, 2, 2, 2, 2, 6, 2, 0, 1, 0, 4, 6, 0, 0, 10, 0, 2, 4, 2, 0, 2, 10, 6, 0, 3, 10, 4, 11, 0, 4, 2, 1, 2, 2, 10, 2, 1, 0, 10, 3, 2, 0, 9, 0, 4, 2, 10, 0, 2, 10, 2, 1, 6, 10, 2, 0, 2, 0, 2, 4, 0, 0, 2, 10, 3, 0, 10, 0, 4, 4, 10, 2, 4, 8, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 10, 0, 4, 10, 0, 0, 1, 10, 9, 3, 4, 10, 0, 0, 4, 3, 8, 0, 10, 0, 0, 4, 9, 2, 0, 0, 2, 4, 1, 10, 4, 4, 0, 0, 2, 4, 4, 0, 4, 0, 1, 1, 4, 1, 6, 4, 3, 6, 10, 0, 4, 4, 9, 10, 2, 0, 2, 3, 10, 10, 10, 3, 3, 10, 2, 3, 0, 0, 2, 2, 10, 10, 4, 10, 2, 4, 10, 10, 2, 9, 2, 2, 0, 1, 3, 0, 8, 6, 0, 10, 1, 2, 4, 0, 9, 3, 4, 0, 4, 0, 0, 0, 10, 4, 2, 0, 1, 0, 1, 0, 10, 2, 10, 0, 10, 10, 3, 10, 3, 0, 10, 1, 3, 2, 10, 0, 0, 2, 10, 0, 3, 9, 2, 10, 10, 0, 2, 9, 9, 0, 0, 0, 9, 0, 0, 1, 4, 0, 6, 7, 3, 0, 10, 0, 2, 11, 3, 0, 2, 0, 10, 1, 3, 0, 10, 0, 0, 0, 3, 3, 0, 11, 2, 10, 9, 2, 0, 2, 3, 10, 2, 11, 3, 0, 0, 0, 6, 3, 0, 3, 2, 0, 2, 0, 1, 3, 2, 0, 2, 0, 1, 2, 10, 1, 2, 2, 0, 0, 0, 0, 0, 0, 2, 9, 3, 1, 9, 9, 8, 0, 2, 2, 10, 10, 3, 0, 4, 2, 10, 4, 3, 0, 0, 0, 2, 10, 2, 10, 10, 10, 3, 3, 4, 0, 4, 10, 0, 2, 2, 0, 10, 4, 9, 0, 0, 1, 0, 2, 2, 0, 2, 0, 0, 10, 9, 8, 2, 4, 3, 2, 0, 2, 2, 3, 0, 0, 0, 10, 0, 0, 1, 10, 0, 8, 3, 10, 4, 0, 0, 10, 0, 0, 10, 10, 0, 2, 0, 4, 3, 2, 0, 0, 0, 0, 4, 0, 2, 10, 2, 2, 0, 0, 0, 1, 2, 3, 9, 10, 9, 10, 2, 10, 0, 0, 10, 0, 0, 0, 2, 10, 0, 10, 2, 0, 10, 9, 7, 2, 0, 2, 11, 0, 0, 10, 9, 0, 10, 0, 0, 0, 10, 0, 10, 0, 0, 10, 4, 1, 2, 1, 4, 0, 10, 2, 0, 8, 10, 10, 2, 9, 2, 3, 10, 0, 10, 9, 0, 2, 10, 10, 2, 0, 4, 1, 3, 0, 0, 0, 9, 6, 10, 2, 10, 2, 11, 0, 4, 0, 0, 2, 9, 0, 10, 0, 10, 3, 6, 2, 10, 3, 0, 10, 9, 10, 10, 10, 1, 10, 0, 2, 0, 0, 10, 0, 1, 4, 0, 10, 0, 10, 10, 2, 2, 4, 9, 3, 4, 0, 0, 4, 10, 11, 0, 0, 0, 9, 4, 3, 2, 9, 0, 10, 10, 3, 4, 4, 10, 7, 0, 0, 2, 0, 0, 2, 0, 0, 1, 10, 0, 0, 0, 0, 2, 10, 9, 1, 0, 3, 2, 2, 0, 3, 4, 10, 10, 10, 2, 2, 10, 0, 10, 2, 4, 2, 10, 10, 0, 2, 2, 10, 0, 0, 2, 2, 1, 11, 9, 0, 10, 4, 3, 1, 0, 10, 9, 2, 8, 0, 0, 0, 1, 0, 7, 4, 0, 0, 0, 0, 0, 10, 10, 0, 2, 9, 0, 0, 6, 4, 4, 10, 10, 10, 2, 2, 0, 3]
# acts=[0, 0, 2, 0, 0, 0, 2, 4, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 9, 9, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 9, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 0, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 2, 0, 4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 4, 0, 10, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 4, 7, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0, 0, 0, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 10, 0]
# acts=[3, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 0, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 0, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 0, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 1, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 6, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 0, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 4, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]


acts=[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]



while True:
    # state = torch.from_numpy(state.copy()).type(torch.float)#.to(device)
    # value, logit, latent = model((state, latent), icm = False)
    # prob = F.softmax(logit, dim=1) #from train
    # action = prob.multinomial(1).data
    # state, reward, done, _ = env.step(action.item())
    try:
        action=int(acts[x])
    except:
        action = 0 #10
    # # print("action",action)
    # # action = env.action_space.sample()
    state, reward, done, info = env.step(action)
    x+=1
    if done: break
env.play()
print(x)



765


In [None]:
torch.cuda.memory_allocated(device)
torch.cuda.memory_stats(device)

# https://stackoverflow.com/questions/48750199/google-colaboratory-misleading-information-about-its-gpu-only-5-ram-available
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7411 sha256=86344311ee07f3a14dce5fa2b550237ecd81c973d748088f322871ac5d557045
  Stored in directory: /root/.cache/pip/wheels/6e/f8/83/534c52482d6da64622ddbf72cd93c35d2ef2881b78fd08ff0c
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Gen RAM Free: 12.4 GB  | Proc size: 380.9 MB
GPU RAM Free: 15106MB | Used: 3MB | Util   0% | Total 15109MB
