In [1]:
import torch
import torch.nn as nn
import numpy as np
from board import Board
from features import Features
from gamestate import GameState
from load_model import load_model  # This is an import from a file within the repo

# Function to parse coordinates
colstr = 'ABCDEFGHJKLMNOPQRST'
def parse_coord(s, board):
    if s == 'pass':
        return Board.PASS_LOC
    return board.loc(colstr.index(s[0].upper()), board.size - int(s[1:]))

def str_coord(loc, board):
    if loc == Board.PASS_LOC:
        return 'pass'
    x = board.loc_x(loc)
    y = board.loc_y(loc)
    return '%c%d' % (colstr[x], board.size - y)

# Load model
checkpoint_path = r"D:\KataGo\kata1-b28c512nbt-s9584861952-d4960414494\model.ckpt"
model, swa_model, other_state_dict = load_model(checkpoint_path, use_swa=False, device="cuda")
if swa_model is not None:
    model = swa_model

model.eval()
board_size = 19
gs = GameState(board_size, GameState.RULES_TT)

pla = Board.BLACK
loc = parse_coord('Q16', gs.board)
gs.board.play(pla, loc)
gs.moves.append((pla, loc))
gs.boards.append(gs.board.copy())

outputs = gs.get_model_outputs(model)
loc = outputs["genmove_result"]
pla = gs.board.pla
gs.board.play(pla, loc)
gs.moves.append((pla, loc))
gs.boards.append(gs.board.copy())
ret = str_coord(loc, gs.board)


  state_dict = torch.load(checkpoint_file,map_location="cpu")


In [2]:
print(model)

Model(
  (conv_spatial): Conv2d(22, 512, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
  (linear_global): Linear(in_features=19, out_features=512, bias=False)
  (blocks): ModuleList(
    (0-1): 2 x NestedBottleneckResBlock(
      (normactconvp): NormActConv(
        (norm): NormMask()
        (act): Mish(inplace=True)
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), padding=same, bias=False)
      )
      (blockstack): ModuleList(
        (0-1): 2 x ResBlock(
          (normactconv1): NormActConv(
            (norm): NormMask()
            (act): Mish(inplace=True)
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          )
          (normactconv2): NormActConv(
            (norm): NormMask()
            (act): Mish(inplace=True)
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          )
        )
      )
      (normactconvq): NormActConv(
        (norm

In [7]:
from torchinfo import summary
import torch

B = 1
C = 22          # spatial input planes
board_size = 19
global_feats = 19  # from linear_global.in_features

# Dummy inputs
spatial = torch.zeros((B, C, board_size, board_size), device="cuda")
globalf = torch.zeros((B, global_feats), device="cuda")

# Print full model summary
summary(
    model,
    input_data=(spatial, globalf),
    device="cuda",
    depth=5,            # increase to see full nested structure
    col_names=("input_size", "output_size", "num_params", "kernel_size", "mult_adds")
)

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
Model                                                   [1, 22, 19, 19]           [1, 6, 362]               --                        --                        --
├─Conv2d: 1-1                                           [1, 22, 19, 19]           [1, 512, 19, 19]          101,376                   [3, 3]                    36,596,736
├─Linear: 1-2                                           [1, 19]                   [1, 512]                  9,728                     --                        9,728
├─ModuleList: 1-3                                       --                        --                        --                        --                        --
│    └─NestedBottleneckResBlock: 2-1                    [1, 512, 19, 19]          [1, 512, 19, 19]          --                        --                        --
│   