In [1]:
import numpy as np
import torch
import pychop as pychop
from pychop import QuantizedLayer, Rounding
from pychop.chop import chop
from pychop import float_params
from time import time
from numpy import linalg
import jax
# from pychop.chop import chop
# from pychop.quant import quant
from time import time
from scipy.io import savemat
# np.set_printoptions(precision=19)

In [2]:
np.random.seed(0)

X_np = np.random.randn(5000, 5000) # Numpy array
X_th = torch.Tensor(X_np) # torch array
X_jx = jax.numpy.asarray(X_np)
print(X_np)

#savemat("tests/verify.mat", {"X":X_np})

[[ 1.76405235  0.40015721  0.97873798 ...  0.92918181  0.22941801
   0.41440588]
 [ 0.30972382 -0.73745619 -1.53691988 ...  0.51687218 -0.03292069
   1.29811143]
 [-0.20211703 -0.833231    1.73360025 ...  0.75309415 -0.58103281
  -0.19837974]
 ...
 [ 1.07432182  1.188486    0.5092741  ...  0.07053449  0.59975911
  -2.41029925]
 [ 0.32432475 -0.02337844  1.62873399 ... -0.16088168 -1.59772992
   1.414703  ]
 [ 0.63460807  1.38090977  0.54829109 ...  0.30762729 -0.11078251
   0.83859307]]


In [3]:
X_th[0:5, 0]

tensor([ 1.7641,  0.3097, -0.2021,  2.4700,  0.3300])

### print unit-roundoff in machine

In [4]:
float_params()

Unnamed: 0,Unnamed: 1,u,xmins,xmin,xmax,p,emins,emin,emax
0,q43,0.0625,0.00195,0.0156,2.40e+02,4,-9,-6,7
1,q52,0.125,1.53e-05,6.1e-05,5.73e+04,3,-16,-14,15
2,b,0.00391,9.18e-41,1.18e-38,3.39e+38,8,-133,-126,127
3,h,0.000488,5.96e-08,6.1e-05,6.55e+04,11,-24,-14,15
4,t,0.000488,1.15e-41,1.18e-38,3.40e+38,11,-136,-126,127
5,s,5.96e-08,1.3999999999999999e-45,1.18e-38,3.40e+38,24,-149,-126,127
6,d,1.11e-16,5e-324,2.23e-308,1.80e+308,53,-1074,-1022,1023
7,q,9.630000000000001e-35,0.0,0.0,inf,113,-16494,-16382,16383


### set backend

In [5]:
# pychop.backend('torch')
pychop.backend('numpy', 1) # print information, NumPy is the default option.

Load NumPy backend.


### run chop

In [6]:
pyq_f = chop('h')
st = time()
X_bit = pyq_f(X_np)
print("runtime:", time() - st)
print(X_bit[:10, 0])

runtime: 2.028095006942749
[ 1.76367188  0.30981445 -0.20214844  2.47070312  0.33007812  0.30200195
  0.37133789  0.38720703 -1.93945312  0.56152344]


In [7]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=1)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:10, 0])

Load Troch backend.
runtime: 0.993048906326294
tensor([ 1.7637,  0.3098, -0.2021,  2.4707,  0.3301,  0.3020,  0.3713,  0.3872,
        -1.9395,  0.5615])


In [8]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=2)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:10, 0])

Load Troch backend.
runtime: 0.7100889682769775
tensor([ 1.7646,  0.3098, -0.2020,  2.4707,  0.3301,  0.3022,  0.3713,  0.3875,
        -1.9395,  0.5620])


In [9]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=3)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:10, 0])

Load Troch backend.
runtime: 0.7768659591674805
tensor([ 1.7637,  0.3096, -0.2021,  2.4688,  0.3298,  0.3020,  0.3711,  0.3872,
        -1.9404,  0.5615])


In [10]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=4)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:10, 0])

Load Troch backend.
runtime: 0.8116700649261475
tensor([ 1.7637,  0.3096, -0.2020,  2.4688,  0.3298,  0.3020,  0.3711,  0.3872,
        -1.9395,  0.5615])


In [11]:
values = torch.tensor([1.7641, 0.3097, -0.2021, 2.4700, 0.3300])

# half precision simulator (5 exponent bits, 10 mantissa bits)
fp16_sim = Rounding(5, 10)

rounding_modes = ["nearest", "up", "down", "towards_zero", 
                 "stochastic_equal", "stochastic_proportional"]

# Compare with PyTorch's native FP16
fp16_native = values.to(dtype=torch.float16).to(dtype=torch.float32)

print("Input values:      ", values)
print("PyTorch FP16:      ", fp16_native)
print()

print()
rounding_modes_num = [1, 2, 3, 4, "stochastic_equal", "stochastic_proportional"]

for mode in rounding_modes_num[:4]:
    pyq_f = chop('h', rmode=mode)
    groud_truth = pyq_f(values)
    emulated = fp16_sim.quantize(values, rounding_modes[mode-1])
    assert np.array_equal(emulated, groud_truth), print("error rmode 3")
    
    print(f"{rounding_modes[mode-1]:12}, ", "Truth:", f"   {emulated}")
    print(f"{rounding_modes[mode-1]:12}, ", "Emulated:", f"{groud_truth}")



Input values:       tensor([ 1.7641,  0.3097, -0.2021,  2.4700,  0.3300])
PyTorch FP16:       tensor([ 1.7637,  0.3098, -0.2021,  2.4707,  0.3301])


nearest     ,  Truth:    tensor([ 1.7637,  0.3098, -0.2021,  2.4707,  0.3301])
nearest     ,  Emulated: tensor([ 1.7637,  0.3098, -0.2021,  2.4707,  0.3301])
up          ,  Truth:    tensor([ 1.7646,  0.3098, -0.2020,  2.4707,  0.3301])
up          ,  Emulated: tensor([ 1.7646,  0.3098, -0.2020,  2.4707,  0.3301])
down        ,  Truth:    tensor([ 1.7637,  0.3096, -0.2021,  2.4688,  0.3298])
down        ,  Emulated: tensor([ 1.7637,  0.3096, -0.2021,  2.4688,  0.3298])
towards_zero,  Truth:    tensor([ 1.7637,  0.3096, -0.2020,  2.4688,  0.3298])
towards_zero,  Emulated: tensor([ 1.7637,  0.3096, -0.2020,  2.4688,  0.3298])


In [12]:
formats = {
    "fp32": (8, 23),    # Standard IEEE 754 float32
    "fp16": (5, 10),    # Standard IEEE 754 float16
    "bf16": (8, 7),     # bfloat16
    "fp8": (5, 2),      # Example 8-bit float
}

# Test different rounding modes
rounding_modes = [
    "nearest",
    "up",
    "down",
    "towards_zero",
    "stochastic_equal",
    "stochastic_proportional"
]

# Create a sample tensor
x = torch.tensor([0.1, 0.3, 1.7, 3.9, -2.5])

# Test quantization
mp = Rounding(*formats["bf16"])

print("Original values:", x)
for mode in rounding_modes:
    result = mp.quantize(x, mode)
    print(f"{mode}:", result)

# Test with a layer
layer = QuantizedLayer(4, 2, *formats["bf16"], rounding_mode="nearest")
input_tensor = torch.randn(3, 4)
output = layer(input_tensor)
print("\nLayer output shape:", output.shape)


Original values: tensor([ 0.1000,  0.3000,  1.7000,  3.9000, -2.5000])
nearest: tensor([ 0.1001,  0.3008,  1.7031,  3.9062, -2.5000])
up: tensor([ 0.1001,  0.3008,  1.7031,  3.9062, -2.5000])
down: tensor([ 0.0996,  0.2988,  1.6953,  3.8906, -2.5000])
towards_zero: tensor([ 0.0996,  0.2988,  1.6953,  3.8906, -2.5000])
stochastic_equal: tensor([ 0.0996,  0.3008,  1.6953,  3.8906, -2.5000])
stochastic_proportional: tensor([ 0.1001,  0.3008,  1.6953,  3.9062, -2.5000])

Layer output shape: torch.Size([3, 2])


In [13]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=1)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:5, 0])

Load Troch backend.
runtime: 0.9264817237854004
tensor([ 1.7637,  0.3098, -0.2021,  2.4707,  0.3301])


In [14]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=2)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:5, 0])

Load Troch backend.
runtime: 1.0012319087982178
tensor([ 1.7646,  0.3098, -0.2020,  2.4707,  0.3301])


In [15]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=3)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:5, 0])

Load Troch backend.
runtime: 1.0078048706054688
tensor([ 1.7637,  0.3096, -0.2021,  2.4688,  0.3298])


In [16]:
pychop.backend('torch', 1) # print information
pyq_f = chop('h', rmode=4)
st = time()
X_bit = pyq_f(X_th)
print("runtime:", time() - st)
print(X_bit[:5, 0])

Load Troch backend.
runtime: 1.1233789920806885
tensor([ 1.7637,  0.3096, -0.2020,  2.4688,  0.3298])


In [17]:
pychop.backend('jax', 1) # print information
pyq_f = chop('h')
st = time()
X_bit = pyq_f(X_jx)
print("runtime:", time() - st)
print(X_bit)

Load JAX backend.
runtime: 7.8559112548828125
[[ 1.7636719   0.40014648  0.9785156  ...  0.9291992   0.22937012
   0.41430664]
 [ 0.30981445 -0.7373047  -1.5371094  ...  0.51708984 -0.03292847
   1.2978516 ]
 [-0.20214844 -0.8330078   1.7333984  ...  0.7529297  -0.5810547
  -0.19836426]
 ...
 [ 1.0742188   1.1884766   0.50927734 ...  0.07055664  0.5996094
  -2.4101562 ]
 [ 0.32421875 -0.02337646  1.6289062  ... -0.16088867 -1.5976562
   1.4150391 ]
 [ 0.6347656   1.3808594   0.54833984 ...  0.3076172  -0.11077881
   0.8383789 ]]


### integer quantization

In [18]:
pychop.backend('numpy')
pyq_f = pychop.quant(bits=8)
X_q = pyq_f(X_np)
X_inv = pyq_f.dequant(X_q)
linalg.norm(X_inv - X_np)

62.800703887228494

In [19]:
pychop.backend('torch')
pyq_f = pychop.quant(bits=8)
X_q = pyq_f(X_th)
X_inv = pyq_f.dequant(X_q)
linalg.norm(X_inv - X_np)

62.800703841249906

In [20]:

pychop.backend('jax')
pyq_f = pychop.quant(bits=8)
X_q = pyq_f(X_jx)
X_inv = pyq_f.dequant(X_q)
linalg.norm(X_inv - X_jx)

62.7823

### fixed point quantization

In [21]:
pychop.backend('numpy')
pyq_f = pychop.fpoint()

pyq_f(X_np)

array([[ 1.75  ,  0.375 ,  1.    , ...,  0.9375,  0.25  ,  0.4375],
       [ 0.3125, -0.75  , -1.5625, ...,  0.5   , -0.0625,  1.3125],
       [-0.1875, -0.8125,  1.75  , ...,  0.75  , -0.5625, -0.1875],
       ...,
       [ 1.0625,  1.1875,  0.5   , ...,  0.0625,  0.625 , -2.4375],
       [ 0.3125, -0.    ,  1.625 , ..., -0.1875, -1.625 ,  1.4375],
       [ 0.625 ,  1.375 ,  0.5625, ...,  0.3125, -0.125 ,  0.8125]])

In [22]:
pychop.backend('torch')
pyq_f = pychop.fpoint()
pyq_f(X_th)

tensor([[ 1.7500,  0.3750,  1.0000,  ...,  0.9375,  0.2500,  0.4375],
        [ 0.3125, -0.7500, -1.5625,  ...,  0.5000, -0.0625,  1.3125],
        [-0.1875, -0.8125,  1.7500,  ...,  0.7500, -0.5625, -0.1875],
        ...,
        [ 1.0625,  1.1875,  0.5000,  ...,  0.0625,  0.6250, -2.4375],
        [ 0.3125, -0.0000,  1.6250,  ..., -0.1875, -1.6250,  1.4375],
        [ 0.6250,  1.3750,  0.5625,  ...,  0.3125, -0.1250,  0.8125]])

In [23]:
pychop.backend('jax')
pyq_f = pychop.fpoint()
pyq_f(X_jx)

Array([[ 1.75  ,  0.375 ,  1.    , ...,  0.9375,  0.25  ,  0.4375],
       [ 0.3125, -0.75  , -1.5625, ...,  0.5   , -0.0625,  1.3125],
       [-0.1875, -0.8125,  1.75  , ...,  0.75  , -0.5625, -0.1875],
       ...,
       [ 1.0625,  1.1875,  0.5   , ...,  0.0625,  0.625 , -2.4375],
       [ 0.3125, -0.    ,  1.625 , ..., -0.1875, -1.625 ,  1.4375],
       [ 0.625 ,  1.375 ,  0.5625, ...,  0.3125, -0.125 ,  0.8125]],      dtype=float32)