Skip to content

Commit

Permalink
crude visualization for narrow act
Browse files Browse the repository at this point in the history
  • Loading branch information
Sopel97 committed Dec 9, 2021
1 parent 70880e9 commit 0eab12c
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
29 changes: 29 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def _init_layers(self):
self.output.weight = nn.Parameter(output_weight)
self.output.bias = nn.Parameter(output_bias)

def get_narrow_preactivations(self, x, ls_indices):
# precompute and cache the offset for gathers
if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]:
self.idx_offset = torch.arange(0,x.shape[0]*self.count,self.count, device=ls_indices.device)

indices = ls_indices.flatten() + self.idx_offset

l1s_ = self.l1(x).reshape((-1, self.count, L2))
l1f_ = self.l1_fact(x)
l1c_ = l1s_.view(-1, L2)[indices]
return l1c_ + l1f_

def forward(self, x, ls_indices):
# precompute and cache the offset for gathers
if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]:
Expand Down Expand Up @@ -241,6 +253,23 @@ def set_feature_set(self, new_feature_set):
else:
raise Exception('Cannot change feature set from {} to {}.'.format(self.feature_set.name, new_feature_set.name))

def get_narrow_preactivations(self, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices):
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
w, wpsqt = torch.split(wp, L1, dim=1)
b, bpsqt = torch.split(bp, L1, dim=1)
l0_ = (us * torch.cat([w, b], dim=1)) + (them * torch.cat([b, w], dim=1))
# clamp here is used as a clipped relu to (0.0, 1.0)
l0_ = torch.clamp(l0_, 0.0, 1.0)

psqt_indices_unsq = psqt_indices.unsqueeze(dim=1)
wpsqt = wpsqt.gather(1, psqt_indices_unsq)
bpsqt = bpsqt.gather(1, psqt_indices_unsq)
preact = self.layer_stacks.get_narrow_preactivations(l0_, layer_stack_indices)
bucketed_preact = []
for i in range(self.num_ls_buckets):
bucketed_preact.append(torch.masked_select(preact, (layer_stack_indices==i).repeat(preact.shape[1], 1).t()).reshape((-1, L2)))
return bucketed_preact

def forward(self, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices):
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
w, wpsqt = torch.split(wp, L1, dim=1)
Expand Down
150 changes: 150 additions & 0 deletions visualize_narrow_preactivation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import argparse
import chess
import features
import nnue_dataset
import model as M
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from serialize import NNUEReader


class NNUEVisualizer():
def __init__(self, model, args):
self.model = model
self.args = args

self.model.cuda()

import matplotlib as mpl
self.dpi = 100
mpl.rcParams["figure.figsize"] = (
self.args.default_width//self.dpi, self.args.default_height//self.dpi)
mpl.rcParams["figure.dpi"] = self.dpi

def _process_fig(self, name, fig=None):
if self.args.save_dir:
from os.path import join
destname = join(
self.args.save_dir, "{}{}.jpg".format("" if self.args.label is None else self.args.label + "_", name))
print("Saving {}".format(destname))
if fig is not None:
fig.savefig(destname)
else:
plt.savefig(destname)

def get_data(self, count, batch_size):
fen_batch_provider = nnue_dataset.FenBatchProvider(self.args.data, True, 1, batch_size, False, 10)

activations_by_bucket = [[] for i in range(self.model.num_ls_buckets)]
i = 0
while i < count:
fens = next(fen_batch_provider)
batch = nnue_dataset.make_sparse_batch_from_fens(self.model.feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens))
us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch.contents.get_tensors('cuda')
bucketed_preact = self.model.get_narrow_preactivations(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices)

for a, b in zip(activations_by_bucket, bucketed_preact):
a.append(b.cpu().detach().numpy().clip(0, 1))

i += batch_size
print('{}/{}'.format(i, count))

return activations_by_bucket

def plot(self):
bucketed_preact = self.get_data(self.args.count, self.args.batch_size)
for i, d in enumerate(bucketed_preact):
print('Bucket {} has {} entries.'.format(i, sum(a.shape[0] for a in d)))

fig, axs = plt.subplots(M.L2, self.model.num_ls_buckets, sharex=True, sharey=True, figsize=(20, 20), dpi=100)

for bucket_id, preact in enumerate(bucketed_preact):
for i in range(M.L2):
acts = np.concatenate([v[:,i] for v in preact]).flatten()

ax = axs[bucket_id, i]
ax.hist(acts, density=True, log=True, bins=128)
ax.set_xlim([0, 1])
if i == 0:
ax.set_ylabel('Bucket {}'.format(bucket_id))
if bucket_id == 0:
ax.set_xlabel('Neuron {}'.format(i))
ax.xaxis.set_label_position('top')

fig.show()

def load_model(filename, feature_set):
if filename.endswith(".pt") or filename.endswith(".ckpt"):
if filename.endswith(".pt"):
model = torch.load(filename)
else:
model = M.NNUE.load_from_checkpoint(
filename, feature_set=feature_set)
model.eval()
elif filename.endswith(".nnue"):
with open(filename, 'rb') as f:
reader = NNUEReader(f, feature_set)
model = reader.model
else:
raise Exception("Invalid filetype: " + str(filename))

return model


def main():
parser = argparse.ArgumentParser(
description="Visualizes networks in ckpt, pt and nnue format.")
parser.add_argument(
"model", help="Source model (can be .ckpt, .pt or .nnue)")
parser.add_argument(
"--default-width", default=1600, type=int,
help="Default width of all plots (in pixels).")
parser.add_argument(
"--count", default=1000000, type=int,
help="")
parser.add_argument(
"--batch_size", default=5000, type=int,
help="")
parser.add_argument(
"--default-height", default=900, type=int,
help="Default height of all plots (in pixels).")
parser.add_argument(
"--save-dir", type=str, required=False,
help="Save the plots in this directory.")
parser.add_argument(
"--dont-show", action="store_true",
help="Don't show the plots.")
parser.add_argument("--data", type=str, help="path to a .bin or .binpack dataset")
parser.add_argument(
"--label", type=str, required=False,
help="Override the label used in plot titles and as prefix of saved files.")
features.add_argparse_args(parser)
args = parser.parse_args()

supported_features = ('HalfKAv2_hm', 'HalfKAv2_hm^')
assert args.features in supported_features
feature_set = features.get_feature_set_from_name(args.features)

from os.path import basename
label = basename(args.model)

model = load_model(args.model, feature_set)

print("Visualizing {}".format(args.model))

if args.label is None:
args.label = label

visualizer = NNUEVisualizer(model, args)

visualizer.plot()

if not args.dont_show:
plt.show()


if __name__ == '__main__':
main()

0 comments on commit 0eab12c

Please sign in to comment.