In [1]:
from torch import nn
import torch
import itertools
import numpy as np
import random
import gym
import os
from collections import deque
from sklearn.manifold import TSNE
from matplotlib import cm
import matplotlib.pyplot as plt

In [2]:
def create_random_obs(env, num_frames, num_samples, device):
    obs = env.reset()
    frames = deque(maxlen=num_frames)
    samples = []
    for _ in range(num_samples):
        action = env.action_space.sample()
    
        new_obs, rew, done, _ = env.step(action)
        transition = (obs, action, rew, done, new_obs)
        samples.append(transition)
        obs_t = torch.as_tensor(obs, dtype=torch.float32).to(device)
        frames.append(obs_t)
        obs = new_obs
    
        if done:
            obs = env.reset()
    return samples

In [3]:
def action_for_obs(samples, network, device):
    q_values_action = []
    for sample in samples:
        obs = sample[0]
        obs = torch.as_tensor(obs, dtype=torch.float32).to(device)
        q_values = network(obs)
        pred_action = torch.argmax(q_values)
        q_values_action.append((q_values, pred_action))
    return q_values_action

In [4]:
def create_tsne_plot(q_values_action):
    q_values = [q.detach().numpy() for q, act in q_values_action]
    tsne = TSNE(2)
    tsne_proj = tsne.fit_transform(q_values)
    cmap = cm.get_cmap('tab20')

    fig, ax = plt.subplots(figsize=(8, 8))
    num_actions = len(q_values_action[0][0])
    actions = np.array([act for q, act in q_values_action])

    for lab in range(num_actions):
        indices = actions == lab
        ax.scatter(tsne_proj[indices, 0], tsne_proj[indices, 1], 
               c = np.array(cmap(lab)).reshape(1, 4),
               label = lab, 
               alpha = .5
              )
    
    ax.legend(fontsize='large', markerscale=2)
    plt.show()
    return

In [5]:
def run(network, env, device, num_frame = 100000, num_samples = 1000):
    samples = create_random_obs(env, num_frame, num_samples, device)
    q_values_action = action_for_obs(samples, network, device)
    create_tsne_plot(q_values_action)