In [9]:
import numpy as np
import collections
import dataclasses
import typing
import einops
import torch
import sys
import pathlib
import logging
import itertools
if '/Users/urolyi1/repos/skyjo' not in sys.path:
    sys.path.append('/Users/urolyi1/repos/skyjo')

In [10]:
import skyjo as sj
import skynet
import mcts_new as mcts
import train_new as train
import explain as explain

In [3]:
mps_tensor = torch.randint(0, 100, (2, 3, 4, 17), device='mps')
mps_tensor.shape

torch.Size([2, 3, 4, 17])

In [4]:
cpu_tensor = torch.randint(0, 100, (2, 3, 4, 17), device='cpu')
cpu_tensor.shape

torch.Size([2, 3, 4, 17])

In [6]:
%%timeit
mps_tensor.cpu()

363 μs ± 23.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
%%timeit
mps_tensor.cpu().numpy()

371 μs ± 18.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [125]:
batch_size = 64
spatial_tensor = torch.rand((batch_size, 2, 3, 4, 17))
nonspatial_tensor = torch.rand((batch_size, sj.GAME_SIZE,))

In [126]:
cpu_model = skynet.SkyNet1D(
    spatial_input_shape=(2, sj.ROW_COUNT, sj.COLUMN_COUNT, sj.FINGER_SIZE),
    non_spatial_input_shape=(sj.GAME_SIZE,),
    value_output_shape=(2,),
    policy_output_shape=(sj.MASK_SIZE,),
)
cpu_model.eval()
# warm-up
for _ in range(30):
    cpu_model(spatial_tensor, nonspatial_tensor)

In [127]:
%%timeit
with torch.no_grad():
    cpu_model.forward(spatial_tensor, nonspatial_tensor)

1.27 ms ± 124 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [128]:
%%timeit
cpu_model.forward(spatial_tensor, nonspatial_tensor)

1.31 ms ± 118 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [129]:
mps_model = skynet.SkyNet1D(
    spatial_input_shape=(2, sj.ROW_COUNT, sj.COLUMN_COUNT, sj.FINGER_SIZE),
    non_spatial_input_shape=(sj.GAME_SIZE,),
    value_output_shape=(2,),
    policy_output_shape=(sj.MASK_SIZE,),
)
mps_model.set_device('mps')
mps_model.eval()
for _ in range(30):
    mps_model(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

In [130]:
%%timeit
with torch.no_grad():
    mps_model.forward(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

1.94 ms ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [131]:
%%timeit
mps_model.forward(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

2.12 ms ± 170 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [132]:
cpu_model_2d = skynet.SkyNet2D(
    spatial_input_shape=(2, sj.ROW_COUNT, sj.COLUMN_COUNT, sj.FINGER_SIZE),
    non_spatial_input_shape=(sj.GAME_SIZE,),
    value_output_shape=(2,),
    policy_output_shape=(sj.MASK_SIZE,),
)
cpu_model_2d.eval()
# warm-up
for _ in range(30):
    cpu_model_2d(spatial_tensor, nonspatial_tensor)

In [133]:
%%timeit
with torch.no_grad():
    cpu_model_2d.forward(spatial_tensor, nonspatial_tensor)

6.32 ms ± 673 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [134]:
%%timeit
cpu_model_2d.forward(spatial_tensor, nonspatial_tensor)

6.83 ms ± 455 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [135]:
mps_model_2d = skynet.SkyNet2D(
    spatial_input_shape=(2, sj.ROW_COUNT, sj.COLUMN_COUNT, sj.FINGER_SIZE),
    non_spatial_input_shape=(sj.GAME_SIZE,),
    value_output_shape=(2,),
    policy_output_shape=(sj.MASK_SIZE,),
)
mps_model_2d.set_device('mps')
mps_model_2d.eval()
for _ in range(30):
    mps_model_2d(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

In [136]:
%%timeit
with torch.no_grad():
    mps_model_2d.forward(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

2.37 ms ± 78 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [137]:
%%timeit
mps_model_2d.forward(spatial_tensor.to(device='mps'), nonspatial_tensor.to(device='mps'))

2.43 ms ± 79.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
