In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(44)

In [3]:
import os
import math
import time
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn import functional as F
from data import get_othello
from data.othello import permit, start_hands, OthelloBoardState, permit_reverse
from mingpt.dataset import CharDataset
from mingpt.utils import sample
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig

In [4]:
# ood_num=-1 means use as many simulated games as possible (from "data/othello_synthetic/")
othello = get_othello(ood_num=-1, data_root=None, wthor=True)
train_dataset = CharDataset(othello)
# original othelloGPT params: n_layer=8, n_head=8, n_embd=512
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

  0%|          | 0/43 [00:00<?, ?it/s]

Mem Used: 3.695 GB: 100%|██████████| 43/43 [00:15<00:00,  2.77it/s]


Deduplicating...
Deduplicating finished with 4263809 games left
Using 3411047 for training, 852762 for validation
Dataset created has 3411047 sequences, 61 unique words.


In [5]:
# load_res = model.load_state_dict(torch.load(f"./ckpts/grok/synth_e3.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)
else:
    print("NO GPU FOUND")

In [6]:
# setting up training
max_epochs = 40
t_start = time.strftime("_%Y%m%d_%H%M%S")
ckpt_path = f"./ckpts/bias80_{t_start}.ckpt"
tconf = TrainerConfig(
    max_epochs=max_epochs, 
    batch_size=512*4, # using 4 gpus
    learning_rate=5e-4,
    lr_decay=True, 
    warmup_tokens=len(train_dataset)*train_dataset.block_size*5, 
    final_tokens=len(train_dataset)*train_dataset.block_size*max_epochs,
    num_workers=0, 
    ckpt_path=ckpt_path, 
    # saved_epochs=[1, 4, 9, 14, 34, 74],
)
trainer = Trainer(model, train_dataset, None, tconf)
device = trainer.device
print(t_start)

_20230703_235156


In [7]:
trainer.train()

  0%|          | 0/1666 [00:00<?, ?it/s]

epoch 1 iter 1665: train loss 1.06821. lr 1.000000e-04: 100%|██████████| 1666/1666 [07:26<00:00,  3.74it/s]
epoch 2 iter 1665: train loss 0.92007. lr 2.000000e-04: 100%|██████████| 1666/1666 [07:16<00:00,  3.82it/s]
epoch 3 iter 1665: train loss 0.82917. lr 3.000000e-04: 100%|██████████| 1666/1666 [07:24<00:00,  3.75it/s]
epoch 4 iter 1665: train loss 0.77929. lr 4.000000e-04: 100%|██████████| 1666/1666 [07:17<00:00,  3.81it/s]
epoch 5 iter 1665: train loss 0.74992. lr 5.000000e-04: 100%|██████████| 1666/1666 [07:17<00:00,  3.81it/s]
epoch 6 iter 1665: train loss 0.72719. lr 4.989936e-04: 100%|██████████| 1666/1666 [07:18<00:00,  3.80it/s]
epoch 7 iter 1665: train loss 0.70349. lr 4.959824e-04: 100%|██████████| 1666/1666 [07:18<00:00,  3.80it/s]
epoch 8 iter 1665: train loss 0.70068. lr 4.909907e-04: 100%|██████████| 1666/1666 [07:16<00:00,  3.81it/s]
epoch 9 iter 1665: train loss 0.69452. lr 4.840587e-04: 100%|██████████| 1666/1666 [07:17<00:00,  3.81it/s]
epoch 10 iter 1665: train lo

In [None]:
# loading model from ckpt
t_load = "20230628_230636"
load_res = model.load_state_dict(torch.load(f"./ckpts/custom_{t_load}.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)
else:
    print("NO GPU FOUND")

In [None]:
total_nodes = 0
success_nodes = 0

bar = tqdm(othello.val[:1000])
for whole_game in bar:
    length_of_whole_game = len(whole_game)
    for length_of_partial_game in range(1, length_of_whole_game):
        total_nodes += 1
        context = whole_game[:length_of_partial_game]
        x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
        y = sample(model, x, 1, temperature=1.0)[0]
        completion = [train_dataset.itos[int(i)] for i in y if i != -1]
        try:
            OthelloBoardState().update(completion, prt=False)
        except Exception:
                # fail_nodes.append([permit_reverse(_) for _ in context])
            pass
        else:
            success_nodes += 1
    bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")

In [12]:
partial_game = [19, 34, 41, 11, 10, 9, 1, 20, 3, 2, 8, 0, 13, 4, 29, 12, 5, 6, 14, 15, 21, 37, 22, 33, 7, 26, 18, 16, 25, 17, 42, 23, 30, 24, 31, 32, 46, 38]
OthelloBoardState().update(partial_game, prt=True)

--------------------
[]
a                
b                
c                
d       O X      
e       X O      
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['c4']
a                
b                
c       X        
d       X X      
e       X O      
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['c4', 'e3']
a                
b                
c       X        
d       X X      
e     O O O      
f                
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['c4', 'e3', 'f2']
a                
b                
c       X        
d       X X      
e     X O O      
f   X            
g                
h                
  1 2 3 4 5 6 7 8
--------------------
--------------------
['c4', 'e3', 'f2', 'b4']
a                
b       O        
c       O        
d       O X      
e     X O O      
f   X 

In [12]:
def validate(ckpt, n=1000):
    model = GPT(mconf)
    load_res = model.load_state_dict(torch.load(f"./ckpts/{ckpt}.ckpt"))
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        model = model.to(device)
    else:
        print("NO GPU FOUND")
        return

    total_nodes = 0
    success_nodes = 0

    bar = tqdm(othello.val[:n])
    for whole_game in bar:
        length_of_whole_game = len(whole_game)
        for length_of_partial_game in range(1, length_of_whole_game):
            total_nodes += 1
            context = whole_game[:length_of_partial_game]
            x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
            y = sample(model, x, 1, temperature=1.0)[0]
            completion = [train_dataset.itos[int(i)] for i in y if i != -1]
            try:
                OthelloBoardState().update(completion, prt=False)
            except Exception:
                    # fail_nodes.append([permit_reverse(_) for _ in context])
                pass
            else:
                success_nodes += 1
        bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
    print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")

In [6]:
model_paths = [
    "grok/synth_e1",
    "grok/synth_e2",
    "grok/synth_e3",
    "grok/synth_e4",
    "grok/synth_e5",
    "grok/synth_e7",
    "grok/synth_e10",
    "grok/synth_e15",
    "grok/synth_e20",
    "grok/synth_e40",
]
for path in model_paths:
    print("now validating:", path)
    validate(path)

now validating: grok/synth_e1


96.89% pass rate: 57153/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:56<00:00,  2.40it/s]


96.89% pass rate: 57153/58990 among all searched nodes
now validating: grok/synth_e2


98.67% pass rate: 58207/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]


98.67% pass rate: 58207/58990 among all searched nodes
now validating: grok/synth_e3


99.31% pass rate: 58582/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]


99.31% pass rate: 58582/58990 among all searched nodes
now validating: grok/synth_e4


99.53% pass rate: 58710/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:53<00:00,  2.42it/s]


99.53% pass rate: 58710/58990 among all searched nodes
now validating: grok/synth_e5


99.59% pass rate: 58750/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:53<00:00,  2.42it/s]


99.59% pass rate: 58750/58990 among all searched nodes
now validating: grok/synth_e7


99.82% pass rate: 58885/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:53<00:00,  2.42it/s]


99.82% pass rate: 58885/58990 among all searched nodes
now validating: grok/synth_e10


99.79% pass rate: 58866/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]


99.79% pass rate: 58866/58990 among all searched nodes
now validating: grok/synth_e15


99.85% pass rate: 58900/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.42it/s]


99.85% pass rate: 58900/58990 among all searched nodes
now validating: grok/synth_e20


99.92% pass rate: 58940/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.42it/s]


99.92% pass rate: 58940/58990 among all searched nodes
now validating: grok/synth_e40


99.94% pass rate: 58952/58990 among all searched nodes: 100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]

99.94% pass rate: 58952/58990 among all searched nodes





In [6]:
# summarizing probe loss

import json



root = f"ckpts/grok/probes/state_tl256_random"
errs = []
for i in range(8):
    name = f"layer{i+1}/tensorboard.txt"
    with open(f"{root}/{name}", "r") as file:
        j = json.load(file)
        test_acc = j['test_acc_cont']
        err = 100 * (1 - test_acc[-1])
        errs.append(err)
        # print(f"layer {i+1} error rate: {err:.5f}")

print(", ".join([str(e) for e in errs]))



25.944670627934276, 25.998496185446008, 26.326511150234744, 26.59829812206572, 26.910798122065728, 26.979478433098592, 27.150491490610328, 27.355670481220663
