## MnistGym
In this OpenAI training gym environment, handwritten digits (0 through 9) are displayed to a reinforcement learning agent on a 128x128px canvas. A correct discrete value action for a matching observation receives a reward.

In [133]:
# Load dependencies.
import os
import cv2
import numpy as np
import tensorflow as tf
from tqdm.notebook import trange

import gym
from gym import spaces

from stable_baselines import DQN
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import CnnPolicy

import matplotlib.pyplot as plt
%matplotlib inline

In [137]:
# Import the MNIST data.
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Define the gym environment.
class MnistGym(gym.Env):
    def __init__(self, width=128, height=128, channels=1, dataset):
        # Training dataset (Handwritten digits on a 28x28px canvas).
        self.X, self.y = dataset
        
        # Reset the state index, used to step through dataset.
        self.idx = 0
        
        # Digits 0-9 are valid actions.
        self.action_space = spaces.Discrete(10)
        
        # A 1-channel canvas is used for observations.
        self.observation_space = spaces.Box(low=0, high=255, shape=(width, height, channels), dtype=np.uint8)
    def _obs(self):
        # Return a frame at the target dimensions from self.X at the current state index for the CnnPolicy.
        width, height, channels = (self.observation_space.shape[0],
                                   self.observation_space.shape[1],
                                   self.observation_space.shape[2])
        obs = self.X[self.idx]
        
        # Enlarge the observation if the dataset is smaller than the target canvas.
        if obs.shape[0] < width or obs.shape[1] < height:
            obs = cv2.resize(obs, (width, height), interpolation = cv2.INTER_CUBIC)
            obs = obs.reshape(width, height, channels)
        return obs
    def step(self, action):
        # The agent earns 1 point for a correct label.
        reward = 1 if action == self.y[self.idx] else 0
        
        # The state index increments at each step then wraps around at the end of the training dataset.
        self.idx = self.idx + 1 if self.idx < len(self.X) - 1 else 0
        
        # Return the observation, earned reward, terminal state, and info dict.
        return self._obs(), reward, self.idx == 0, {}
    def reset(self):
        # Reset the index to the beginning of the training dataset and return the initial observation.
        self.idx = 0
        return self._obs()
    def render(self, action='', mode='human', close=False):
        # Display the labeled observation.
        width, height = self.observation_space.shape[0], self.observation_space.shape[1]
        fig, ax = plt.subplots(1)
        ax.imshow(self._obs().reshape(width, height), cmap='Greys')
        
        # Label with the correct value and action if supplied. 
        title = '{}-{}'.format(action, self.y[self.idx]) if action != '' else self.y[self.idx]
        ax.set_title(title)
        plt.show()

In [138]:
# Load the custom gym into a vectorized environment.
env = DummyVecEnv([lambda: MnistGym(width=128, height=128, channels=1, dataset=(X_train, y_train))])

# Grab the observation shape for generating evaluation frames.
width, height = env.observation_space.shape[0], env.observation_space.shape[1]

In [153]:
def create_model(pretrained=False, save_model=True, epochs=2):
    model_name = "dqn_cnn_mnist"
    
    # Return a pretrained model if the flag is set. Otherwise, train a new model.
    if pretrained:
        return DQN.load(model_name)

    # Create a model from a DQN agent with a CnnPolicy attached to a tensorboard logger.
    model = DQN(CnnPolicy, env, verbose=1, tensorboard_log="./mnist_log")

    # Train the model on several epochs through the full training dataset.
    model.learn(total_timesteps=len(X_train) * epochs)
    
    # Save the new model if the flag is set.
    if save_model:
        model.save(model_name)

    return model

model = create_model(pretrained=False, save_model=True, epochs=2)

In [149]:
model.save("dqn_cnn_mnist")

In [151]:
# Evaluate the model by counting the total rewards attained on the test dataset.
total_rewards = 0

for idx in trange(len(X_test)):
    # Generate an evaluation observation frame.
    obs = cv2.resize(X_test[idx], (width, height), interpolation = cv2.INTER_CUBIC)
    obs = obs.reshape(width, height, 1)
    
    # Predict an action based on the observation.
    action, _states = model.predict(obs)

    # Score the prediction.
    reward = 1 if action == y_test[idx] else 0
    total_rewards += reward

print('Accuracy: {:.2f}%'.format(total_rewards / len(X_test) * 100.0))

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))


Accuracy: 98.00%
