# Street Fighter Tester Notebook
This notebook can be run to test models side by side during training runs. 

In [1]:
import retro

In [2]:
from gym import Env
from gym.spaces import Discrete, Box, MultiBinary
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pygame


pygame 2.1.3 (SDL 2.0.22, Python 3.7.17)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
class StreetFighter(Env):
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
        #self.action_space = MultiBinary(12) # one player
        self.action_space = MultiBinary(24) # two players
        # ref.: https://retro.readthedocs.io/en/latest/python.html
        # self.game = retro.make(game='StreetFighterIISpecialChampionEdition-Genesis', players=2, use_restricted_actions=retro.Actions.FILTERED) # ok model train
        self.game = retro.make(game='StreetFighterIISpecialChampionEdition-Genesis', players=2, use_restricted_actions=retro.Actions.ALL) # model vs player
        
        # Inicialize o pygame para capturar o joystick
        pygame.init()
        pygame.joystick.init()
        if pygame.joystick.get_count() > 0:
            self.joystick = pygame.joystick.Joystick(0)
            self.joystick.init()
        else:
            raise Exception("Nenhum controle conectado.")

    def step(self, actions):
        # Extrair ações dos jogadores
        #actions 0-11  # Ação do Player 1 (joystick)
        #actions 12-23   # Ação do Player 2 (modelo)
        #print(f"Actions: {actions}")
        # Passar as ações para o ambiente retro
        obs, reward, done, info = self.game.step(actions)
        obs = self.preprocess(obs)

        #print(f"Botões: {self.game.buttons}")
        
        # Calcular a recompensa com base na saúde
        reward = (self.enemy_health - info['enemy_health']) * 2 + (info['health'] - self.health)
        self.health = info['health']
        self.enemy_health = info['enemy_health']

        return obs, reward, done, info
    
    def render(self, *args, **kwargs): 
        self.game.render(*args, **kwargs)

    def reset(self):
        self.previous_frame = np.zeros(self.observation_space.shape)
        
        # Resetar o ambiente
        obs = self.game.reset()
        obs = self.preprocess(obs)
        self.previous_frame = obs
        self.health = 176
        self.enemy_health = 176
        
        return obs

    def preprocess(self, observation): 
        gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)
        state = np.reshape(resize, (84,84,1))
        return state
    
    def get_player1_action(self):
        pygame.event.pump()  # Atualizar eventos do pygame
        action = [0.0] * 12

        # Mapeamento de botões e eixos para ações
        #  0    1    2       3       4      5       6       7        8    9    10   11
        # ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
        
        if self.joystick.get_button(0):
            action[0] = 1.0
        if self.joystick.get_button(1):
            action[1] = 1.0
        if self.joystick.get_button(2):
            action[9] = 1.0
        if self.joystick.get_button(3):
            action[10] = 1.0
        if self.joystick.get_button(4):
            action[4] = 1.0
        if self.joystick.get_button(5):
            action[5] = 1.0
        if self.joystick.get_button(6):
            action[6] = 1.0
        if self.joystick.get_button(7):
            action[7] = 1.0
        if self.joystick.get_button(8):
            action[8] = 1.0
        if self.joystick.get_button(9): # start
            action[3] = 1.0
        if self.joystick.get_axis(0) < -0.5: # left
            action[6] = 1.0
        if self.joystick.get_axis(0) > 0.5: # right
            action[7] = 1.0
        if self.joystick.get_axis(1) < -0.5: # up
            action[4] = 1.0
        if self.joystick.get_axis(1) > 0.5: # down
            action[5] = 1.0
        if self.joystick.get_hat(0)[0] == -1: # left
            action[6] = 1.0
        if self.joystick.get_hat(0)[0] == 1:  # right
            action[7] = 1.0
        if self.joystick.get_hat(0)[1] == -1:  # down
            action[5] = 1.0
        if self.joystick.get_hat(0)[1] == 1:  # up
            action[4] = 1.0
        return np.array(action).reshape(1, -1) # para o mesmo formato do model
    
    def get_player2_action():
        pygame.event.pump()  # Atualizar eventos do pygame
        action = [0.0] * 12

        # Mapeamento de botões e eixos para ações
        #  0    1    2       3       4      5       6       7        8    9    10   11
        # ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
        
        if self.joystick.get_button(0):
            action[0] = 1.0
        if self.joystick.get_button(1):
            action[1] = 1.0
        if self.joystick.get_button(2):
            action[9] = 1.0
        if self.joystick.get_button(3):
            action[10] = 1.0
        if self.joystick.get_button(4):
            action[4] = 1.0
        if self.joystick.get_button(5):
            action[5] = 1.0
        if self.joystick.get_button(6):
            action[6] = 1.0
        if self.joystick.get_button(7):
            action[7] = 1.0
        if self.joystick.get_button(8):
            action[8] = 1.0
        if self.joystick.get_button(9): # start
            action[3] = 1.0
        if self.joystick.get_axis(0) < -0.5: # left
            action[6] = 1.0
        if self.joystick.get_axis(0) > 0.5: # right
            action[7] = 1.0
        if self.joystick.get_axis(1) < -0.5: # up
            action[4] = 1.0
        if self.joystick.get_axis(1) > 0.5: # down
            action[5] = 1.0
        if self.joystick.get_hat(0)[0] == -1: # left
            action[6] = 1.0
        if self.joystick.get_hat(0)[0] == 1:  # right
            action[7] = 1.0
        if self.joystick.get_hat(0)[1] == -1:  # down
            action[5] = 1.0
        if self.joystick.get_hat(0)[1] == 1:  # up
            action[4] = 1.0
        return np.array(action).reshape(1, -1) # para o mesmo formato do model
    
    def close(self): 
        self.game.close()
        pygame.quit()


In [4]:
import time
# Import PPO for algos
from stable_baselines3 import PPO
# Evaluate Policy
from stable_baselines3.common.evaluation import evaluate_policy
# Import Wrappers
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage

In [5]:
#5.46m is the model that performed best 6.16m pretty good as well
#model = PPO.load('./train_nodelta/best_model_770000.zip')
#model = PPO.load('./train_nodelta/best_model_170000.zip')
#model = PPO.load('./train/best_model_100000.zip')
#model = PPO.load('./train_nodelta/best_model_70000.zip')
model = PPO.load('./train_nodelta/best_model_3260000.zip')

In [6]:
env = StreetFighter()
env = Monitor(env)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')

In [7]:
do_train = False
if do_train:
    for episode in range(1): 
        obs = env.reset()
        done = False
        total_reward = 0
        while not done: 
            action, _ = model.predict(obs)
            obs, reward, done, info = env.step(action)
            env.render()
            time.sleep(0.01)
            total_reward += reward
        print('Total Reward for episode {} is {}'.format(total_reward, episode))
        time.sleep(2)
    env.close()

In [8]:
# para manter padrao das acoes onde o model foi treinado
buttons = env.envs[0].game.buttons
num_buttons = len(buttons)
num_buttons_per_player = num_buttons // 2
filtered_actions = env.actions  # Ações filtradas para o modelo

# ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
buttons = {
    0: "B", 1: "A", 2: "MODE", 3: "START",
    4: "UP", 5: "DOWN", 6: "LEFT", 7: "RIGHT",
    8: "C", 9: "Y", 10: "X", 11: "Z"
}

def display_buttons_status(actions, label):
    status = []
    for idx, action in enumerate(actions[0]):
        state = "1" if action == 1 else "0"
        status.append(f"{buttons[idx]}: {state}")
    return f"{label}: {', '.join(status)}"


In [None]:
import torch
import numpy as np
import time
import pygame

# Reiniciar o ambiente e obter a primeira observação
obs = env.reset()

# Verificar a inicialização do joystick
if pygame.joystick.get_count() != 0:
    print(f"Controle detectado: {pygame.joystick.Joystick(0).get_name()}")

# Verificar inicialização do pygame
if not pygame.get_init():
    raise SystemExit("Erro: o sistema de vídeo do pygame não foi inicializado corretamente.")

do_game = True

if do_game:
    try:
        done = False
        while not done:
            env.render()  # Renderiza o ambiente

            # Obter a ação do modelo (jogador 1)
            action_p1, _ = model.predict(obs)
            #print(action_p1)
            action_p1[0][3] = 0 # start always released

            # Capturar ação do jogador (joystick)
            action_p2 = env.envs[0].get_player1_action()  # Use o método get_joystick_action da classe StreetFighter
            
            combined_actions = np.concatenate((action_p1, action_p2), axis=1)
            
            # Exibir as informações da etapa se necessário
            #print(f"Reward: {reward}, Done: {done}, Info: {info}", end='\r')
            #print(f"Model: {action_p1} Player: {action_p2} ALL: {combined_actions}", end='\r')
            print(f"{display_buttons_status(action_p1, 'Model')} {display_buttons_status(action_p2, 'Player')}", end='\r')
            
            # Executar as ações de ambos os jogadores
            obs, reward, done, info = env.step(combined_actions)
                
            # Controlar a velocidade do loop
            time.sleep(0.01)

        env.close()

    finally:
        env.close()
        pygame.quit()


Controle detectado: Xbox One Controller
Model: B: 0, A: 0, MODE: 1, START: 0, UP: 1, DOWN: 1, LEFT: 0, RIGHT: 1, C: 0, Y: 1, X: 1, Z: 0 Player: B: 0, A: 0, MODE: 0, START: 0, UP: 0, DOWN: 0, LEFT: 0, RIGHT: 0, C: 0, Y: 0, X: 0, Z: 0