In [16]:
import pickle
import torch
import sys
import os

sys.path.append('../')

from probe import LinearProbe, NonLinearProbe, ProbeDataset, train_probe, test_probe
from GPT.dataset import EpisodeDataset
from GPT.model import Config, GPTModel
from RL_Training_ConnectFour.env import Connect4Game
from RL_Training_ConnectFour.dqn import DQNAgent, get_random_move

In [17]:
token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)}
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 43
block_size = 42
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [18]:
path = ''

In [25]:
config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)

In [26]:
def take_turns(model, layer, probe, device):
    X = []
    while True:
        env = Connect4Game()
        move_1 = get_random_move(env.state.board)
        _, done = env.step(move_1, 1)
        if done: #hit a terminal state
            return(env.state.find_winner(env.state.board, env.state.last_move[0], env.state.last_move[1]))
        X.append(env.state.last_move)
        
        X_idx = [token_to_idx[token] for token in X]
        X_idx = torch.tensor(X_idx, dtype=torch.long).to(device)
        X_idx = X_idx.unsqueeze(0)
        embedding = model(X_idx, layer)[:, len(X) - 1, :]
        cpu_embed = embedding.cpu()
        pred = probe.predict(cpu_embed, device)
        possible_actions = [xy[1] for xy in get_valid_locations(self.env.state.board)]
        mask = torch.full(pred.shape, float('-inf')).to('cuda')
        mask[0][possible_actions] = pred[0][possible_actions]
        move_2 = torch.argmax(mask).item()

        _, done = env.step(move_2, -1)
        if done:
            return(env.state.find_winner(env.state.board, env.state.last_move[0], env.state.last_move[1]))
        X.append(env.state.last_move)

In [27]:
def decisions_validate(probe_model_path, gpt_model_path, config, linear):
    success_count = 0
    total_attempts = 0

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
            
    if linear:
        probe = LinearProbe(7, 512).to(device)
    else:
        probe = NonLinearProbe(7, 512).to(device)
    probe.load_state_dict(torch.load(probe_model_path, map_location = device))

    model = GPTModel(config).to(device)
    model.load_state_dict(torch.load(gpt_model_path))

    itrs = 10000

    for _ in range(itrs):
        total_attempts += 1
        try:
            x = take_turns(model, 6, probe, device)
            if x == -1:
                success_count += 1
        except (KeyError, AssertionError):
            continue

    success_rate = success_count / total_attempts   
    print(f"Success rate: {success_rate:.2f} ({success_count}/{total_attempts})")


In [28]:
decisions_validate(probe_model_path = 'Mega Probe/Layer_7/Linear_Layer_7/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = True)
decisions_validate(probe_model_path = 'Mega Probe/Layer_7/Nonlinear_Layer_7/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = False)

Success rate: 0.97 (7800/8000)
Success rate: 0.66 (5251/8000)


In [29]:
decisions_validate(probe_model_path = 'Mega Probe/Layer_8/Linear_Layer_8/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = True)
decisions_validate(probe_model_path = 'Mega Probe/Layer_8/Nonlinear_Layer_8/best_model.pth', gpt_model_path = 'Model_12.pth', config = config, linear = False)

Success rate: 0.95 (7569/8000)
Success rate: 0.91 (7262/8000)
