Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for LEB128 compression of feature transformer parameters. #251

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 103 additions & 18 deletions serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
import features
import math
import model as M
import numpy
import struct
import torch
import io
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from functools import reduce
import operator
import numpy as np
from numba import njit

def ascii_hist(name, x, bins=6):
N,X = numpy.histogram(x, bins=bins)
N,X = np.histogram(x, bins=bins)
total = 1.0*len(x)
width = 50
nmax = N.max()
Expand All @@ -23,6 +25,36 @@ def ascii_hist(name, x, bins=6):
xi = '{0: <8.4g}'.format(xi).ljust(10)
print('{0}| {1}'.format(xi,bar))

@njit
def encode_leb_128_array(arr):
res = []
for v in arr:
while True:
byte = v & 0x7f
v = v >> 7
if (v == 0 and byte & 0x40 == 0) or (v == -1 and byte & 0x40 != 0):
res.append(byte)
break
res.append(byte | 0x80)
return res

@njit
def decode_leb_128_array(arr, n):
ints = np.zeros(n)
k = 0
for i in range(n):
r = 0
shift = 0
while True:
byte = arr[k]
k = k + 1
r |= (byte & 0x7f) << shift
shift += 7
if (byte & 0x80) == 0:
ints[i] = r if (byte & 0x40) == 0 else r | ~((1 << shift) - 1)
break
return ints

# hardcoded for now
VERSION = 0x7AF32F20
DEFAULT_DESCRIPTION = "Network trained with the https://github.com/glinscott/nnue-pytorch trainer."
Expand All @@ -31,7 +63,7 @@ class NNUEWriter():
"""
All values are stored in little endian.
"""
def __init__(self, model, description=None):
def __init__(self, model, description=None, ft_compression='none'):
if description is None:
description = DEFAULT_DESCRIPTION

Expand All @@ -43,7 +75,7 @@ def __init__(self, model, description=None):
fc_hash = self.fc_hash(model)
self.write_header(model, fc_hash, description)
self.int32(model.feature_set.hash ^ (M.L1*2)) # Feature transformer hash
self.write_feature_transformer(model)
self.write_feature_transformer(model, ft_compression)
for l1, l2, output in model.layer_stacks.get_coalesced_layer_stacks():
self.int32(fc_hash) # FC layers hash
self.write_fc_layer(model, l1)
Expand Down Expand Up @@ -76,7 +108,21 @@ def write_header(self, model, fc_hash, description):
self.int32(len(encoded_description)) # Network definition
self.buf.extend(encoded_description)

def write_feature_transformer(self, model):
def write_leb_128_array(self, arr):
buf = encode_leb_128_array(arr)
self.int32(len(buf))
self.buf.extend(buf)

def write_tensor(self, arr, compression='none'):
if compression == 'none':
self.buf.extend(arr.tobytes())
elif compression == 'leb128':
self.buf.extend('COMPRESSED_LEB128'.encode('utf-8'))
self.write_leb_128_array(arr)
else:
raise Exception('Invalid compression method.')

def write_feature_transformer(self, model, ft_compression):
layer = model.input

bias = layer.bias.data[:M.L1]
Expand All @@ -93,10 +139,11 @@ def write_feature_transformer(self, model):
ascii_hist('ft weight:', weight.numpy())
ascii_hist('ft psqt weight:', psqt_weight.numpy())

self.buf.extend(bias.flatten().numpy().tobytes())
# Weights stored as [num_features][outputs]
self.buf.extend(weight.flatten().numpy().tobytes())
self.buf.extend(psqt_weight.flatten().numpy().tobytes())

self.write_tensor(bias.flatten().numpy(), ft_compression)
self.write_tensor(weight.flatten().numpy(), ft_compression)
self.write_tensor(psqt_weight.flatten().numpy(), ft_compression)

def write_fc_layer(self, model, layer, is_output=False):
# FC layers are stored as int8 weights, and int32 biases
Expand Down Expand Up @@ -170,20 +217,51 @@ def read_header(self, feature_set, fc_hash):
desc_len = self.read_int32()
description = self.f.read(desc_len)

def read_leb_128_array(self, dtype, shape):
l = self.read_int32()
d = self.f.read(l)
if len(d) != l:
raise Exception('Unexpected end of file when reading compressed data.')

res = torch.FloatTensor(decode_leb_128_array(d, reduce(operator.mul, shape, 1)))
res = res.reshape(shape)
return res

def peek(self, length=1):
pos = self.f.tell()
data = self.f.read(length)
self.f.seek(pos)
return data

def determine_compression(self):
leb128_magic = b'COMPRESSED_LEB128'
if self.peek(len(leb128_magic)) == leb128_magic:
self.f.read(len(leb128_magic)) # actually advance the file pointer
return 'leb128'
else:
return 'none'

def tensor(self, dtype, shape):
d = numpy.fromfile(self.f, dtype, reduce(operator.mul, shape, 1))
d = torch.from_numpy(d.astype(numpy.float32))
d = d.reshape(shape)
return d
compression = self.determine_compression()

if compression == 'none':
d = np.fromfile(self.f, dtype, reduce(operator.mul, shape, 1))
d = torch.from_numpy(d.astype(np.float32))
d = d.reshape(shape)
return d
elif compression == 'leb128':
return self.read_leb_128_array(dtype, shape)
else:
raise Exception('Invalid compression method.')

def read_feature_transformer(self, layer, num_psqt_buckets):
shape = layer.weight.shape

bias = self.tensor(numpy.int16, [layer.bias.shape[0]-num_psqt_buckets]).divide(self.model.quantized_one)
bias = self.tensor(np.int16, [layer.bias.shape[0]-num_psqt_buckets]).divide(self.model.quantized_one)
# weights stored as [num_features][outputs]
weights = self.tensor(numpy.int16, [shape[0], shape[1]-num_psqt_buckets])
weights = self.tensor(np.int16, [shape[0], shape[1]-num_psqt_buckets])
weights = weights.divide(self.model.quantized_one)
psqt_weights = self.tensor(numpy.int32, [shape[0], num_psqt_buckets])
psqt_weights = self.tensor(np.int32, [shape[0], num_psqt_buckets])
psqt_weights = psqt_weights.divide(self.model.nnue2score * self.model.weight_scale_out)

layer.bias.data = torch.cat([bias, torch.tensor([0]*num_psqt_buckets)])
Expand All @@ -202,8 +280,8 @@ def read_fc_layer(self, layer, is_output=False):
non_padded_shape = layer.weight.shape
padded_shape = (non_padded_shape[0], ((non_padded_shape[1]+31)//32)*32)

layer.bias.data = self.tensor(numpy.int32, layer.bias.shape).divide(kBiasScale)
layer.weight.data = self.tensor(numpy.int8, padded_shape).divide(kWeightScale)
layer.bias.data = self.tensor(np.int32, layer.bias.shape).divide(kBiasScale)
layer.weight.data = self.tensor(np.int8, padded_shape).divide(kWeightScale)

# Strip padding.
layer.weight.data = layer.weight.data[:non_padded_shape[0], :non_padded_shape[1]]
Expand All @@ -219,6 +297,7 @@ def main():
parser.add_argument("source", help="Source file (can be .ckpt, .pt or .nnue)")
parser.add_argument("target", help="Target file (can be .pt or .nnue)")
parser.add_argument("--description", default=None, type=str, dest='description', help="The description string to include in the network. Only works when serializing into a .nnue file.")
parser.add_argument("--ft_compression", default='none', type=str, dest='ft_compression', help="Compression method to use for FT weights and biases. Either 'none' or 'leb128'. Only allowed if saving to .nnue.")
features.add_argparse_args(parser)
args = parser.parse_args()

Expand All @@ -238,12 +317,18 @@ def main():
else:
raise Exception('Invalid network input format.')

if args.ft_compression != 'none' and not args.target.endswith('.nnue'):
raise Exception('Compression only allowed for .nnue target.')

if args.ft_compression not in ['none', 'leb128']:
raise Exception('Invalid compression method.')

if args.target.endswith('.ckpt'):
raise Exception('Cannot convert into .ckpt')
elif args.target.endswith('.pt'):
torch.save(nnue, args.target)
elif args.target.endswith('.nnue'):
writer = NNUEWriter(nnue, args.description)
writer = NNUEWriter(nnue, args.description, ft_compression=args.ft_compression)
with open(args.target, 'wb') as f:
f.write(writer.buf)
else:
Expand Down