<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/AIM2_simplify_strip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### buffer

In [1]:
# https://github.com/iDurugkar/adversarial-intrinsic-motivation/blob/main/grid_world_experiments/buffers.py
import random
from typing import List, Union
import numpy as np
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

class ReplayBuffer(object):
    def __init__(self, size: int):
        """Implements a ring buffer (FIFO).
        :param size: (int)  Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped."""
        self._storage = []
        self._maxsize = size
        self._next_idx = 0

    def __len__(self) -> int:
        return len(self._storage)

    @property
    def storage(self):
        """[(Union[np.ndarray, int], Union[np.ndarray, int], float, Union[np.ndarray, int], bool)]:
         content of the replay buffer"""
        return self._storage

    @property
    def buffer_size(self) -> int:
        """float: Max capacity of the buffer"""
        return self._maxsize

    def can_sample(self, n_samples: int) -> bool:
        return len(self) >= n_samples

    def is_full(self) -> int:
        return len(self) == self.buffer_size

    def add(self, obs_t, action, reward, obs_tp1, done):
        data = (obs_t, action, reward, obs_tp1, done)
        if self._next_idx >= len(self._storage):
            self._storage.append(data)
        else:
            self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._maxsize

    def extend(self, obs_t, action, reward, obs_tp1, done):
        for data in zip(obs_t, action, reward, obs_tp1, done):
            if self._next_idx >= len(self._storage):
                self._storage.append(data)
            else:
                self._storage[self._next_idx] = data
            self._next_idx = (self._next_idx + 1) % self._maxsize

    def _encode_sample(self, idxes: Union[List[int], np.ndarray]):
        obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
        for i in idxes:
            data = self._storage[i]
            obs_t, action, reward, obs_tp1, done = data
            obses_t.append(np.array(obs_t, copy=False))
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            obses_tp1.append(np.array(obs_tp1, copy=False))
            dones.append(done)

        obses_t = np.array(obses_t)
        actions = np.array(actions)
        obses_tp1 = np.array(obses_tp1)
        return (torch.tensor(obses_t).type(torch.float),
                torch.tensor(actions).type(torch.float),
                torch.tensor(rewards).type(torch.float),
                torch.tensor(obses_tp1).type(torch.float),
                torch.tensor(dones).type(torch.float))

    def sample(self, batch_size: int, **_kwargs):
        """Sample a batch of experiences.
        :param batch_size: (int) How many transitions to sample.
        :return:
            - obs_batch: (np.ndarray) batch of observations
            - act_batch: (numpy float) batch of actions executed given obs_batch
            - rew_batch: (numpy float) rewards received as results of executing act_batch
            - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch
            - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episodeand 0 otherwise."""
        idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
        return self._encode_sample(idxes)



#### policy

In [2]:
# https://github.com/iDurugkar/adversarial-intrinsic-motivation/blob/main/grid_world_experiments/policy.py
from abc import ABC
import torch
import torch.nn.functional as f
import torch.nn as nn
import numpy as np
import random


# class Mepol(nn.Module):
#     def __init__(self, s_size, a_size, h_size=32):
#         super(Mepol, self).__init__()
#         self.model=nn.Sequential(
#             nn.Linear(s_size, h_size), nn.ReLU(),
#             nn.Linear(h_size, a_size),
#             nn.Softmax(dim=0),
#         )
    
#     def forward(self, state): # og discrete
#         # state = torch.from_numpy(state).float().unsqueeze(0).to(device)
#         probs = self.model(state).cpu()
#         m = torch.distributions.Categorical(probs)
#         action = m.sample() # can't use action = np.argmax(m) use  m.sample(), sample an action with prob dist P(.|s)
#         return action.item()

class MlpNetwork(nn.Module):
    def __init__(self, input_dim, output_dim=1, activ=f.relu, n_units=64):
        super(MlpNetwork, self).__init__()
        # n_units = 512
        self.h1 = nn.Linear(input_dim, n_units)
        self.h2 = nn.Linear(n_units, n_units)
        # self.h3 = nn.Linear(n_units, n_units)
        self.out = nn.Linear(n_units, output_dim)
        self.activ = activ

    def forward(self, x):
        x = self.activ(self.h1(x))
        x = self.activ(self.h2(x))
        # x = self.activ(self.h3(x))
        x = self.out(x)
        x = f.log_softmax(x, dim=-1)
        return x


class SoftQLearning(nn.Module, ABC):
    """Learns a soft Q-function. Samples from softmax distribution of Q-values for policy"""
    def __init__(self, x_dim=1, out_dim=2, max_state=9., min_state=0, ent_coef=0.01, target_update=1e-1):
        super(SoftQLearning, self).__init__()
        self.diff_state = np.array(max_state - min_state).astype(np.float32)
        self.mean_state = np.asarray(self.diff_state / 2 + min_state).astype(np.float32)
        self.input_dim = x_dim
        self.num_actions = out_dim
        self.alpha = ent_coef
        self.q = MlpNetwork(self.input_dim, output_dim=out_dim, n_units=64)
        self.q_target = MlpNetwork(self.input_dim, output_dim=out_dim, n_units=64)
        self.target_params = self.q_target.parameters()
        self.q_params = self.q.parameters()
        self.target_update_rate = target_update

    def parameters(self, recurse: bool = True):
        return self.q_params

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        x = x.type(torch.float32)
        x = (x - self.mean_state) / self.diff_state
        return x

    def forward(self, x: torch.Tensor):
        x = self.normalize(x)
        q = self.q(x)
        v = self.alpha * torch.logsumexp(q / self.alpha, dim=-1)
        # self.alpha = max(0.01, 0.99 * self.alpha + 0.01 * (torch.mean(torch.abs(q)).detach().numpy() / 10.))
        qt = self.q_target(x)
        vt = self.alpha * torch.logsumexp(qt / self.alpha, dim=-1)
        return q, v, qt, vt
        # return q

    # def pi_loss(self, x: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
    #     """Return log_pi for the policy gradient"""
    #     x = self.normalize(x)
    #     logits = self.pi(x)
    #     actions = actions.type(torch.long)
    #     log_pi = logits.gather(dim=-1, index=actions)
    #     return log_pi

    def sample_action(self, x: torch.Tensor) -> torch.Tensor:
        """Sample from policy"""
        x = self.normalize(x)
        q = self.q(x)
        v = self.alpha * torch.logsumexp(q / self.alpha, dim=-1)
        logits = 1. / self.alpha * (q - v)
        pi = torch.exp(logits)
        action = pi.multinomial(1)
        return action

    def entropy(self, x: torch.Tensor) -> torch.Tensor:
        x = self.normalize(x)
        q = self.q(x)
        v = self.alpha * torch.logsumexp(q / self.alpha, dim=-1)
        logits = 1. / self.alpha * (q - torch.unsqueeze(v, dim=-1))
        entropy_kl = torch.sum(torch.log(torch.ones_like(logits) / self.num_actions) - logits, dim=-1)
        # pi = torch.exp(logits)
        # pisum = torch.sum(pi, dim=-1)
        # entropy = -torch.sum(pi * logits, dim=-1)
        return entropy_kl

    def update_target(self):
        """update the target network using polyak averaging"""
        with torch.no_grad():
            for c, t in zip(self.q.parameters(), self.q_target.parameters()):
                t.data.copy_((1. - self.target_update_rate) * t.data + self.target_update_rate * c.data)


#### mian

In [None]:
# https://github.com/iDurugkar/adversarial-intrinsic-motivation/blob/main/grid_world_experiments/main.py
import numpy as np
import torch
from torch import nn
# from torch.nn import utils
import torch.nn.functional as f
import random
import matplotlib.pyplot as plt
import argparse
import os
from os import path

seed=1123
reward='aim' # ['gail', 'airl', 'fairl', 'aim', 'none']
dir='/content'

torch.set_default_dtype(torch.float32)
# Set random seeds
seed = 42 * seed
print(seed)
torch.manual_seed(seed)
random.seed = seed
np.random.seed = seed
reward_to_use = reward  # use one of ['gail', 'airl', 'fairl', 'none']
print(reward_to_use)

def wasserstein_reward(d):
    return d
reward_dict = {'aim': wasserstein_reward}


class Discriminator(nn.Module):
    """The discriminator used to learn the potentials or the reward functions"""
    def __init__(self, x_dim=1, max_state=10., min_state=0):
        super(Discriminator, self).__init__()
        self.mean_state = torch.tensor((max_state - min_state) / 2 + min_state, dtype=torch.float32)
        self.diff_state = torch.tensor(max_state - min_state, dtype=torch.float32)
        self.input_dim = x_dim
        self.d = MlpNetwork(self.input_dim, n_units=64)  # , activ=f.tanh)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        x = x.type(torch.float32)
        x = (x - self.mean_state) / self.diff_state
        return x

    def forward(self, x):
        x = self.normalize(x)
        output = self.d(x)
        return output


import gym
# !pip install git+https://github.com/ntasfi/PyGame-Learning-Environment.git
# import gym_pygame


class GAIL:
    """Class to take the continuous MDP and use gail to match given target distribution"""
    def __init__(self):
        self.env = gym.make("CartPole-v1")
        # self.env = gym.make("Pendulum-v0") #continuous
        # self.env = gym.make("MountainCar-v0") #discrete
        self.s_size = self.env.observation_space.shape[0]
        self.a_size = self.env.action_space.n
        max_state = 10
        min_state = 0
        self.policy = SoftQLearning(x_dim=self.s_size, out_dim=self.a_size, max_state=max_state, min_state=min_state, ent_coef=.3, target_update=3e-2)
        # self.policy = Mepol(self.s_size, self.a_size)
        print("self.s_size",self.s_size)
        self.discriminator = Discriminator(x_dim=self.s_size, max_state=max_state, min_state=min_state)
        self.discount = 0.99
        self.check_state = set()
        self.agent_buffer = ReplayBuffer(size=5000)
        self.target_buffer = ReplayBuffer(size=5000)
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters())  # , lr=3e-4)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters())  # , lr=1e-4)
        self.max_r = 0.
        self.min_r = -1.

    def gather_data(self, num_trans=100): # line 4-
        t = 0
        while t < num_trans:
            s = self.env.reset()
            s = torch.tensor(s).type(torch.float32)
            done = False
            while not done:
                # self.states.append(deepcopy(s))
                # print("sssssssss",s)
                action = self.policy.sample_action(s)
                # self.actions.append(a)
                a = np.squeeze(action.data.detach().numpy())
                s_p, r, done, _ = self.env.step(a)
                s_p = torch.tensor(s_p).type(torch.float32)
                # d = self.discriminator(sp)
                # i_r = gail_reward(d)
                # self.next_states.append(deepcopy(s))
                # self.rewards.append(i_r)  # deepcopy(r))
                # self.dones.append(deepcopy(done))
                # print(s.squeeze(), action.reshape([-1]).detach(), r, s_p.squeeze(), done)
                # tensor([ 0.0144, -0.0008, -0.0287,  0.0008]) tensor([0]) 1.0 tensor([ 0.0144, -0.1955, -0.0286,  0.2843]) False
                self.agent_buffer.add(s.squeeze(), action.reshape([-1]).detach(), r, s_p.squeeze(), done)
                if s_p not in self.check_state:
                    self.check_state.add(s_p)
                    self.target_buffer.add(s, a, r, s_p, done)
                s = s_p
                t += 1
            # self.states.append(s)


    def compute_td_targets(self, states, next_states, dones, rewards=None):
        """Compute the value of the current states and the TD target based on one step reward and value of next states
        :return: value of current states v, TD target targets"""
        # states = states.reshape([-1, self.env.dims])
        states = states.reshape([-1, self.s_size])
        # next_states = next_states.reshape([-1, self.env.dims])
        next_states = next_states.reshape([-1, self.s_size])
        v = self.policy(states)[0]
        v_prime = self.policy(next_states)[-1]
        if rewards is not None:
            dones = rewards.type(torch.float32).reshape([-1, 1])
        else:
            dones = dones.type(torch.float32).reshape([-1, 1])
        reward_func = reward_dict[reward_to_use]
        if reward_func is not None:
            # d0 = self.discriminator(states)
            d1 = self.discriminator(next_states)
            # Compute rewards
            # r0 = reward_func(d0)
            r1 = reward_func(d1)
            rewards = rewards.type(torch.float32).reshape([-1, 1]) + ((r1 - self.max_r) / (self.max_r - self.min_r))
        targets = rewards.type(torch.float32).reshape([-1, 1])
        # print("in compute_td_targets",dones, self.discount, v_prime) #100?*[1.] 0.99 tensor([-0.6525, -0.7356]
        targets += (1. - dones) * self.discount * v_prime.reshape([-1, 1])
        return v, targets.detach()

    def fit_v_func(self):
        """This function will train the value function using the collected data"""
        self.policy_optimizer.zero_grad()
        s, a, r, s_p, dones = self.agent_buffer.sample(100)
        q, targets = self.compute_td_targets(s, s_p, dones, rewards=r)
        actions = torch.tensor(a, dtype=torch.long)
        v = q.gather(dim=-1, index=actions)
        loss = torch.mean(0.5 * (targets - v) ** 2)
        print("fit_v_func loss",loss)
        loss.backward()
        self.policy_optimizer.step()
        self.policy.update_target()
        return

    def optimize_policy(self): #line 27-29
        """This function will optimize the policy to maximize returns Based on collected data"""
        self.policy_optimizer.zero_grad()
        s, a, r, s_p, dones = self.agent_buffer.sample(100)
        v, targets = self.compute_td_targets(s, s_p, dones, rewards=r)
        advantages = (targets - v).detach()
        a = a.reshape([-1, 1]).detach()
        print("s, a",s.shape, a.shape) #[100, 4]float -1 1 [100, 1] 1./0.
        # neg_log_pi = -1. * self.policy.pi_loss(s.reshape([-1, self.env.dims]), a)
        # neg_log_pi = -1. * self.policy.pi_loss(s, a)
        # entropy_kl = self.policy.entropy(s.reshape([-1, self.env.dims]))
        neg_log_pi = nn.NLLLoss()(s, a) # me

        entropy_kl = self.policy.entropy(s)
        loss = torch.mean(advantages * neg_log_pi) + 1e-1 * torch.mean(entropy_kl)
        loss.backward()
        self.policy_optimizer.step()
        return

    def compute_aim_pen(self, target_state, prev_state, next_state_state, lambda_=10.): #equation 8 pt 2
        """Computes values of the discriminator at different points and constraints the difference to be 0.1"""
        prev_out = self.discriminator(prev_state)
        next_out = self.discriminator(next_state_state)
        penalty = lambda_ * torch.max(torch.abs(next_out - prev_out) - 0.1, torch.tensor(0.)).pow(2).mean()
        return penalty


    def optimize_discriminator(self): # line 32-33
        """Optimize the discriminator based on the memory and target_distribution"""
        num_samples = 100
        self.discriminator_optimizer.zero_grad()
        # _, _, _, target_distribution, _ = self.target_buffer.sample(100)
        # target_dist = np.reshape(self.env.target_distribution(), (-1,))
        target_distribution=np.array([4,0,0,0])
        target_dist=target_distribution.squeeze()
        # import torch.nn.functional as f
        # f.normalize(input, p=2, dim=2)
        # print("target_dist",target_dist) # [4,0,0,0]
        # p=target_dist/sum(target_dist) # [4,0,0,0]
        # print("p",p)
        # target_distribution = np.random.choice(target_dist.shape[0], num_samples, p=target_dist)
        # print("target_distribution",target_distribution) #[100* 0]
        states, _, _, next_states, _ = self.agent_buffer.sample(num_samples)
        # target_distribution = sample_target_distribution(mean=self.env.target_mean, std=self.env.target_std, num=100)
        target_distribution = target_distribution.reshape([-1, 1])

        target_distribution = np.tile(target_distribution,(1,self.s_size))
        # print("target_distribution",target_distribution)
        # next_states = next_states.reshape([-1, self.env.dims])
        # print("next_states",next_states.shape) #[100, 4]
        
        ones = torch.tensor(target_distribution).type(torch.float32)
        zeros = torch.tensor(next_states).type(torch.float32)
        zeros_prev = torch.tensor(states).type(torch.float32)
        
        # print(ones.shape, zeros.shape, zeros_prev.shape) #[100, 4] [100, 4] [100, 4]

        # ####### WGAN loss
        pred_ones = self.discriminator(ones)
        pred_zeros = self.discriminator(zeros)
        preds = torch.cat([pred_zeros, pred_ones], dim=0)
        self.max_r = torch.max(preds).detach().cpu().numpy() + 0.1
        self.min_r = torch.min(preds).detach().cpu().numpy() - 0.1
        wgan_loss = torch.mean(pred_zeros) + torch.mean(pred_ones * (-1.)) # equation 8 pt 1
        aim_penalty = self.compute_aim_pen(ones, zeros_prev, zeros) # equation 8 pt 2
        # grad_penalty = self.compute_grad_pen(ones, zeros)
        loss = wgan_loss + aim_penalty  # + grad_penalty
        print("optimize_discriminator loss",loss)
        # loss = torch.mean(- labels * pred.log() - (1 - labels) * (1. - pred).log())
        loss.backward()
        # utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=0.5)
        self.discriminator_optimizer.step()

    def act(self, state):
        # state = torch.tensor(state).type(torch.float32)
        action = self.policy.sample_action(state)
        # a = np.squeeze(action.data.detach().numpy())
        # s_p, r, done, _ = self.env.step(a)
        return action


gail = GAIL() #line 1-2
# .to(device)
gail.gather_data(num_trans=50) #500
print('')
for i in range(100): #500 line 3
    for _ in range(5):
        # gail.gather_data(num_trans=500)
        gail.optimize_discriminator() #line 31-34
        # gail.optimize_discriminator(target_states, policy_states, policy_next_states)
    for _ in range(10):
        gail.gather_data(num_trans=50)
        gail.fit_v_func() # update policy line 29

        # Useful only if using a separate policy
        # gail.gather_data(num_trans=500)
        # gail.optimize_policy()


47166
aim
self.s_size 4

optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
fit_v_func loss tensor(0.7173, grad_fn=<MeanBackward0>)




fit_v_func loss tensor(0.7123, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7111, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7151, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7003, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7028, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.6993, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7074, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7057, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.6887, grad_fn=<MeanBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
optimize_discriminator loss tensor(0., grad_fn=<AddBackward0>)
fit_v_func loss tensor(0.7154, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7138, grad_fn=<MeanBackward0>)
fit_v_func loss tensor(0.7195, grad_fn=<MeanBackward0>)
fit_v_func lo

In [22]:
gail.gather_data(num_trans=50)
print('')
for i in range(100): #500
    if reward_to_use != 'none':
        for _ in range(5):
            # gail.gather_data(num_trans=500)
            gail.optimize_discriminator()
            # gail.optimize_discriminator(target_states, policy_states, policy_next_states)
    for _ in range(10):
        gail.gather_data(num_trans=500)
        gail.fit_v_func()







#### eval

In [23]:

env = gym.make("CartPole-v1")
model=gail
rewards=0
s = env.reset()
s = torch.tensor(s.copy()).type(torch.float32)
done = False
while not done:
    # print("sssssssss",s)
    action = model.act(s)
    a = np.squeeze(action.data.detach().numpy())
    s_p, r, done, _ = env.step(a)
    s_p = torch.tensor(s_p).type(torch.float32)
    s = s_p
    rewards+=r
print(rewards)
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity


9.0


#### video

In [15]:
!pip install imageio-ffmpeg
!apt install python-opengl ffmpeg xvfb
!pip3 install pyvirtualdisplay
# Virtual display
from pyvirtualdisplay import Display
virtual_display = Display(visible=0, size=(500, 500))
virtual_display.start()
import imageio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Reading package lists... Done
Building dependency tree       
Reading state information... Done
python-opengl is already the newest version (3.1.0+dfsg-1).
ffmpeg is already the newest version (7:3.4.11-0ubuntu0.1).
xvfb is already the newest version (2:1.19.6-1ubuntu4.11).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [16]:
# make mp4

env = gym.make("CartPole-v1")
model=gail
rewards=0
images = []  
done = False
s = env.reset()
img = env.render(mode='rgb_array')
images.append(img)
s = torch.tensor(s.copy()).type(torch.float32)
done = False
while not done:
    # print("sssssssss",s)
    action = model.act(s)
    a = np.squeeze(action.data.detach().numpy())
    s_p, r, done, _ = env.step(a)
    s_p = torch.tensor(s_p).type(torch.float32)
    s = s_p
    rewards+=r

    img = env.render(mode='rgb_array')
    images.append(img)
# print('Episode: {} \t\t Reward: {}'.format(ep, round(ep_reward, 2)))
imageio.mimsave("video.mp4", [np.array(img) for i, img in enumerate(images)], fps=30)

print(rewards)
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity




10.0


In [None]:

from IPython.display import HTML
from base64 import b64encode
mp4 = open('video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

### mario

#### wrappers

In [None]:

class MarioSparse(gym.Wrapper):
    def __init__(self, env):
        # super().__init__(env)
        super(MarioSparse, self).__init__(env)
        self.env = env
        self.total_score = 0
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        life = info['life']
        score = info['score']
        self.total_score += score
        if life<2:
            print("MarioSparse: died")
            # return observation, score, True, info # lost one life, end env
            done = True
        # else:
            # self.total_score = 0
        return observation, score, done, info
    def reset(self):
        self.total_score = 0
        return self.env.reset()
# env = MarioSparse(env)

class MarioEarlyStop(gym.Wrapper):
    def __init__(self, env):
        # super().__init__(env)
        super(MarioEarlyStop, self).__init__(env)
        self.env = env
        self.max_pos = 0
        self.count_step = 0
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        x_pos = info['x_pos']
        if x_pos <= self.max_pos: self.count_step += 1
        else:
            self.max_pos = x_pos
            self.count_step = 0
        if self.count_step > 500:
            print("MarioEarlyStop: early stop ", self.max_pos)
            # return observation, reward, True, info # early stop
            done = True
        # else:
        return observation, reward, done, info
    def reset(self):
        self.max_pos = 0
        self.count_step = 0
        return self.env.reset()
# env = MarioEarlyStop(env)



class PosState(gym.Wrapper):
    def __init__(self, env):
        super(PosState, self).__init__(env)
        self.env = env
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        x_pos = info['x_pos']
        y_pos = info['y_pos']
        return [x_pos,y_pos], score, done, info
    def reset(self):
        return self.env.reset()
# env = PosState(env)


#### setup

In [None]:
!pip install gym-super-mario-bros nes-py
# https://github.com/Kautenja/gym-super-mario-bros
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

env = gym_super_mario_bros.make('SuperMarioBros-v3') #og v0 pixel v3
env = JoypadSpace(env, COMPLEX_MOVEMENT) # SIMPLE_MOVEMENT COMPLEX_MOVEMENT
env = MarioSparse(env)
env = MarioEarlyStop(env)
env = PosState(env)


