In [1]:
import os
from mingpt.utils import set_seed
set_seed(44)

import math
import time
import numpy as np
from copy import deepcopy
import pickle
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torch.nn import functional as F
from torch.utils.data import Subset
from tqdm import tqdm
from matplotlib import pyplot as plt

from data import get_othello, plot_probs, plot_mentals
from data.othello import permit, start_hands, OthelloBoardState, permit_reverse
from mingpt.dataset import CharDataset
from mingpt.model import GPT, GPTConfig, GPTforProbeIA
from mingpt.utils import sample, intervene, print_board
from mingpt.probe_model import BatteryProbeClassification, BatteryProbeClassificationTwoLayer


## Load nonlinear probes

In [4]:
championship = False
mid_dim = 256
how_many_history_step_to_use = 99
exp = f"state_tl{mid_dim}"
if championship:
    exp += "_championship"


probes = {}
layer = 8
probe = BatteryProbeClassificationTwoLayer(torch.cuda.current_device(), probe_class=3, num_task=64, mid_dim=mid_dim)
load_res = probe.load_state_dict(torch.load(f"./ckpts/battery_othello/{exp}/layer{layer}/checkpoint.ckpt"))
probe.eval()


BatteryProbeClassificationTwoLayer(
  (proj): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=192, bias=True)
  )
)

## Load trained models for probing at layer 8

In [5]:
# othello = get_othello(ood_perc=.2, data_root="data/othello_pgn", wthor=False)
othello = get_othello(ood_perc=0., data_root=None, wthor=False, ood_num=1)
train_dataset = CharDataset(othello)

mconf = GPTConfig(61, 59, n_layer=8, n_head=8, n_embd=512)

model = GPTforProbeIA(mconf, probe_layer=layer, disable_last_layer_norm = True)
load_res = model.load_state_dict(torch.load("./ckpts/gpt_no_last_layer_norm.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)
_ = model.eval()

100%|██████████| 1/1 [00:00<00:00, 30.17it/s]


Dataset created has 1 sequences, 61 unique words.


## Validate it: for what percentage of all partial games in validation set, the top-1 prediction is legal

In [6]:
if not championship:  # for GPT trained on both datasets, use the validation set of synthetic for validation
    othello = get_othello(ood_num=-1, data_root=None, wthor=True)

total_nodes = 0
success_nodes = 0

bar = tqdm(othello.val[:1000])
for whole_game in bar:
    length_of_whole_game = len(whole_game)
    for length_of_partial_game in range(1, length_of_whole_game):
        total_nodes += 1
        context = whole_game[:length_of_partial_game]
        x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
        y = sample(model, x, 1, temperature=1.0)[0]
        completion = [train_dataset.itos[int(i)] for i in y if i != -1]
        try:
            OthelloBoardState().update(completion, prt=False)
        except Exception:
            pass
        else:
            success_nodes += 1
    bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")

Mem Used: 18.02 GB: 100%|██████████| 238/238 [01:15<00:00,  3.17it/s]


Deduplicating...
Deduplicating finished with 23796010 games left


  0%|          | 0/1000 [00:00<?, ?it/s]

Using 20 million for training, 3796010 for validation


99.93% pass rate: 13500/13509 among all searched nodes:  23%|██▎       | 229/1000 [01:40<05:37,  2.29it/s]


KeyboardInterrupt: 

## Load a game from intervention benchmark

In [7]:
with open("intervention_benchmark.pkl", "rb") as input_file:
    dataset = pickle.load(input_file)
    
case_id = 777
completion = dataset[case_id]["history"]

### Check the partial game progression

In [8]:
print(completion)
ab = OthelloBoardState()
ab.update(completion, prt=True)

pre_intv_valids = [permit_reverse(_) for _ in ab.get_valid_moves()]
print("valid moves:", pre_intv_valids)

[37, 29, 18, 42, 19]
--------------------
[]
a                
b                
c                
d       O X      
e       X O      
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['e6']
a                
b                
c                
d       O X      
e       X X X    
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['e6', 'd6']
a                
b                
c                
d       O O O    
e       X X X    
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['e6', 'd6', 'c3']
a                
b                
c     X          
d       X O O    
e       X X X    
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['e6', 'd6', 'c3', 'f3']
a                
b                
c     X          
d       X O O    
e  

In [16]:
if False:
    print(model)

s = torch.tensor([train_dataset.stoi[move] for move in completion], dtype=torch.long).to(device)
s = s[None, :]
print("s: \n", s, s.shape)

h = model.forward_1st_stage(s)
print("h: \n", h[0][-1][:10], h.shape)

out1 = model.head(h)
print("output by running head(h):\n", out1[0][-1][:10], out1.shape)

out2, _ = model(s)
print("output by running the model(s):\n", out2[0][-1][:10], out2.shape)

reconstructed_board, _ = probe((h)[0][-1])
print("reconstructed board:\n", reconstructed_board.squeeze()[:10])
board = torch.argmax(reconstructed_board.squeeze(), dim = -1).reshape(8,8).tolist()
print(board)

"""
['e6', 'd6', 'c3', 'f3', 'c4']
a                
b                
c     X X        
d       X X O    
e       O X X    
f     O          
g                
h                
  1 2 3 4 5 6 7 8
"""
for r in board:
    print("|", end='')
    for c in r:
        if c == 1:
            print(" ", end='')
        elif c==0:
            print("O", end='')
        elif c==2:
            print("X", end='')

    print("|\n", end='')


s: 
 tensor([[34, 28, 19, 39, 20]], device='cuda:0') torch.Size([1, 5])
h: 
 tensor([-0.8545, -0.5327, -0.4665, -0.9722,  1.9432,  1.0412, -0.2281,  2.2611,
        -0.8041, -0.3103], device='cuda:0', grad_fn=<SliceBackward>) torch.Size([1, 5, 512])
output by running head(h):
 tensor([-24.7494,  -1.8201,  -2.2122,  -1.0959,  -2.1191,  -0.8590,  -1.0885,
         -2.1869,  -1.5516,  -1.3359], device='cuda:0',
       grad_fn=<SliceBackward>) torch.Size([1, 5, 61])
output by running the model(s):
 tensor([-24.7494,  -1.8201,  -2.2122,  -1.0959,  -2.1191,  -0.8590,  -1.0885,
         -2.1869,  -1.5516,  -1.3359], device='cuda:0',
       grad_fn=<SliceBackward>) torch.Size([1, 5, 61])
reconstructed board:
 tensor([[ -9.0074,  12.0080,  -5.6608],
        [-13.9133,  18.9850,  -8.2382],
        [-11.3246,  17.7782,  -6.7344],
        [-10.4789,  19.8135,  -8.0795],
        [ -8.3978,  15.3387,  -7.0764],
        [ -7.8179,  16.5528,  -5.1594],
        [-12.8687,  12.0632,  -8.3205],
        [

### Extract the latent space, the head, and the probe to be used for Marabou

In [17]:
class TrinityNet(nn.Module):
    def __init__(self, probe, head, n_embd = 512, vocab_size = 61, mid_dim = 256):
        super().__init__()
        #the probe head
        self.probe = nn.Sequential(
            nn.Linear(n_embd, mid_dim, bias=True),
            nn.ReLU(True),
            nn.Linear(mid_dim, 64*3, bias=True),
        )
        #set weight
        with torch.no_grad():
            self.probe[0].weight = nn.Parameter(probe.proj[0].weight)
            self.probe[0].bias = nn.Parameter(probe.proj[0].bias)
            
            self.probe[2].weight = nn.Parameter(probe.proj[2].weight)
            self.probe[2].bias = nn.Parameter(probe.proj[2].bias)
        #the logits head
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        #set weight
        with torch.no_grad():
            self.head.weight = nn.Parameter(head.weight)
        
    def forward(self, h):
        logits = self.head(h)
        probe = self.probe(h)
        print(logits.shape, probe.shape)
        return torch.cat([logits, probe])
    
    def get_a_board(self):
        return OthelloBoardState()
    
    def play(self, s, is_i= True):
        if isinstance(s, torch.Tensor):
            s = s.squeeze().cpu().numpy()
        if is_i:
            #convert to s
            s = [train_dataset.itos[move] for move in s]
        print("playing the sequence:",s)
        board = self.get_a_board()
        board.update(s, prt=True)
    
trinity = TrinityNet(probe, model.head)
trinity.to(device)

print(trinity(h[0][-1]))

torch.Size([61]) torch.Size([192])
tensor([-2.4749e+01, -1.8201e+00, -2.2122e+00, -1.0959e+00, -2.1191e+00,
        -8.5902e-01, -1.0885e+00, -2.1869e+00, -1.5516e+00, -1.3359e+00,
         1.6450e-02, -1.2670e+00,  8.7403e+00, -8.7246e-01, -1.0914e+00,
        -2.0707e+00, -9.6828e-01, -2.3762e+00,  6.2370e-01, -1.1780e+00,
        -7.2599e-01, -1.2089e-01,  8.9860e+00, -4.0348e-02, -1.5985e+00,
        -5.2786e-01, -1.0326e+00,  8.9397e+00,  9.1746e-01,  1.3736e-01,
        -1.4051e+00, -1.2508e+00, -8.6213e-01, -2.8386e+00, -7.0979e-01,
         8.7975e+00, -2.5311e+00, -7.8931e-01, -1.6974e+00, -3.0895e+00,
         9.0381e+00, -2.4195e-01,  8.8705e+00, -1.7676e+00, -1.3732e+00,
        -2.6736e+00, -1.6396e+00, -1.4218e+00, -8.4368e-01, -2.0650e+00,
        -3.7264e-01, -1.4883e+00, -2.9299e+00, -8.5015e-01, -2.1101e+00,
        -6.4349e-01, -2.5108e+00, -3.8076e-01, -1.9649e+00, -8.7731e-01,
        -1.6228e+00, -9.0074e+00,  1.2008e+01, -5.6608e+00, -1.3913e+01,
         1.8985e

In [18]:
import maraboupy