In [16]:
import pickle
import torch
import sys
import os

sys.path.append('../')

from generate_embeddings_gridworld import get_embeddings_qvalues, min_max_normalization
from probe import LinearProbe, NonLinearProbe, ProbeDataset, train_probe, test_probe
from GPT.dataset import EpisodeDataset
from GPT.model import Config, GPTModel

In [17]:
token_to_idx = {(i, j): i * 9 + j + 1 for i in range(9) for j in range(9)} | {"up": 82, "down": 83, "left": 84, "right": 85}
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 86
block_size = 200
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [18]:
path = ''

In [19]:
with open(os.path.join(path, 'train00.pkl'), 'rb') as f:
    agent00 = pickle.load(f)
with open(os.path.join(path, 'train08.pkl'), 'rb') as f:
    agent08 = pickle.load(f)
with open(os.path.join(path, 'train80.pkl'), 'rb') as f:
    agent80 = pickle.load(f)
with open(os.path.join(path, 'train88.pkl'), 'rb') as f:
    agent88 = pickle.load(f)

In [20]:
with open(os.path.join(path, 'qhist00.pkl'), 'rb') as f:
    qhist00 = pickle.load(f)
with open(os.path.join(path, 'qhist08.pkl'), 'rb') as f:
    qhist08 = pickle.load(f)
with open(os.path.join(path, 'qhist80.pkl'), 'rb') as f:
    qhist80 = pickle.load(f)
with open(os.path.join(path, 'qhist88.pkl'), 'rb') as f:
    qhist88 = pickle.load(f)

In [21]:
train_ratio = 0.8
valid_ratio = 0.1

d00 = len(agent00)
d08 = len(agent08)
d80 = len(agent80)
d88 = len(agent88)

train00 = agent00[:int(train_ratio * d00)]
valid00 = agent00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00) ]
test00 = agent00[int((train_ratio + valid_ratio) * d00): ]

train08 = agent08[:int(train_ratio * d08)]
valid08 = agent08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08) ]
test08 = agent08[int((train_ratio + valid_ratio) * d08): ]

train80 = agent80[:int(train_ratio * d80)]
valid80 = agent80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80) ]
test80 = agent80[int((train_ratio + valid_ratio) * d80): ]

train88 = agent88[:int(train_ratio * d88)]
valid88 = agent88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88) ]
test88 = agent88[int((train_ratio + valid_ratio) * d88): ]

In [22]:
qtrain00 = qhist00[:int(train_ratio * d00)]
qvalid00 = qhist00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00)]
qtest00 = qhist00[int((train_ratio + valid_ratio) * d00):]

qtrain08 = qhist08[:int(train_ratio * d08)]
qvalid08 = qhist08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08)]
qtest08 = qhist08[int((train_ratio + valid_ratio) * d08):]

qtrain80 = qhist80[:int(train_ratio * d80)]
qvalid80 = qhist80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80)]
qtest80 = qhist80[int((train_ratio + valid_ratio) * d80):]

qtrain88 = qhist88[:int(train_ratio * d88)]
qvalid88 = qhist88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88)]
qtest88 = qhist88[int((train_ratio + valid_ratio) * d88):]

In [23]:
# Subsample Sizes
s = 100000
n = 12500

train = train00[:s] + train08[:s] + train80[:s] + train88[:s]
valid = valid00[:n] + valid08[:n] + valid80[:n] + valid88[:n]
test = test00[:n] + test08[:n] + test80[:n] + test88[:n]


qtrain = qtrain00[:s] + qtrain08[:s] + qtrain80[:s] + qtrain88[:s]
qvalid = qvalid00[:n] + qvalid08[:n] + qvalid80[:n] + qvalid88[:n]
qtest = qtest00[:n] + qtest08[:n] + qtest80[:n] + qtest88[:n]

In [24]:
train_dataset = EpisodeDataset(train, token_to_idx)
valid_dataset = EpisodeDataset(valid, token_to_idx)
test_dataset = EpisodeDataset(test, token_to_idx)

In [25]:
config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)

In [11]:
def mega_training_pipeline(folder_path: str, positions: list, layers: list, model_load_path):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for i in layers:

        curr_path = os.path.join(folder_path, f"Layer_{i}")

        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        
        print(f"Layer {i}")
        
        # Retreive Embeddings and Normalize Q-Values

        embed_train, qval_train = get_embeddings_qvalues(positions, train, qtrain, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)
        embed_valid, qval_valid = get_embeddings_qvalues(positions, valid, qvalid, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)
        embed_test, qval_test = get_embeddings_qvalues(positions, test, qtest, i, config, token_to_idx, cutoff = 30, model_load_path = model_load_path)

        qval_train_norm, min, max = min_max_normalization(qval_train)
        qval_valid_norm = min_max_normalization(qval_valid, min, max)
        qval_test_norm = min_max_normalization(qval_test, min, max)

        d = len(qval_train_norm[0])
        n = embed_train[0].shape[0]

        # Non-Random

        probe_dataset_train = ProbeDataset(embed_train, qval_train_norm)
        probe_dataset_valid = ProbeDataset(embed_valid, qval_valid_norm)
        probe_dataset_test = ProbeDataset(embed_test, qval_test_norm)

        print("\nTraining Linear Probe\n")
        model_path_linear, _, _ = train_probe(probe_dataset_train, probe_dataset_valid, device = device, epochs = 100, params = (d, n), model_dir = os.path.join(curr_path, f"Linear_Layer_{i}"), linear=True)
        test_loss_linear = test_probe(probe_dataset_test, model_path_linear, (d, n), device, linear = True)
        print(f"MSE Loss Linear: {test_loss_linear:.4f}")

        print("\nTraining Nonlinear Probe\n")
        model_path_nonlin, _, _ = train_probe(probe_dataset_train, probe_dataset_valid, device = device, epochs = 100, params = (d, n), model_dir = os.path.join(curr_path, f"Nonlinear_Layer_{i}"), linear=False)
        test_loss_nonlin = test_probe(probe_dataset_test, model_path_nonlin, (d, n), device, linear = False)
        print(f"MSE Loss Nonlinear: {test_loss_nonlin:.4f}")



In [None]:
mega_training_pipeline('Mega Probe', [(0, 0), (8, 0), (0, 8), (8, 8), (2, 2), (2, 6), (6, 2), (6, 6), (4, 2), (4, 6), (2, 4), (6, 4), (4, 0), (4, 8), (0, 4), (0, 8)], layers = [7, 8], model_load_path = 'Model_12.pth')

Layer 7

Training Linear Probe

Best Epoch: 90, Min MSE Loss: 0.06121562799219897
MSE Loss Linear: 0.0607

Training Nonlinear Probe

Best Epoch: 16, Min MSE Loss: 0.05889081619028584
MSE Loss Nonlinear: 0.0578
Layer 8

Training Linear Probe

Best Epoch: 82, Min MSE Loss: 0.061162354209258585
MSE Loss Linear: 0.0614

Training Nonlinear Probe

Best Epoch: 9, Min MSE Loss: 0.058891325965519645


In [26]:
def take_turns(starting_pos, model, layer, probe, device):

    get_direction = lambda x: {0: "up", 1: "down", 2: "left", 3: "right"}[torch.argmax(x).item()]

    directions = {
        "up": (0, 1),
        "down": (0, -1),
        "left": (-1, 0),
        "right": (1, 0),
    }

    curr_pos = starting_pos
    X = [starting_pos]

    while curr_pos != (4, 4):
        
        X_idx = [token_to_idx[token] for token in X]
        X_idx = torch.tensor(X_idx, dtype=torch.long).to(device)
        X_idx = X_idx.unsqueeze(0)

        embedding = model(X_idx, layer)[:, len(X) - 1, :]
        cpu_embed = embedding.cpu()
        pred = probe.predict(cpu_embed, device)

        direction = get_direction(pred)

        dx, dy = directions[direction]
        new_pos = (curr_pos[0] + dx, curr_pos[1] + dy)

        curr_pos = new_pos
        X.append(direction)
        X.append(curr_pos) 


In [27]:
def decisions_validate(probe_model_path, gpt_model_path, config, linear):
    success_count = 0
    total_attempts = 0

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
            
    if linear:
        probe = LinearProbe(4, 512).to(device)
    else:
        probe = NonLinearProbe(4, 512).to(device)
    probe.load_state_dict(torch.load(probe_model_path, map_location = device))

    model = GPTModel(config).to(device)
    model.load_state_dict(torch.load(gpt_model_path))

    itrs = 100

    for _ in range(itrs):
        for i in range(9):
            for j in range(9):
                if (i, j) == (4, 4):
                    continue   
                total_attempts += 1
                try:
                    take_turns((i, j), model, 6, probe, device)
                    success_count += 1
                except (KeyError, AssertionError):
                    continue

    success_rate = success_count / total_attempts   
    print(f"Success rate: {success_rate:.2f} ({success_count}/{total_attempts})")


In [28]:
decisions_validate(probe_model_path = 'Mega Probe/Layer_7/Linear_Layer_7/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = True)
decisions_validate(probe_model_path = 'Mega Probe/Layer_7/Nonlinear_Layer_7/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = False)

Success rate: 0.97 (7800/8000)
Success rate: 0.66 (5251/8000)


In [29]:
decisions_validate(probe_model_path = 'Mega Probe/Layer_8/Linear_Layer_8/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = True)
decisions_validate(probe_model_path = 'Mega Probe/Layer_8/Nonlinear_Layer_8/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = False)

Success rate: 0.95 (7569/8000)
Success rate: 0.91 (7262/8000)
