In [2]:
import json
import numpy as np
import math

In [49]:
def func_def(name, coeffs, powers, expansion=0):
    def build_operands(power, expansion):
        if np.all(power == 0):
            return "1"
        expanded_operands = np.array([
            "op1",
            "op2",
            f"(op1 % {expansion})",
            f"(op2 % {expansion})",
        ])
        expanded_operands = expanded_operands[:len(power)]

        nz_operands = expanded_operands[power != 0]
        pretty_operands = []
        for o, p in zip(nz_operands, power[power!=0]):
            new = o + f"**{p}" if p != 1 else o
            pretty_operands.append(new)
        if len(pretty_operands) == 1:
            # Only one non-zero operand, so replace the other one
            # with one in order to always have two operands
            if 'op1' in pretty_operands[0]:
                pretty_operands.append('torch.ones_like(op2)')
            else:
                pretty_operands.append('torch.ones_like(op1)')

        # Make sure we always have the correct order of op1 first, op2 second
        if 'op2' in pretty_operands[0]:
            pretty_operands = pretty_operands[::-1]
        return pretty_operands

    # Header
    s = f"def {name.lower()}(base_func, op1, op2, kwargs):\n"
    s += f"    \"\"\"\n    Approximate Multiplication HTP Model for {name}\n    \"\"\"\n"

    # Multiply accurate product w. coefficient for linear model
    try:
        base_coeff = [c for c,p in zip(coeffs, powers) if np.all(p==np.array([1,1]))][0]
        s += f"    res = {base_coeff} * base_func(op1, op2, **kwargs)\n"
    except ValueError:
        s += "    res = base_func(op1, op2, **kwargs)\n"

    for c, p in zip(coeffs, powers):
        # regular multiplication, already handled in first row
        if len(p) == 4 and np.all(p == np.array([1, 1, 0, 0])):
            continue
        if len(p) == 2 and np.all(p == np.array([1,1])):
            continue

        coeff = f"{abs(c)} * " if abs(c) != 1.0 else ""
        sgn = "+" if c > 0 else "-"
        ops = build_operands(np.array(p), expansion)
        if len(ops) == 2:
            s += f"    res {sgn}= {coeff}base_func({ops[0]}, {ops[1]}, **kwargs)\n"
        elif len(ops) == 1:
            s += f"    res {sgn}= {coeff}{ops[0]}\n"
    s += "    return res\n"
    return s


In [50]:
def gen_dict(accurate, mul_names, dict_name):
    s = f"{dict_name} = {{\n"
    s += f"    \"accurate\" : {accurate.lower()},\n"
    for n in mul_names:
        s += f"    \"{n}\" : {n.lower()},\n"
    s += "}\n"
    return s

In [51]:
 with open('mul8s_models.json', 'r') as f:
    params = json.load(f)

In [52]:
s = "\"\"\"\nHigh-Throughput Models for EvoApprox 8-Bit Signed multipliers (mul8s_*)\nThis file is automatically generated.\n\"\"\"\nimport torch\n\n\n"

names = []
for mul in params:
    htp_name = f"htp_{mul['name']}"
    s += func_def(htp_name, mul['htp_params']['coefficients'], mul['htp_params']['powers'], mul['htp_params']['expansion'])
    s += '\n'
    lin_name = f"lin_{mul['name']}"
    s += func_def(lin_name, mul['base_params']['coefficients'], mul['base_params']['powers'], 0)
    s += '\n'
    names.append(htp_name)
    names.append(lin_name)

In [53]:
accurate = "htp_" + [mul['name'] for mul in params if math.isclose(mul['evoapprox_metrics']['wce'], 0.0)][0]
s += gen_dict(accurate, names, "htp_models_mul8s")

In [54]:
with open('htp_models_mul8s.py', 'w+') as f:
    f.write(s)