In [2]:
%load_ext autoreload
%autoreload 2
from alphatoe import evals, data, game
import torch
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformer_lens import HookedTransformerConfig, HookedTransformer
import json

In [36]:
#generate games
_, game_list = data.gen_games(gametype="all")

Generating all possible games...
Generated 255168 games
Generated array of moves


In [3]:
model_dict = torch.load("../scripts/models/strategic games-20230711-195141.pt")

In [4]:
#load json config as dictionary
with open("../scripts/models/strategic games-20230711-195141.json") as json_file:
    cfg_dict = json.load(json_file)

cfg_dict

{'experiment_name': 'strategic games',
 'gametype': 'strat',
 'fine_tune': None,
 'n_epochs': 500,
 'lr': 1e-05,
 'weight_decay': 0.0001,
 'batch_size': 32768,
 'train_test_split': 0.8,
 'n_layers': 8,
 'n_heads': 8,
 'd_model': 128,
 'd_head': 16,
 'd_mlp': 512,
 'act_fn': 'relu',
 'normalization_type': None,
 'device': 'cuda',
 'seed': 1337,
 'save_losses': False,
 'save_checkpoints': False,
 'eval_model': False}

In [5]:
cfg = HookedTransformerConfig(
    n_layers = cfg_dict['n_layers'],
    n_heads = cfg_dict['n_heads'],
    d_model = cfg_dict['d_model'],
    d_head = cfg_dict['d_head'],
    d_mlp = cfg_dict['d_mlp'],
    act_fn = cfg_dict['act_fn'],
    #normalization_type=None,
    normalization_type=cfg_dict['normalization_type'],
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 1337,
)

model = HookedTransformer(cfg)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [6]:
model.to("cuda")

Moving model to device:  cuda


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
    (1): TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
      

In [41]:
seq_test = [10,3,1,7]
model(torch.tensor(seq_test))[0, -1]

tensor([  1.9625,  -2.4645,   1.9753,  -7.4530,   2.1081,   1.9256,   2.3395,
        -10.6423,   2.0720, -11.8306], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [7]:
sample = evals.sample_games(model, 1, 1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:59<00:00, 16.89it/s]


In [54]:
len(sample[11])

11

In [8]:
evals.eval_model(sample)

100%|██████████| 1000/1000 [00:00<00:00, 860723.17it/s]


100%|██████████| 1000/1000 [00:00<00:00, 48808.43it/s]


{'repeat moves': 0.365,
 'play after player victory': 0.187,
 'play after draw game': 0.004,
 'inappropriate end state': 0.003,
 'total illegal moves, jack': 0.484}

In [71]:
game.play_game(sample[11])
print(sample[11])
print(evals._check_if_illegal_moves(sample[11]))

| O | X | X |
| X | X | O |
| O | X | O |
[10, 2, 6, 3, 5, 7, 0, 4, 8, 1, 9]
False


In [74]:
other_illegal_games = evals.eval_model(sample)

100%|██████████| 1000/1000 [00:00<00:00, 768891.66it/s]
100%|██████████| 1000/1000 [00:00<00:00, 44107.85it/s]


In [75]:
bad_samples = [sample[i] for i in range(len(sample)) if other_illegal_games[i]]

In [77]:
game.play_game(bad_samples[0])
print(bad_samples[0])


Invalid game
| O | O | X |
| X | O | X |
|   | X | O |
[10, 2, 0, 3, 4, 5, 1, 7, 8, 9, 9]
