In [1]:
import pickle
import torch
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

sys.path.append('../')

from generate_embeddings_gridworld import get_embeddings_qvalues
from sklearn.manifold import TSNE
from GPT.dataset import EpisodeDataset
from GPT.model import Config



In [2]:
token_to_idx = {(i, j): i * 9 + j + 1 for i in range(9) for j in range(9)} | {"up": 82, "down": 83, "left": 84, "right": 85}
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 86
block_size = 200
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [3]:
path = ''

In [4]:
with open(os.path.join(path, 'train00.pkl'), 'rb') as f:
    agent00 = pickle.load(f)
with open(os.path.join(path, 'train08.pkl'), 'rb') as f:
    agent08 = pickle.load(f)
with open(os.path.join(path, 'train80.pkl'), 'rb') as f:
    agent80 = pickle.load(f)
with open(os.path.join(path, 'train88.pkl'), 'rb') as f:
    agent88 = pickle.load(f)

In [5]:
with open(os.path.join(path, 'qhist00.pkl'), 'rb') as f:
    qhist00 = pickle.load(f)
with open(os.path.join(path, 'qhist08.pkl'), 'rb') as f:
    qhist08 = pickle.load(f)
with open(os.path.join(path, 'qhist80.pkl'), 'rb') as f:
    qhist80 = pickle.load(f)
with open(os.path.join(path, 'qhist88.pkl'), 'rb') as f:
    qhist88 = pickle.load(f)

In [6]:
train_ratio = 0.8
valid_ratio = 0.1

d00 = len(agent00)
d08 = len(agent08)
d80 = len(agent80)
d88 = len(agent88)

train00 = agent00[:int(train_ratio * d00)]
valid00 = agent00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00) ]
test00 = agent00[int((train_ratio + valid_ratio) * d00): ]

train08 = agent08[:int(train_ratio * d08)]
valid08 = agent08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08) ]
test08 = agent08[int((train_ratio + valid_ratio) * d08): ]

train80 = agent80[:int(train_ratio * d80)]
valid80 = agent80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80) ]
test80 = agent80[int((train_ratio + valid_ratio) * d80): ]

train88 = agent88[:int(train_ratio * d88)]
valid88 = agent88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88) ]
test88 = agent88[int((train_ratio + valid_ratio) * d88): ]

In [7]:
qtrain00 = qhist00[:int(train_ratio * d00)]
qvalid00 = qhist00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00)]
qtest00 = qhist00[int((train_ratio + valid_ratio) * d00):]

qtrain08 = qhist08[:int(train_ratio * d08)]
qvalid08 = qhist08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08)]
qtest08 = qhist08[int((train_ratio + valid_ratio) * d08):]

qtrain80 = qhist80[:int(train_ratio * d80)]
qvalid80 = qhist80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80)]
qtest80 = qhist80[int((train_ratio + valid_ratio) * d80):]

qtrain88 = qhist88[:int(train_ratio * d88)]
qvalid88 = qhist88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88)]
qtest88 = qhist88[int((train_ratio + valid_ratio) * d88):]

In [8]:
# Subsample Sizes
s = 5000

train = train00[:s] + train08[:s] + train80[:s] + train88[:s]
qtrain = qtrain00[:s] + qtrain08[:s] + qtrain80[:s] + qtrain88[:s]

train_dataset = EpisodeDataset(train, token_to_idx)

In [9]:
config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)

In [10]:
layers = range(1, 9)  # Layers from 1 to 8
perplexities = [20, 30, 40]  # Perplexity values
model_load_path = 'Model_12.pth'

for layer in layers:
    
    embed22, _ = get_embeddings_qvalues([(2, 2)], train, qtrain, layer, config, token_to_idx, cutoff=30, model_load_path=model_load_path)
    embed26, _ = get_embeddings_qvalues([(2, 6)], train, qtrain, layer, config, token_to_idx, cutoff=30, model_load_path=model_load_path)
    embed62, _ = get_embeddings_qvalues([(6, 2)], train, qtrain, layer, config, token_to_idx, cutoff=30, model_load_path=model_load_path)
    embed66, _ = get_embeddings_qvalues([(6, 6)], train, qtrain, layer, config, token_to_idx, cutoff=30, model_load_path=model_load_path)
    
    embeddings = embed22 + embed26 + embed62 + embed66
    
    labels_22 = [22] * len(embed22)
    labels_26 = [26] * len(embed26)
    labels_62 = [62] * len(embed62)
    labels_66 = [66] * len(embed66)
    
    labels = labels_22 + labels_26 + labels_62 + labels_66
    
    embeddings_tensor = torch.stack(embeddings)
    embeddings_array = embeddings_tensor.numpy().reshape((-1, 512))
    
    for perplexity in perplexities:
        lwr_dimensional_embedding = TSNE(n_components=2, perplexity=perplexity, n_iter=2000, n_iter_without_progress=500).fit_transform(embeddings_array)
        
        color_map = {22: 'blue', 26: 'orange', 62: 'green', 66: 'purple'}
        colors = [color_map[label] for label in labels]
        
        plt.figure(figsize=(10, 10), dpi=300)
        scatter = plt.scatter(lwr_dimensional_embedding[:, 0], lwr_dimensional_embedding[:, 1], c=colors, s=50)
        
        legend_handles = [plt.Line2D([0], [0], marker='o', color='w', label=str(label),
                                     markerfacecolor=color_map[label], markersize=10)
                          for label in color_map.keys()]
        legend_labels = ['Position ' + str(label) for label in color_map.keys()]
        
        plt.legend(handles=legend_handles, labels=legend_labels, title='Positions', fontsize=12)
        
        plt.xlabel('Dimension 1', fontsize=14)
        plt.ylabel('Dimension 2', fontsize=14)
        plt.title(f't-SNE Visualization (Layer {layer}, Perplexity {perplexity})', fontsize=16)
        
        plt.tight_layout()
        plt.savefig(f'tsne_plot_layer{layer}_perplexity{perplexity}.png')
        plt.close()