In [None]:
"""
Implementation of the Brute from "Revisiting the Arcade Learning Environment:
Evaluation Protocols and Open Problems for General Agents" by Machado et al.
https://arxiv.org/abs/1709.06009
This is an agent that uses the determinism of the environment in order to do
pretty well at a number of retro games.  It does not save emulator state but
does rely on the same sequence of actions producing the same result when played
back.
"""

import random
import argparse

import numpy as np
import retro
import gym


EXPLORATION_PARAM = 0.05


class StochasticFrameSkip(gym.Wrapper):
    def __init__(self, env, n, stickprob):
        gym.Wrapper.__init__(self, env)
        self.n = n
        self.stickprob = stickprob
        self.curac = None
        self.rng = np.random.RandomState()
        self.supports_want_render = hasattr(env, "supports_want_render")

    def reset(self, **kwargs):
        self.curac = None
        return self.env.reset(**kwargs)

    def step(self, ac):
        done = False
        totrew = 0
        for i in range(self.n):
            # First step after reset, use action
            if self.curac is None:
                self.curac = ac
            # First substep, delay with probability=stickprob
            elif i==0:
                if self.rng.rand() > self.stickprob:
                    self.curac = ac
            # Second substep, new action definitely kicks in
            elif i==1:
                self.curac = ac
            if self.supports_want_render and i<self.n-1:
                ob, rew, done, info = self.env.step(self.curac, want_render=False)
            else:
                ob, rew, done, info = self.env.step(self.curac)
            totrew += rew
            if done: break
        return ob, totrew, done, info

    def seed(self, s):
        self.rng.seed(s)

class SnesDiscretizer(gym.ActionWrapper):
    """
    Wrap a gym-retro environment and make it use discrete
    actions for the Sonic game.
    """
    def __init__(self, env):
        super(SnesDiscretizer, self).__init__(env)
        buttons = ["B", "Y", "SELECT", "START", "up", "down", "left", "right", "A", "X", "L", "R"]
        actions = [['right'],['right', 'A'],['right', 'B'],['right','Y'],['A'],['B'],['left'],['left', 'A'],
                   ['left', 'B'],['left','Y'],['A','Y'],['B','Y'],['down'],['up'], ['Y', 'up'],['B','up'], ['A','up'],['A','Y','right'], 
                  ['A','Y','left'],['B','Y','right'],['B','Y','left'],['SELECT']]
      
        
        self._actions = []
        for action in actions:
            arr = np.array([False] * 12)
            for button in action:
                arr[buttons.index(button)] = True
            self._actions.append(arr)
        self.action_space = gym.spaces.Discrete(len(self._actions))

    def action(self, a): # pylint: disable=W0221
        return self._actions[a].copy()
    
class ProcessFrameMario(gym.Wrapper):
    def __init__(self, env=None, reward_type=None, dim=84):
        super(ProcessFrameMario, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(1, dim, dim), dtype=np.uint8)
        
        self.timer = 90000
        self.countdown = 0
        self.multiplier = 0
        self.fresh= True
        self.x = 0
        self.s = 0
        self.code_covered = set()
        self.crashed = False 
        
    def step(self, action): #pylint: disable=method-hidden
            
        
        if  self.timer % 50 == 0:
                retro._retro.Memory.assign(self.env.data.memory, 8261058, "uint8", 1) 
            

        if self.fresh: 

            retro._retro.Memory.assign(self.env.data.memory, 8257561, "uint8", 22)
            retro._retro.Memory.assign(self.env.data.memory, 8261058, "uint8", 1) 

            self.fresh = False
        
        #action =  self.env.action_space.sample()
        obs, _, done, info = self.env.step(action)
      
        
        if (info ['powerup'] != 22) and (info['powerup']>3):
            self.crashed = True
         
       
        if self.crashed:
            info['crash'] = 1
            
        else:
            info ['crash'] = 0
         
        self.timer-=1
        
        reward = 0
      
        trace = info ['trace'][:1000]
        line = [x[2] for x in trace]
        for word in line:
            if word not in self.code_covered:  
                self.code_covered.add(word)
                #reward+=1
                reward = 1 
                
                
        if reward == 0:
            self.countdown += 1
        else:
            self.countdown = 0
            self.multiplier+=1
            #reward = self.multiplier      
        
        if self.countdown > 200:
            done = True
            
        #if self.timer ==0:
            #done = True
            
       
        
        if done:     
            self.timer = 90000
            self.fresh = True 
            self.x = 0
            self.s = 0
            self.countdown = 0
            self.code_covered = set()
            self.crashed = False 
            self.multiplier = 0
    
        
        
        #return obs, (reward*reward)/500, done, info
        return obs, reward, done, info


class Frameskip(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._skip = skip
        #self.x = 0

    def reset(self):
        return self.env.reset()

    def step(self, act):
        total_rew = 0.0
        done = None
        for i in range(self._skip):
            obs, rew, done, info = self.env.step(act)
            
            """
            rew = 0
            if info['x']> self.x:
                rew = info['x'] - self.x
                self.x = info['x']
            """
            
            total_rew += rew
            if done:
                self.x = 0 
                break

        return obs, total_rew, done, info


class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super().__init__(env)
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

    def step(self, ac):
        observation, reward, done, info = self.env.step(ac)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info['TimeLimit.truncated'] = True
        return observation, reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self.env.reset(**kwargs)


class Node:
    def __init__(self, value=-np.inf, children=None):
        self.value = value
        self.visits = 0
        self.children = {} if children is None else children

    def __repr__(self):
        return "<Node value=%f visits=%d len(children)=%d>" % (
            self.value,
            self.visits,
            len(self.children),
        )


def select_actions(root, action_space, max_episode_steps):
    """
    Select actions from the tree
    Normally we select the greedy action that has the highest reward
    associated with that subtree.  We have a small chance to select a
    random action based on the exploration param and visit count of the
    current node at each step.
    We select actions for the longest possible episode, but normally these
    will not all be used.  They will instead be truncated to the length
    of the actual episode and then used to update the tree.
    """
    node = root

    acts = []
    steps = 0
    while steps < max_episode_steps:
        if node is None:
            # we've fallen off the explored area of the tree, just select random actions
            act = action_space.sample()
        else:
            epsilon = EXPLORATION_PARAM / np.log(node.visits + 2)
            if random.random() < epsilon:
                # random action
                act = action_space.sample()
            else:
                # greedy action
                act_value = {}
                for act in range(action_space.n):
                    if node is not None and act in node.children:
                        act_value[act] = node.children[act].value
                    else:
                        act_value[act] = -np.inf
                best_value = max(act_value.values())
                best_acts = [
                    act for act, value in act_value.items() if value == best_value
                ]
                act = random.choice(best_acts)

            if act in node.children:
                node = node.children[act]
            else:
                node = None

        acts.append(act)
        steps += 1

    return acts


def rollout(env, acts):
    """
    Perform a rollout using a preset collection of actions
    """
    total_rew = 0
    env.reset()
    steps = 0
    for act in acts:
        _obs, rew, done, _info = env.step(act)
        

        steps += 1
        total_rew += rew
        if done:
            break

    return steps, total_rew


def update_tree(root, executed_acts, total_rew):
    """
    Given the tree, a list of actions that were executed before the game ended, and a reward, update the tree
    so that the path formed by the executed actions are all updated to the new reward.
    """
    root.value = max(total_rew, root.value)
    root.visits += 1
    new_nodes = 0

    node = root
    for step, act in enumerate(executed_acts):
        if act not in node.children:
            node.children[act] = Node()
            new_nodes += 1
        node = node.children[act]
        node.value = max(total_rew, node.value)
        node.visits += 1

    return new_nodes


class Brute:
    """
    Implementation of the Brute
    Creates and manages the tree storing game actions and rewards
    """

    def __init__(self, env, max_episode_steps):
        self.node_count = 1
        self._root = Node()
        self._env = env
        self._max_episode_steps = max_episode_steps

    def run(self):
        acts = select_actions(self._root, self._env.action_space, self._max_episode_steps)
        steps, total_rew = rollout(self._env, acts)
        executed_acts = acts[:steps]
        self.node_count += update_tree(self._root, executed_acts, total_rew)
        return executed_acts, total_rew

max_episode_steps = 800

env = retro.make( game='SuperMarioWorld-Snes', state= 'YoshiIsland4.state', use_restricted_actions=retro.Actions.ALL)
env = StochasticFrameSkip(env, n=4, stickprob=0.25)

#env = Frameskip(env)
#env = TimeLimit(env, max_episode_steps=max_episode_steps)
env = SnesDiscretizer(env)
env = ProcessFrameMario (env)

timestep_limit=1e12

brute = Brute(env, max_episode_steps=max_episode_steps)
timesteps = 0
best_rew = float('-inf')
while True:
    acts, rew = brute.run()
    timesteps += 1
    print ("timesteps = ", timesteps)
    print (rew)
    if rew > best_rew:
        print("new best reward {} => {}".format(best_rew, rew))
        best_rew = rew
        #env.unwrapped.record_movie("best.bk2")
        env.reset()
        #for act in acts:
        #    env.step(act)
        #env.unwrapped.stop_record()
    acts, rew = brute.run()
    print ("timesteps = ", timesteps)
    print (rew)
    if rew > best_rew:
        print("new best reward {} => {}".format(best_rew, rew))
        best_rew = rew
        #env.unwrapped.record_movie("best.bk2")
        #for act in acts:
        #    env.step(act)
        #env.unwrapped.stop_record()
    if timesteps > timestep_limit:
        print("timestep limit exceeded")
        break



timesteps =  1
2
new best reward -inf => 2
timesteps =  1
4
new best reward 2 => 4
timesteps =  2
8
new best reward 4 => 8
timesteps =  2
18
new best reward 8 => 18
timesteps =  3
4
timesteps =  3
3
timesteps =  4
3
timesteps =  4
2
timesteps =  5
4
timesteps =  5
2
timesteps =  6
2
timesteps =  6
2
timesteps =  7
2
timesteps =  7
5
timesteps =  8
4
timesteps =  8
5
timesteps =  9
8
timesteps =  9
2
timesteps =  10
2
timesteps =  10
4
timesteps =  11
3
timesteps =  11
5
timesteps =  12
4
timesteps =  12
2
timesteps =  13
4
timesteps =  13
15
timesteps =  14
6
timesteps =  14
2
timesteps =  15
2
timesteps =  15
9
timesteps =  16
4
timesteps =  16
2
timesteps =  17
2
timesteps =  17
4
timesteps =  18
3
timesteps =  18
14
timesteps =  19
5
timesteps =  19
2
timesteps =  20
2
timesteps =  20
5
timesteps =  21
2
timesteps =  21
2
timesteps =  22
4
timesteps =  22
2
timesteps =  23
12
timesteps =  23
4
timesteps =  24
4
timesteps =  24
12
timesteps =  25
4
timesteps =  25
2
timesteps =  26
2