# First, a quick demo to show that they developed reasonable policies

In [2]:
%matplotlib
import sys
sys.path.append('..')

import gym
import torch
from gym import ObservationWrapper
from gym.wrappers import AtariPreprocessing
from gym.wrappers.frame_stack import FrameStack
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

from qlearner import TargetQLearning

class TorchWrapper(ObservationWrapper):
    def observation(self, obs):
        return torch.tensor(obs).float()
    
q1 = torch.load('./MTQ_1_389474_steps.pt')
q2 = torch.load('./MTQ_2_389474_steps.pt')
    
env1_name = 'BreakoutNoFrameskip-v4'
env2_name = 'PongNoFrameskip-v4'

env1 = TorchWrapper(FrameStack(AtariPreprocessing(gym.make(env1_name)), num_stack=4))
env2 = TorchWrapper(FrameStack(AtariPreprocessing(gym.make(env2_name)), num_stack=4))

qs = [q1, q2]

agt1 = TargetQLearning(
    n_actions=env1.action_space.n,
    target_lag=1000,
    opt_args={
        "lr": 0.0001
    },
    transitions_per_fit=4,
    memory_len=1
)
agt1.Q = q1

agt2 = TargetQLearning(
    n_actions=env2.action_space.n,
    target_lag=1000,
    opt_args={
        "lr": 0.0001
    },
    transitions_per_fit=4,
    memory_len=1
)
agt2.Q = q2

Using matplotlib backend: Qt5Agg
Building None
Building None


# Pong

In [3]:
anim = agt2.play(env2)

# Breakout

In [4]:
anim = agt1.play(env1)

# Now, let's look at the generated features

In [5]:
def get_features(s, agt):
    vec_w = agt.Q[0]
    h = agt.Q[1]

    feats = vec_w(s)
    feats = h(feats)
    
    return feats

In [6]:
# Grab 5000 states from each env

e1_states = torch.zeros((5000, 4, 84, 84))
e2_states = torch.zeros((5000, 4, 84, 84))
e1_screens = np.zeros((5000, 210, 160, 3), dtype='byte')
e2_screens = np.zeros((5000, 210, 160, 3), dtype='byte')

s = env1.reset()
done = False
games = 0
for step in range(5000):
    a = agt1.get_action(s)
    s, r, done, _ = env1.step(a)
    e1_states[step] = s
    e1_screens[step] = env1.render(mode='rgb_array')
    if done:
        games += 1
        done = False
        s = env1.reset()
print(f'Played {games}  games of breakout')
        
s = env2.reset()
done = False
games = 0
for step in range(5000):
    a = agt2.get_action(s)
    s, r, done, _ = env2.step(a)
    e2_states[step] = s
    e2_screens[step] = env2.render(mode='rgb_array')
    if done:
        games += 1
        done = False
        s = env2.reset()
print(f'Played {games}  games of Pong')

Played 12  games of breakout
Played 1  games of Pong


In [7]:
f1 = get_features(e1_states, agt1)
f2 = get_features(e2_states, agt2)
fs = torch.cat([f1, f2])
print(fs.shape)

torch.Size([10000, 256])


In [8]:
X_embedded = TSNE(n_components=2, perplexity=30, learning_rate=10).fit_transform(fs.detach().numpy())

In [17]:
vals = np.zeros(10000)
vals[:5000] = torch.max(agt1.Q[2](f1), dim=1).values.detach().numpy()
vals[5000:] = torch.max(agt2.Q[2](f2), dim=1).values.detach().numpy()

plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=vals, alpha=0.3, cmap='magma')

cb = plt.colorbar()
plt.title('t-SNE Plot for Features Extracted from Pong and Breakout')
plt.savefig('../figs/atari_tsne.png')

In [19]:
envs = np.zeros(10000)
envs[:5000] = 0
envs[5000:] = 1.

plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=envs, alpha=0.3, cmap='magma')
cb = plt.colorbar(ticks=[0, 1])
cb.ax.set_yticklabels(['Breakout', 'Pong'])
plt.title('t-SNE Plot Colored by Game')
plt.savefig('../figs/atari_game_tsne.png')