In [89]:
import json
import numpy as np
import math

In [96]:
def func_def(name, coeffs, powers, expansion):
    def build_operands(power, expansion):
        exps = np.array([
            "op1",
            "op2",
            f"(op1 % {expansion})",
            f"(op2 % {expansion})",
        ])
        out = exps[power != 0]
        res = []
        for o, p in zip(out, power[power!=0]):
            new = o + f"**{p}" if p != 1 else o
            res.append(new)
        if len(res) == 1:
            if 'op1' in res[0]:
                res.append('torch.ones_like(op2)')
            else:
                res.append('torch.ones_like(op1)')
        if 'op2' in res[0]:
            res = res[::-1]
        return res

    s = f"def {name.lower()}(base_func, op1, op2, kwargs):\n"
    s += f"    \"\"\"\n    Approximate Multiplication HTP Model for {name}\n    \"\"\"\n"
    s += "    res = base_func(op1, op2, **kwargs)\n"
    for c, p in zip(coeffs, powers):
        if np.all(p == np.array([1, 1, 0, 0])):
            # regular multiplication, already handled in first row
            continue
        coeff = f"{abs(c)} * " if abs(c) != 1.0 else ""
        sgn = "+" if c > 0 else "-"
        ops = build_operands(np.array(p), expansion)
        if len(ops) != 2:
            print(name, ops)
            continue
        s += f"    res {sgn}= {coeff}base_func({ops[0]}, {ops[1]}, **kwargs)\n"
    s += "    return res\n"
    return s


In [97]:
def gen_dict(accurate, mul_names, dict_name):
    s = f"{dict_name} = {{\n"
    s += f"    \"accurate\" : {accurate.lower()},\n"
    for n in mul_names:
        s += f"    \"{n}\" : {n.lower()},\n"
    s += "}\n"
    return s

In [98]:
 with open('htp_params_mul8s.json', 'r') as f:
    params = json.load(f)

In [99]:
s = "\"\"\"\nHigh-Throughput Models for EvoApprox 8-Bit Signed multipliers (mul8s_*)\nThis file is automatically generated.\n\"\"\"\nimport torch\n\n\n"

for mul in params:
    name = mul['name']
    s += func_def(name, mul['htp_params']['coefficients'], mul['htp_params']['powers'], mul['htp_params']['expansion'])
    s += '\n'

In [100]:
accurate = [mul['name'] for mul in params if math.isclose(mul['evoapprox_metrics']['wce'], 0.0)][0]
s += gen_dict(accurate, [mul['name'] for mul in params], "htp_models_mul8s")

In [101]:
with open('htp_models_mul8s.py', 'w+') as f:
    f.write(s)