Skip to content

Commit

Permalink
Script for quantizing weights.
Browse files Browse the repository at this point in the history
Use smaller precision to store the weights to decrease the file size.

See discussion in issue #1733.

Pull request #1736.
  • Loading branch information
Ttl authored and gcp committed Aug 16, 2018
1 parent 07c908e commit f85a685
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions training/tf/quantize_weights.py
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
import sys, os, argparse

def format_n(x):
x = float(x)
x = '{:.3g}'.format(x)
x = x.replace('e-0', 'e-')
if x.startswith('0.'):
x = x[1:]
if x.startswith('-0.'):
x = '-' + x[2:]
return x

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description='Quantize network file to decrease the file size.')
parser.add_argument("input",
help='Input file', type=str)

parser.add_argument("-o", "--output",
help='Output file. Defaults to input + "_quantized"',
required=False, type=str, default=None)

args = parser.parse_args()

if args.output == None:
output_name = os.path.splitext(sys.argv[1])
output_name = output_name[0] + '_quantized' + output_name[1]
else:
output_name = args.output
output = open(output_name, 'w')

calculate_error = True
error = 0

with open(args.input, 'r') as f:
for line in f:
line = line.split(' ')
lineq = list(map(format_n, line))

if calculate_error:
e = sum((float(line[i]) - float(lineq[i]))**2 for i in range(len(line)))
error += e/len(line)
output.write(' '.join(lineq) + '\n')

if calculate_error:
print('Weight file difference L2-norm: {}'.format(error**0.5))

output.close()

0 comments on commit f85a685

Please sign in to comment.