In [25]:
from aqt.jax.v2.aqt_dot_general import CalibrationMode
from functools import partial
import aqt.jax.v2.config as aqt_config
from aqt.jax.v2 import aqt_quantizer
from aqt.jax.v2 import calibration
from aqt.jax.v2 import utils as aqt_utils
from aqt.jax.v2.numerics import int_numerics
import jax.numpy as np
import jax

In [77]:
fully_quantized = partial(
    aqt_config.fully_quantized,
    calibration_mode=CalibrationMode.ALL_AXES, use_stochastic_rounding=False,
)

def q_dot_maybe(fwd=None, bwd=None):
    if fwd is None and bwd is None:
        return np.dot
    else:
        return quant_dot_for_dot(fully_quantized(fwd_bits=fwd, bwd_bits=bwd))

def quant_dot_for_dot(general_dot):
    def _dot(a, b):
        contr_dims = ((a.ndim-1,), (0,))
        batch_dims = ((), ())
        return general_dot(a, b, (contr_dims, batch_dims))
    return jax.jit(_dot)

def make_quantizer(bits=8):
    return aqt_quantizer.Quantizer(
        numerics=int_numerics.IntNumerics(
            bits=bits,
            preserve_zero=True,
            # preserve_max_val=True,
            preserve_max_val=False,
            clip=True,
            clip_gradient=True,
            round=True,
            noise_fn=None,
        ),
        # calib_shared_axes=-1,
        calib_shared_axes=None,
        scale_stop_grad=True,
        calibration=calibration.AbsMaxCalibration,
        po2_scale=False,
        context=aqt_utils.Context(key=jax.random.PRNGKey(0), train_step=None)
    )

In [78]:
key = jax.random.PRNGKey(0)
H = 4
a, b = np.split(jax.random.normal(key, (2*H, H)), 2)
print('a', a)
print('b', b)

a [[-0.5338914   0.8417911   0.8115571   0.05308708]
 [ 0.72478807 -0.5391156  -0.21932149  0.5509203 ]
 [ 0.16972555  1.1971722  -1.0609422   0.28213271]
 [-1.0543169   1.0187539  -0.42167255 -2.5889838 ]]
b [[ 0.3031899  -0.7655693   1.3062729  -0.7149365 ]
 [-0.18686387 -1.8082983  -0.46174228  0.17252915]
 [ 0.43107846  0.2948003  -0.8942256  -0.30150604]
 [ 0.27695706 -1.4905776  -0.5799751   0.9487235 ]]


In [93]:
lhs, rhs = 4, 2
qd = q_dot_maybe(fwd=(lhs, None), bwd=8)
qd(np.eye(2), np.eye(2))

Array([[0.9333334, 0.       ],
       [0.       , 0.9333334]], dtype=float32)

In [94]:
q = make_quantizer(bits=4)
iq, _ = q.quant(np.eye(2), calibration_axes=-1)

q = make_quantizer(bits=2)
jq, _ = q.quant(np.eye(2), calibration_axes=-1)

iq.dequant() # @ jq.dequant()

Array([[0.9333334, 0.       ],
       [0.       , 0.9333334]], dtype=float32)

In [63]:
# quantize a and b to the relevant bits, then use q_dot - should be same as full precision passed in directly
ql = make_quantizer(bits=lhs)
qr = make_quantizer(bits=rhs)
aq, _ = ql.quant(a, calibration_axes=-1)
bq, _ = qr.quant(b, calibration_axes=-1)

cqq = qd(aq.dequant(), bq.dequant())
cq = qd(a, b)
c = a@b
c, cq, cqq

(Array([[ 0.0453768 , -0.9533617 , -1.8426027 ,  0.33260754],
        [ 0.37852615, -0.46583915,  1.0723062 , -0.02239281],
        [-0.55145985, -3.028088  ,  0.4540146 ,  0.67275053],
        [-1.4088378 ,  2.6997137 ,  0.03098934, -1.3995585 ]],      dtype=float32),
 Array([[-0.03916679, -0.5091683 , -1.8408391 ,  0.40145963],
        [ 0.38187623, -0.41614717,  1.1309412 ,  0.15177132],
        [-0.5336476 , -2.6878211 ,  0.23989663,  0.812711  ],
        [-1.3316709 ,  2.4234455 , -0.12239623, -1.2827125 ]],      dtype=float32),
 Array([[ 0.06364604, -0.77354413, -1.3806294 ,  0.15666717],
        [ 0.37208453, -0.41614717,  1.1211494 ,  0.15177132],
        [-0.43573058, -2.952197  ,  0.7050023 ,  0.5679185 ],
        [-1.3806294 ,  2.9424055 ,  0.357397  , -1.2386498 ]],      dtype=float32))

In [51]:
cq = np.matmul(aq.qvalue, bq.qvalue, preferred_element_type=np.int8) / (aq.scale[0] * bq.scale[0])
c = a @ b
print(c)
print(cq)

[[ 0.0453768  -0.9533617  -1.8426027   0.33260754]
 [ 0.37852615 -0.46583915  1.0723062  -0.02239281]
 [-0.55145985 -3.028088    0.4540146   0.67275053]
 [-1.4088378   2.6997137   0.03098934 -1.3995585 ]]
[[   0.          -6.3659    -101.8544      19.0977   ]
 [  21.363722   -21.363722    32.04558     -5.3409305]
 [ -26.154976   -58.848698    32.69372     19.616232 ]
 [  -1.8139033   12.697323     5.44171     -7.2556133]]


In [40]:
aq.dequant()

Array([[-0.8417911 ,  0.8417911 ,  0.8417911 ,  0.        ],
       [ 0.72478807, -0.72478807, -0.        ,  0.72478807],
       [ 0.        ,  1.1971722 , -1.1971722 ,  0.        ],
       [-0.        ,  0.        , -0.        , -2.5889838 ]],      dtype=float32)

In [41]:
aq.dequant() / aq.qvalue

Array([[0.8417911 , 0.8417911 , 0.8417911 ,        nan],
       [0.72478807, 0.72478807,        nan, 0.72478807],
       [       nan, 1.1971722 , 1.1971722 ,        nan],
       [       nan,        nan,        nan, 2.5889838 ]], dtype=float32)

In [19]:
qd = q_dot_maybe(fwd=(lhs, rhs), bwd=8)
qd(a, b)

Array([[ 0.       ,  0.       ],
       [-0.9431708,  1.2575611]], dtype=float32)