In [10]:
import chess
from model import GPTConfig, GPT
from tokenizers import Tokenizer
import torch
import os
import numpy as np
import chess.pgn
import io
from contextlib import nullcontext
from rl_utils import get_rewards, pad_arrays, get_reverse_discount_rewards

In [11]:
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
device = 'cpu'

In [12]:

ckpt_path = os.path.join("out_chess_llm_q_iteration", 'ckpt900.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args']
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
# create the model
gptconf = GPTConfig(**checkpoint_model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
model.to(device)

number of parameters: 102.85M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(23296, 768)
    (wpe): Embedding(768, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=23296, bias=False)
)

In [13]:
tokenizer = Tokenizer.from_file("/data/evan/CS285_Final_Project/model/tokenizer.model")

In [15]:
device_type = 'cpu'
ptdtype = torch.bfloat16
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [16]:
idx_list = torch.IntTensor([0]).to(device)
board = chess.Board()

In [17]:
with ctx:
    idx_next, logit, prob = model.get_next_move(idx_list, board, tokenizer)

torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1, 23296])
torch.Size([1, 23296])


IndexError: index 22044 is out of bounds for dimension 0 with size 1

In [18]:
idx_next.item()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [40]:
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')
batch_size = 2

gradient_accumulation_steps = 10

In [43]:
with ctx:
    probs, boards, dones = model.sample_legal_trajectories(batch_size, tokenizer, device=device)

    # calculate rewards for that trajectory
    targets, masks = get_rewards(boards, dones)

    full_masks = torch.Tensor(pad_arrays(masks)).to(device)

    full_targets = torch.Tensor(pad_arrays(targets)).to(device)

    inputs = full_masks * probs

    loss = loss_fn(inputs, full_targets) / full_masks.sum()

    loss = loss / gradient_accumulation_steps

loss.backward()

In [34]:
with ctx:
    probs, boards, dones = model.sample_legal_trajectories(2, tokenizer, device=device)

    # calculate rewards for that trajectory\
    targets_d, masks_d = get_reverse_discount_rewards(boards, 0.95)

    targets, masks = get_rewards(boards, dones)

    full_masks = torch.Tensor(pad_arrays(masks)).to(device)

    full_target = torch.Tensor(pad_arrays(targets)).to(device)

In [35]:
loss = (torch.nn.functional.binary_cross_entropy_with_logits(probs * full_masks, full_target, reduction='sum') / full_masks.sum()) / 10

In [36]:
probs.dtype

torch.float32

In [37]:
full_masks.dtype

torch.float32

In [38]:
loss.backward()

In [7]:
dones

[True, True]

In [8]:
full_target

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [9]:
boards[1].result()

'1/2-1/2'

In [12]:
boards[1].result

<bound method Board.result of Board('7k/Q5R1/6Qp/7P/8/6P1/5PK1/8 b - - 0 54')>

In [8]:
targets_d

[array([0.05799111, 1.        , 0.06104327, 1.        , 0.06425608,
        1.        , 0.06763798, 1.        , 0.07119787, 1.        ,
        0.07494513, 1.        , 0.07888961, 1.        , 0.08304169,
        1.        , 0.08741231, 1.        , 0.09201296, 1.        ,
        0.09685574, 1.        , 0.10195341, 1.        , 0.10731938,
        1.        , 0.11296777, 1.        , 0.11891344, 1.        ,
        0.12517204, 1.        , 0.13176005, 1.        , 0.13869479,
        1.        , 0.14599451, 1.        , 0.15367843, 1.        ,
        0.16176677, 1.        , 0.17028081, 1.        , 0.17924296,
        1.        , 0.1886768 , 1.        , 0.19860716, 1.        ,
        0.20906017, 1.        , 0.22006333, 1.        , 0.23164562,
        1.        , 0.24383749, 1.        , 0.25667104, 1.        ,
        0.27018004, 1.        , 0.28440005, 1.        , 0.29936847,
        1.        , 0.3151247 , 1.        , 0.33171022, 1.        ,
        0.34916865, 1.        , 0.36754595, 0.99

In [9]:
full_masks

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [10]:
probs

tensor([[-1.1007e+00, -8.5758e-01, -6.8355e-01, -4.4358e+00, -5.2183e-01,
         -3.3869e-02, -1.7516e+00, -5.8120e-01, -6.8881e+00, -9.0856e-01,
         -8.3140e-01, -9.6628e-01, -1.3158e+00, -4.5910e-01, -1.2073e-01,
         -7.5612e-01, -5.0810e+00, -1.8676e+00, -2.4674e+00, -1.2607e+00,
         -2.3355e-02, -2.1168e+00, -1.3593e+00, -2.5023e-01, -2.0986e+00,
         -1.8771e+00, -3.7050e-01, -1.6851e+00, -4.9095e-03, -4.4382e-01,
         -1.6760e+00, -2.0921e+00, -2.0608e+00, -7.0795e-01, -2.5383e+00,
         -3.6730e-01, -1.5822e+00, -3.3371e+00, -4.6375e-01, -8.0519e-01,
         -9.4756e-01, -1.2960e-02, -9.5324e-01, -7.2157e-01, -3.0367e-01,
         -1.6993e+00, -2.9645e+00, -3.4643e+00, -3.3809e+00, -3.6018e-01,
         -8.3803e-02, -2.0811e+00, -2.9186e+00, -1.0013e+00, -1.0997e+00,
         -3.2565e+00, -2.7185e-01, -8.6739e-01, -5.1926e-01, -3.3354e-01,
         -3.5707e+00, -2.4740e+00, -5.4396e-01, -1.8899e+00, -4.9152e-02,
         -3.4279e+00, -1.3927e-01, -1.

In [11]:
full_target

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,

In [12]:
probs.mean().backward()

In [18]:
probs.shape

torch.Size([4, 96])

In [19]:
full_masks.shape

(4, 96)

In [20]:
full_target.shape

(4, 96)

In [13]:
boards

[Board('r2q1r1k/1pp1p1bp/1n4p1/p3pb2/P3N1n1/3B1NB1/1PP2PPP/R2Q1RK1 w - - 4 15'),
 Board('B1b1rk2/p4pp1/5np1/1p3q2/P7/4B1P1/1P2nPKP/R2R4 w - - 1 22'),
 Board('2r3k1/p1r1bpp1/1pP1p2p/4P3/q1pB2P1/4P1QP/P5B1/2RR2K1 w - - 7 29'),
 Board('5r2/1p2r2B/n1p2kR1/3pN3/3P1P1R/4K3/1P5P/8 b - - 4 48')]

In [14]:
for board in boards:
    game = chess.pgn.Game()
    node = game
    for move in board.move_stack:
        node = node.add_main_variation(move)
    print(game, file=open("./eval_games.pgn", 'a'))