In [1]:
from torch import nn
import torch
import itertools
import numpy as np
import random
import gym
import os
import sys
from collections import deque
from sklearn.manifold import TSNE
import matplotlib
#matplotlib.use('TkAgg')
from matplotlib import cm
import matplotlib.pyplot as plt

sys.path.append('../')
from model import DQN
from model_utility import init_weights

In [2]:
from make_env import *

In [3]:
GAME = 'Boxing'
SAVED_MODEL_PATH = '../boxing_final_experiment/boxing_model' 
#DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")

In [4]:
env = make_env(GAME)
image_size = env.observation_space.shape[:2]
action_size = env.action_space.n

Boxing


In [5]:
def load_model():
    model = DQN(image_size, action_size).to(DEVICE)
    model.apply(init_weights)
    state_dict = torch.load(SAVED_MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
    return model, env

In [6]:
def env_reset(env):
    state = env.reset()
    state = state.transpose((2, 0, 1))
    state = torch.autograd.Variable(torch.from_numpy(state).float().unsqueeze(0))
    state = state.to(DEVICE)
    return state

In [7]:
def transform_obs(obs):
    obs = obs.transpose((2, 0, 1))
    obs = torch.autograd.Variable(torch.from_numpy(obs).float().unsqueeze(0))
    return obs

def create_random_obs(env, env_reset, num_frames, num_samples, device):
    obs = env_reset(env)
    frames = deque(maxlen=num_frames)
    samples = []
    for _ in range(num_samples):
        action = env.action_space.sample()
    
        new_obs, rew, done, _ = env.step(action)
        new_obs = transform_obs(new_obs)
        transition = (obs, action, rew, done, new_obs)
        samples.append(transition)
        obs_t = torch.as_tensor(obs, dtype=torch.float32).to(device)
        frames.append(obs_t)
        obs = new_obs
    
        if done:
            obs = env_reset(env)
    return samples

In [8]:
def action_for_obs(samples, network, device):
    q_values_action = []
    for sample in samples:
        obs = sample[0]
        obs = torch.as_tensor(obs, dtype=torch.float32).to(device)
        q_values = network(obs)
        pred_action = torch.argmax(q_values)
        q_values_action.append((q_values.squeeze(0), pred_action))
    return q_values_action

In [12]:
def create_tsne_plot(q_values_action):
    q_values = [q.detach().numpy() for q, act in q_values_action]
    tsne = TSNE(2)
    tsne_proj = tsne.fit_transform(q_values)
    cmap = cm.get_cmap('tab20')

    fig, ax = plt.subplots(figsize=(8, 8))
    num_actions = len(q_values_action[0][0])
    actions = np.array([act for q, act in q_values_action])
    print(len(actions))

    for lab in range(num_actions):
        indices = actions == lab
        ax.scatter(tsne_proj[indices, 0], tsne_proj[indices, 1], 
               c = np.array(cmap(lab)).reshape(1, 4),
               label = lab, 
               alpha = .5
              )
    
    ax.legend(fontsize='large', markerscale=2)
    plt.savefig('boxing_tsne.png')
    plt.show()
    return

In [13]:
def run(network, env, env_reset, device, num_frame = 100000, num_samples = 1000):
    samples = create_random_obs(env, env_reset, num_frame, num_samples, device)
    q_values_action = action_for_obs(samples, network, device)
    create_tsne_plot(q_values_action)

In [14]:
model, env = load_model()
run(model, env, env_reset, DEVICE)

1000


  plt.show()
