Skip to content

Commit

Permalink
Fix save/load agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Johan Backman committed Jun 10, 2019
1 parent f16d82b commit 80a5725
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
2 changes: 1 addition & 1 deletion framework/devices/dolphin/dolphin.py
Expand Up @@ -88,7 +88,7 @@ def __create_fifo_pipe(self, fifo_name: Text) -> Text:
pipes_dir.mkdir()

if fifo_path.exists():
return
return fifo_path

fifo_path = str(fifo_path)
try:
Expand Down
2 changes: 1 addition & 1 deletion framework/devices/dolphin/dolphin_pad.py
Expand Up @@ -64,7 +64,7 @@ def connect(self):
self.pipe = open(self.path, 'w', buffering=1)

def _send_to_pipe(self, msg):
log.info(msg)
# log.info(msg)
current_time = time.time()
sleep_time = self.last_command_time + self.MIN_COOLDOWN - current_time
if sleep_time > 0:
Expand Down
35 changes: 29 additions & 6 deletions framework/games/ssbm/ssbm.py
@@ -1,6 +1,7 @@

import logging
import time
import copy
from struct import pack, unpack
from typing import List

Expand All @@ -10,7 +11,8 @@
from framework.devices.device import Device
from framework.games.game import Game
from framework.games.ssbm.ssbm_menu_helper import SSBMMenuHelper
from framework.games.ssbm.ssbm_observation import SSBMObservation
from framework.games.ssbm.ssbm_observation import SSBMObservation, Position
from framework.games.ssbm.ssbm_reward import SimpleSSBMReward

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,6 +46,8 @@ def __init__(self, device: Device, agents: List[Agent], sampling_window: float):
self.sampling_window = sampling_window
self.machine = self._build_state_machine()
self.menu_helper = SSBMMenuHelper(self.device.pad)
self.reward_calculator = SimpleSSBMReward()
self.all_obervations = []
self.reset_state()

def _build_state_machine(self):
Expand Down Expand Up @@ -77,7 +81,7 @@ def run(self):

while True:
address, value = self.device.read_state()
if address is not None: # None if socket times out first
if address is not None: # None if socket times out first
self.update_observation(address, value)

# Wait for stock information
Expand All @@ -89,9 +93,10 @@ def run(self):
new_time = time.time()
address, value = self.device.read_state()

if address is not None: # None if socket times out first
if address is not None: # None if socket times out first
self.update_observation(address, value)
if self._is_done():
self.save_agents()
self.finish_game()
break

Expand All @@ -103,18 +108,36 @@ def reset_state(self):
self.current_observation = SSBMObservation()
self.last_update = time.time()
self.frame_counter = 0
self.all_obervations = []

def reset(self):
self.device.pad.reset_button_state()
self.reset_state()
self.restart_game()

def save_agents(self):
for agent in self.agents:
agent.save()

def update_agents(self, observation):
for agent in self.agents:
action = agent.act(observation, self.frame_counter)
print(f"Should take action: {action}")
# print(f"Should take action: {action}")
self.device.set_button_state(action.as_slippi_bitmask())

if self.all_obervations:
reward = self.reward_calculator.cost(
observation, self.all_obervations, self.frame_counter)
agent.learn(self.all_obervations[-1],
observation, action, reward, 0.0)
print(f'Reward: {reward}')

previous_observation = SSBMObservation(
Position(observation.player_x, observation.player_y),
Position(observation.enemy_x, observation.enemy_y),
observation.player_stocks, observation.enemy_stocks,
observation.player_percent, observation.enemy_percent)
self.all_obervations.append(previous_observation)
self.frame_counter += 1

def update_observation(self, address, buffer):
Expand All @@ -128,7 +151,7 @@ def update_observation(self, address, buffer):

def _is_done(self):
return self.current_observation.player_stocks <= 0 or \
self.current_observation.enemy_stocks <= 0
self.current_observation.enemy_stocks <= 0

def on_enter_start_menu(self):
self.menu_helper.go_to_character_select()
Expand Down Expand Up @@ -156,4 +179,4 @@ def on_exit_character_preselection(self):
time.sleep(2)

def on_enter_game_done(self):
time.sleep(5) # Wait for stats screen to appear
time.sleep(5) # Wait for stats screen to appear
6 changes: 4 additions & 2 deletions smashrl/ssbm_agent.py
@@ -1,6 +1,7 @@
"""The actual RL Agent"""

import logging
import os
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -54,12 +55,13 @@ def learn(self, observation: SSBMObservation,
return self.q.train(observations, observations_next, actions, [reward], [done])

def load(self, path='./trained_dqn/dqn.ckpt'):
if not Path(path).exists():
if not Path(os.path.dirname(path)).exists():
log.info(
f'No pre-trained agent found in {path}... Running new model')
return

self.q.load()
self.q.load(path)

def save(self, path='./trained_dqn/dqn.ckpt'):
log.info('Saving agent...')
return self.q.save(path)

0 comments on commit 80a5725

Please sign in to comment.