In [1]:
import numpy as np
import gym
from go_ai import policies, game, metrics, data
from go_ai.models import value_model
import os
import random
import torch

# Hyperparameters

In [2]:
BOARD_SIZE = 4

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 256
NUM_EVAL_GAMES = 128

In [4]:
INIT_TEMP = 1/64
TEMP_DECAY = 3/4
MIN_TEMP = 1/64

In [5]:
BATCH_SIZE = 32

In [6]:
LOAD_SAVED_MODELS = True

# Data Parameters

In [7]:
EPISODES_DIR = 'episodes/'

In [8]:
CHECKPOINT_PATH = 'checkpoints/checkpoint_{}x{}.pt'.format(BOARD_SIZE, BOARD_SIZE)

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Model

In [11]:
curr_model = value_model.ValueNet(BOARD_SIZE)
checkpoint_model = value_model.ValueNet(BOARD_SIZE)

if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
    print("Initialized checkpoint") 

curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))

optim = torch.optim.Adam(curr_model.parameters(), 1e-3)
curr_model

Initialized checkpoint


ValueNet(
  (convs): Sequential(
    (0): Conv2d(6, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
  )
  (fcs): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
  (criterion): BCEWithLogitsLoss()
)

# Policies

In [12]:
curr_policy = policies.QTempPolicy('Current', curr_model, INIT_TEMP)
checkpoint_policy = policies.QTempPolicy('Checkpoint', checkpoint_model, INIT_TEMP)

random_policy = policies.RandomPolicy()
greedy_policy = policies.QTempPolicy('Greedy', policies.greedy_val_func, temp=0)
human_policy = policies.HumanPolicy()

In [13]:
def decay_temps(policies, temp_decay, min_temp):
    for policy in policies:
        assert hasattr(policy, 'temp')
        policy.temp *= temp_decay
        if policy.temp < min_temp:
            policy.temp = min_temp
        print(f"{policy.name} temp decayed to {policy.temp}")

# Demo and Time Games

Symmetries

In [14]:
%%time
go_env.reset()
action = (1, 1)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 372 ms, sys: 39.7 ms, total: 412 ms
Wall time: 201 ms


With replay memory

In [15]:
%%time
go_env.reset()
_,_ = game.pit(go_env, curr_policy, curr_policy, get_traj=True)

CPU times: user 664 ms, sys: 66.2 ms, total: 730 ms
Wall time: 393 ms


In [16]:
%%time
metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)

CPU times: user 3.65 s, sys: 121 ms, total: 3.77 s
Wall time: 3.4 s


# Train

In [17]:
for iteration in range(ITERATIONS):
    print(f"Iteration {iteration}")
    
    # Make and write out the episode data
    _, replay_data = game.play_games(go_env, curr_policy, curr_policy, True, EPISODES_PER_ITERATION)
        
    # Process the data
    random.shuffle(replay_data)
    replay_data = data.replaylist_to_numpy(replay_data)

    # Optimize
    curr_model.optimize(replay_data, optim, BATCH_SIZE)
    
    # Evaluate against checkpoint model and other baselines
    opp_winrate, _ = game.play_games(go_env, curr_policy, checkpoint_policy, False, EPISODES_PER_ITERATION)

    if opp_winrate > 0.6:
        # New parameters are significantly better. Accept it
        torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
        checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Accepted new model")
        
        # Plot samples of states and response heatmaps
        metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)
        print("Plotted sample trajectory")
        
        rand_winrate, _ = game.play_games(go_env, curr_policy, random_policy, False, NUM_EVAL_GAMES)
        greed_winrate, _ = game.play_games(go_env, curr_policy, greedy_policy, False, NUM_EVAL_GAMES)

    elif opp_winrate >= 0.4:
        # Keep trying
        print(f"{100*opp_winrate:.1f}% Continuing to train current weights")
    else:
        # New parameters are significantly worse. Reject it.
        curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Rejected new model")
        
    # Decay the temperatures if any
    decay_temps([curr_policy, checkpoint_policy], TEMP_DECAY, MIN_TEMP)

Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Iteration 0


Current vs. Current: 100%|██████████| 256/256 [02:02<00:00,  2.08it/s, 54.5%]
Optimizing: 178it [00:27,  6.39it/s, 57.8%, 0.663L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:11<00:00,  1.95it/s, 61.3%]


61.3% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.75
Checkpoint temp decayed to 0.75
Iteration 1


Current vs. Current: 100%|██████████| 256/256 [02:08<00:00,  1.99it/s, 55.3%]
Optimizing: 183it [00:29,  6.19it/s, 62.3%, 0.655L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:04<00:00,  2.05it/s, 50.4%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

50.4% Continuing to train current weights
Current temp decayed to 0.5625
Checkpoint temp decayed to 0.5625
Iteration 2


Current vs. Current: 100%|██████████| 256/256 [02:48<00:00,  1.52it/s, 55.5%]
Optimizing: 168it [01:51,  1.51it/s, 65.5%, 0.637L]
Current vs. Checkpoint: 100%|██████████| 256/256 [10:22<00:00,  2.43s/it, 52.5%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

52.5% Continuing to train current weights
Current temp decayed to 0.421875
Checkpoint temp decayed to 0.421875
Iteration 3


Current vs. Current: 100%|██████████| 256/256 [10:57<00:00,  2.57s/it, 55.9%]
Optimizing: 193it [02:05,  1.54it/s, 63.9%, 0.643L]
Current vs. Checkpoint: 100%|██████████| 256/256 [07:14<00:00,  1.70s/it, 52.0%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

52.0% Continuing to train current weights
Current temp decayed to 0.31640625
Checkpoint temp decayed to 0.31640625
Iteration 4


Current vs. Current: 100%|██████████| 256/256 [42:52<00:00, 10.05s/it, 56.4%]   
Optimizing: 170it [03:13,  1.14s/it, 64.5%, 0.633L]
Current vs. Checkpoint: 100%|██████████| 256/256 [41:28<00:00,  9.72s/it, 52.0%]   
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

52.0% Continuing to train current weights
Current temp decayed to 0.2373046875
Checkpoint temp decayed to 0.2373046875
Iteration 5


Current vs. Current: 100%|██████████| 256/256 [08:28<00:00,  1.99s/it, 54.3%]
Optimizing: 178it [02:18,  1.29it/s, 65.6%, 0.615L]
Current vs. Checkpoint: 100%|██████████| 256/256 [50:26<00:00, 11.82s/it, 55.3%]    
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

55.3% Continuing to train current weights
Current temp decayed to 0.177978515625
Checkpoint temp decayed to 0.177978515625
Iteration 6


Current vs. Current: 100%|██████████| 256/256 [29:52<00:00,  7.00s/it, 51.0%]   
Optimizing: 169it [02:12,  1.27it/s, 63.5%, 0.619L]
Current vs. Checkpoint: 100%|██████████| 256/256 [10:42<00:00,  2.51s/it, 65.8%] 


65.8% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.13348388671875
Checkpoint temp decayed to 0.13348388671875
Iteration 7


Current vs. Current: 100%|██████████| 256/256 [08:35<00:00,  2.01s/it, 58.6%]
Optimizing: 194it [07:41,  2.38s/it, 67.1%, 0.617L]
Current vs. Checkpoint: 100%|██████████| 256/256 [10:53<00:00,  2.55s/it, 49.8%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

49.8% Continuing to train current weights
Current temp decayed to 0.1001129150390625
Checkpoint temp decayed to 0.1001129150390625
Iteration 8


Current vs. Current: 100%|██████████| 256/256 [14:12<00:00,  3.33s/it, 58.6%]   
Optimizing: 162it [02:34,  1.05it/s, 75.3%, 0.508L]
Current vs. Checkpoint: 100%|██████████| 256/256 [14:23<00:00,  3.37s/it, 61.5%]  


61.5% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.07508468627929688
Checkpoint temp decayed to 0.07508468627929688
Iteration 9


Current vs. Current: 100%|██████████| 256/256 [13:41<00:00,  3.21s/it, 57.6%] 
Optimizing: 190it [02:40,  1.19it/s, 71.9%, 0.528L]
Current vs. Checkpoint: 100%|██████████| 256/256 [15:04<00:00,  3.53s/it, 54.5%]  
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

54.5% Continuing to train current weights
Current temp decayed to 0.056313514709472656
Checkpoint temp decayed to 0.056313514709472656
Iteration 10


Current vs. Current: 100%|██████████| 256/256 [12:21<00:00,  2.90s/it, 51.8%] 
Optimizing: 212it [03:35,  1.01s/it, 69.4%, 0.583L]
Current vs. Checkpoint: 100%|██████████| 256/256 [15:54<00:00,  3.73s/it, 44.7%]  
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

44.7% Continuing to train current weights
Current temp decayed to 0.04223513603210449
Checkpoint temp decayed to 0.04223513603210449
Iteration 11


Current vs. Current: 100%|██████████| 256/256 [16:09<00:00,  3.79s/it, 61.3%]   
Optimizing: 128it [07:44,  3.63s/it, 71.9%, 0.518L]
Current vs. Checkpoint: 100%|██████████| 256/256 [14:01<00:00,  3.29s/it, 54.9%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

54.9% Continuing to train current weights
Current temp decayed to 0.03167635202407837
Checkpoint temp decayed to 0.03167635202407837
Iteration 12


Current vs. Current: 100%|██████████| 256/256 [14:06<00:00,  3.31s/it, 59.6%]  
Optimizing: 225it [03:43,  1.01it/s, 60.0%, 0.683L]
Current vs. Checkpoint: 100%|██████████| 256/256 [11:39<00:00,  2.73s/it, 14.5%] 
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

14.5% Rejected new model
Current temp decayed to 0.023757264018058777
Checkpoint temp decayed to 0.023757264018058777
Iteration 13


Current vs. Current: 100%|██████████| 256/256 [12:54<00:00,  3.03s/it, 67.6%] 
Optimizing: 195it [09:29,  2.92s/it, 80.0%, 0.427L]
Current vs. Checkpoint: 100%|██████████| 256/256 [16:49<00:00,  3.94s/it, 84.6%]   


84.6% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.017817948013544083
Checkpoint temp decayed to 0.017817948013544083
Iteration 14


Current vs. Current: 100%|██████████| 256/256 [1:36:21<00:00, 22.58s/it, 91.2%]   
Optimizing: 156it [00:33,  4.70it/s, 91.4%, 0.259L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:19<00:00,  1.83it/s, 63.9%]


63.9% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 15


Current vs. Current: 100%|██████████| 256/256 [02:17<00:00,  1.86it/s, 72.7%]
Optimizing: 203it [00:40,  5.04it/s, 72.9%, 0.592L]
Current vs. Checkpoint: 100%|██████████| 256/256 [01:48<00:00,  2.36it/s, 79.7%]


79.7% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 16


Current vs. Current: 100%|██████████| 256/256 [01:50<00:00,  2.32it/s, 88.1%]
Optimizing: 139it [00:29,  4.77it/s, 88.3%, 0.324L]
Current vs. Checkpoint: 100%|██████████| 256/256 [01:52<00:00,  2.27it/s, 95.1%]


95.1% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 17


Current vs. Current: 100%|██████████| 256/256 [02:12<00:00,  1.93it/s, 76.0%]
Optimizing: 204it [00:40,  5.01it/s, 81.6%, 0.405L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:08<00:00,  2.00it/s, 79.5%]


79.5% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 18


Current vs. Current: 100%|██████████| 256/256 [13:00<00:00,  3.05s/it, 76.0%]  
Optimizing: 185it [00:36,  5.03it/s, 84.4%, 0.370L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:07<00:00,  2.01it/s, 85.9%]


85.9% Accepted new model


Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Plotted sample trajectory
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 19


Current vs. Current: 100%|██████████| 256/256 [02:04<00:00,  2.06it/s, 78.9%]
Optimizing: 180it [00:38,  4.71it/s, 80.3%, 0.408L]
Current vs. Checkpoint:  42%|████▏     | 108/256 [01:02<01:19,  1.87it/s, 37.0%]

KeyboardInterrupt: 

# Evaluate

Play against our AI

In [None]:
set_temps([curr_policy, checkpoint_policy], 0)

In [None]:
game.pit(go_env, human_policy, checkpoint_policy, False)