In [1]:
import pickle
import torch
import sys
import os

sys.path.append('../')

from generate_embeddings_gridworld import get_embeddings_qvalues, min_max_normalization
from probe import ProbeDataset, train_probe, test_probe
from GPT.dataset import EpisodeDataset
from GPT.model import Config

In [2]:
token_to_idx = {(i, j): i * 9 + j + 1 for i in range(9) for j in range(9)} | {"up": 82, "down": 83, "left": 84, "right": 85}
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 86
block_size = 200
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [3]:
path = ''

In [4]:
with open(os.path.join(path, 'train00.pkl'), 'rb') as f:
    agent00 = pickle.load(f)
with open(os.path.join(path, 'train08.pkl'), 'rb') as f:
    agent08 = pickle.load(f)
with open(os.path.join(path, 'train80.pkl'), 'rb') as f:
    agent80 = pickle.load(f)
with open(os.path.join(path, 'train88.pkl'), 'rb') as f:
    agent88 = pickle.load(f)

In [5]:
with open(os.path.join(path, 'qhist00.pkl'), 'rb') as f:
    qhist00 = pickle.load(f)
with open(os.path.join(path, 'qhist08.pkl'), 'rb') as f:
    qhist08 = pickle.load(f)
with open(os.path.join(path, 'qhist80.pkl'), 'rb') as f:
    qhist80 = pickle.load(f)
with open(os.path.join(path, 'qhist88.pkl'), 'rb') as f:
    qhist88 = pickle.load(f)

In [6]:
train_ratio = 0.8
valid_ratio = 0.1

d00 = len(agent00)
d08 = len(agent08)
d80 = len(agent80)
d88 = len(agent88)

train00 = agent00[:int(train_ratio * d00)]
valid00 = agent00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00) ]
test00 = agent00[int((train_ratio + valid_ratio) * d00): ]

train08 = agent08[:int(train_ratio * d08)]
valid08 = agent08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08) ]
test08 = agent08[int((train_ratio + valid_ratio) * d08): ]

train80 = agent80[:int(train_ratio * d80)]
valid80 = agent80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80) ]
test80 = agent80[int((train_ratio + valid_ratio) * d80): ]

train88 = agent88[:int(train_ratio * d88)]
valid88 = agent88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88) ]
test88 = agent88[int((train_ratio + valid_ratio) * d88): ]

In [7]:
qtrain00 = qhist00[:int(train_ratio * d00)]
qvalid00 = qhist00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00)]
qtest00 = qhist00[int((train_ratio + valid_ratio) * d00):]

qtrain08 = qhist08[:int(train_ratio * d08)]
qvalid08 = qhist08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08)]
qtest08 = qhist08[int((train_ratio + valid_ratio) * d08):]

qtrain80 = qhist80[:int(train_ratio * d80)]
qvalid80 = qhist80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80)]
qtest80 = qhist80[int((train_ratio + valid_ratio) * d80):]

qtrain88 = qhist88[:int(train_ratio * d88)]
qvalid88 = qhist88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88)]
qtest88 = qhist88[int((train_ratio + valid_ratio) * d88):]

In [8]:
# Subsample Sizes
s = 100000
n = 12500

train = train00[:s] + train08[:s] + train80[:s] + train88[:s]
valid = valid00[:n] + valid08[:n] + valid80[:n] + valid88[:n]
test = test00[:n] + test08[:n] + test80[:n] + test88[:n]


qtrain = qtrain00[:s] + qtrain08[:s] + qtrain80[:s] + qtrain88[:s]
qvalid = qvalid00[:n] + qvalid08[:n] + qvalid80[:n] + qvalid88[:n]
qtest = qtest00[:n] + qtest08[:n] + qtest80[:n] + qtest88[:n]

In [9]:
train_dataset = EpisodeDataset(train, token_to_idx)
valid_dataset = EpisodeDataset(valid, token_to_idx)
test_dataset = EpisodeDataset(test, token_to_idx)

In [10]:
config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)

In [11]:
def training_pipeline(folder_path: str, positions: list, layers: list, linear, model_load_path, train_random):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for i in layers:

        curr_path = os.path.join(folder_path, f"Layer_{i}")

        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        
        print(f"Layer {i}")
        
        # Retreive Embeddings and Normalize Q-Values

        embed_train, qval_train = get_embeddings_qvalues(positions, train, qtrain, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)
        embed_valid, qval_valid = get_embeddings_qvalues(positions, valid, qvalid, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)
        embed_test, qval_test = get_embeddings_qvalues(positions, test, qtest, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)

        qval_train_norm, min, max = min_max_normalization(qval_train)
        qval_valid_norm = min_max_normalization(qval_valid, min, max)
        qval_test_norm = min_max_normalization(qval_test, min, max)

        d = len(qval_train_norm[0])
        n = embed_train[0].shape[0]

        # Non-Random

        probe_dataset_train = ProbeDataset(embed_train, qval_train_norm)
        probe_dataset_valid = ProbeDataset(embed_valid, qval_valid_norm)
        probe_dataset_test = ProbeDataset(embed_test, qval_test_norm)

        print("\nTraining Normal Probe\n")
        model_path, train_loss, valid_loss = train_probe(probe_dataset_train, probe_dataset_valid, device = device, epochs = 50, params = (d, n), model_dir = os.path.join(curr_path, f"Non_Random_Layer_{i}"), linear=linear)

        with open(os.path.join(curr_path, "normal_train_loss"), 'wb') as f:
            pickle.dump(train_loss, f)
        with open(os.path.join(curr_path, "normal_valid_loss"), 'wb') as f:
            pickle.dump(valid_loss, f)

        test_loss = test_probe(probe_dataset_test, model_path, (d, n), device, linear)
        print(f"MSE Loss: {test_loss:.4f}")

        if train_random:

            # Random
        
            random_embeddings_train = [torch.randn(512, 1) for _ in range(len(embed_train))]
            random_embeddings_valid = [torch.randn(512, 1) for _ in range(len(embed_valid))]
            random_embeddings_test = [torch.randn(512, 1) for _ in range(len(embed_test))]
            
            random_dataset_train = ProbeDataset(random_embeddings_train, qval_train_norm)
            random_dataset_valid = ProbeDataset(random_embeddings_valid, qval_valid_norm)
            random_dataset_test = ProbeDataset(random_embeddings_test, qval_test_norm)
    
            print("\nTraining Random Probe\n")
            random_path, random_train_loss, random_valid_loss = train_probe(random_dataset_train, random_dataset_valid, device = device, epochs = 50, params = (d, n), model_dir = os.path.join(curr_path, f"Random_Layer_{i}"), linear=linear)
    
            with open(os.path.join(curr_path, "random_train_loss"), 'wb') as f:
                pickle.dump(random_train_loss, f)
            with open(os.path.join(curr_path, "random_valid_loss"), 'wb') as f:
                pickle.dump(random_valid_loss, f)
    
            rand_loss = test_probe(random_dataset_test, random_path, (d, n), device, linear)
            print(f"Random MSE Loss: {rand_loss:.4f}\n")  

In [12]:
training_pipeline(folder_path = 'Nonlinear_Probe', positions = [(2, 2), (2, 6), (6, 2), (6, 6)], layers = list(range(1, 9)), linear = False, model_load_path = 'Model_12.pth', train_random = True)

Layer 1

Training Normal Probe

Best Epoch: 5, Min MSE Loss: 0.11358589092565267
MSE Loss: 0.1162

Training Random Probe

Best Epoch: 3, Min MSE Loss: 0.14671152819063624
Random MSE Loss: 0.1487

Layer 2

Training Normal Probe

Best Epoch: 9, Min MSE Loss: 0.11317466846774966
MSE Loss: 0.1144

Training Random Probe

Best Epoch: 3, Min MSE Loss: 0.14671152819063624
Random MSE Loss: 0.1487

Layer 3

Training Normal Probe

Best Epoch: 9, Min MSE Loss: 0.11355891705450623
MSE Loss: 0.1147

Training Random Probe

Best Epoch: 3, Min MSE Loss: 0.14671152819063624
Random MSE Loss: 0.1487

Layer 4

Training Normal Probe

Best Epoch: 14, Min MSE Loss: 0.11391765904622314
MSE Loss: 0.1138

Training Random Probe

Best Epoch: 3, Min MSE Loss: 0.14671152819063624
Random MSE Loss: 0.1487

Layer 5

Training Normal Probe

Best Epoch: 14, Min MSE Loss: 0.11385590993179523
MSE Loss: 0.1136

Training Random Probe

Best Epoch: 3, Min MSE Loss: 0.14671152819063624
Random MSE Loss: 0.1487

Layer 6

Training 