In [None]:
# To run this code you must use iPython. Also you can use the .ipynb file in ipython notebook mode.
from TrainCatch import CatchEnvironment
from TrainCatch import Qnet
from IPython import display
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pylab as pl
import time
import torch
import torch.nn as nn
import numpy as np
import datetime


gridSize = 10 # The size of the grid that the agent is going to play the game on.
maxGames = 100 # The number of games the agent will play.
GAME_VISIBLITY = False # Display Option

env = CatchEnvironment(gridSize)
total_score = max_score = score = 0
numberOfGames = 0

ground = 1
plot = pl.figure(figsize=(12,12))
axis = plot.add_subplot(111, aspect='equal')
axis.set_xlim([-1, 12])
axis.set_ylim([0, 12])


#----------------------------
# Visualize the game arena
#----------------------------
def drawState(fruitRow, fruitColumn, basket):
  global gridSize
  # column is the x axis
  fruitX = fruitColumn
  # Invert matrix style points to coordinates
  fruitY = (gridSize - fruitRow + 1)
  statusTitle = "Score: {:03d}, Max: {:03d}, Avg: {:03.1f}, TotalGame: {:03d}".format(score, max_score, \
            total_score/float(numberOfGames) if numberOfGames > 0 else 0, \
            numberOfGames)
  axis.set_title(statusTitle, fontsize=30)
  for p in [
    patches.Rectangle(
        ((ground - 1), (ground)), 11, 10,
        facecolor="#000000"      # Black
    ),
    patches.Rectangle(
        (basket - 1, ground), 2, 0.5,
        facecolor="#FF0000"     # No background
    ),
    patches.Rectangle(
        (fruitX - 0.5, fruitY - 0.5), 1, 1,
        facecolor="#FF0000"       # red
    ),
    ]:
      axis.add_patch(p)
  display.clear_output(wait=True)
  display.display(pl.gcf())


#----------------------------
# Write Log
#----------------------------
def write_log(filename, content):
  with open('./logs/' + filename, 'a') as f:
    f.write(content + '\n')


# Restore variables from disk.
q = Qnet()
q.load_state_dict(torch.load('./weights/Train_Catch.pth'))
print('Load model!')

# For log file
now_dt = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = "test_{}.log".format(now_dt)

# Start playing
while (numberOfGames < maxGames):
  # The initial state of the environment.
  score = 0
  isGameOver = False
  fruitRow, fruitColumn, basket = tuple(env.reset())
  if GAME_VISIBLITY:
    drawState(fruitRow, fruitColumn, basket)

  # The game lasts until the agent drops a fruit. or score > 999
  while (isGameOver != True and score < 999): 
    # Forward the current state through the network.
    s = np.array([fruitRow, fruitColumn, basket])
    out = q.forward(torch.from_numpy(s).float())
    # Find the max index (the chosen action).
    action = out.argmax().item()
    nextState, reward, gameOver = env.act(action)
    fruitRow = nextState[0]
    fruitColumn = nextState[1]
    basket = nextState[2]
    # Count game results
    if (reward == 1):
      score += 1
    currentState = nextState
    isGameOver = gameOver
    if GAME_VISIBLITY:
      drawState(fruitRow, fruitColumn, basket)
      time.sleep(0.01)
    if (reward == 1):
      env.resetFruit()
  
  # Write test score
  logmsg = "game : {}, score : {}".format(numberOfGames + 1, score)
  print(logmsg)
  write_log(filename, logmsg)
  
  # Update score stats
  total_score += score
  if (max_score < score):
    max_score = score
  numberOfGames += 1


# Clear the display
if GAME_VISIBLITY:
  display.clear_output(wait=True)