In [1]:
import numpy as np
import torch
import pychop
from pychop.chop import chop
from pychop.quant import quant
from time import time

In [2]:
X_np = np.random.randn(5000, 5000) # Numpy array
X = torch.Tensor(X_np) # torch array

In [3]:
pychop.to_fixed_point(X)

numpy


tensor([[ 0.1250, -0.5000,  0.4375,  ..., -0.3750,  0.0625,  0.7500],
        [ 0.2500, -2.1875, -0.8125,  ...,  0.5625,  1.0000, -0.3750],
        [ 0.4375,  0.7500,  1.1875,  ...,  0.1250,  1.8125, -0.5625],
        ...,
        [ 0.2500, -1.5000,  0.2500,  ...,  3.6250, -0.3125, -0.1875],
        [-0.0000,  0.5625, -0.5000,  ..., -0.5625,  1.1250, -0.3125],
        [-0.9375,  0.3125,  0.2500,  ..., -0.6250,  0.1250, -0.2500]])

In [4]:
pychop.backend('numpy')
pyq_f = chop('h')
pyq_f(X_np)

array([[ 0.07757568,  0.14916992, -0.06958008, ...,  0.57763672,
        -1.36425781,  0.35083008],
       [ 0.74804688, -1.18164062,  0.85009766, ...,  1.35253906,
         0.41113281,  1.20117188],
       [-0.00433731, -1.33886719, -1.32519531, ...,  1.15332031,
         0.37011719,  0.92480469],
       ...,
       [-0.39306641, -0.02983093, -0.00984192, ...,  0.05960083,
        -1.05664062,  0.19311523],
       [-1.38476562, -0.38208008,  0.40283203, ..., -0.17810059,
         1.23046875, -0.29150391],
       [ 0.37915039,  0.05490112, -0.01125336, ..., -0.69726562,
         0.05987549,  3.1640625 ]])

In [5]:
pychop.backend('torch')
pyq_f = chop('h', device='cuda')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_gpu = X.to(device)

In [6]:
pyq_f(X_gpu)

tensor([[-1.5205,  0.2493,  0.5713,  ...,  0.6880, -1.0352, -0.0603],
        [-0.5664, -1.6699, -0.8940,  ...,  1.1475, -0.7046,  0.5986],
        [ 0.7534, -0.4973, -0.1455,  ...,  0.5312, -0.4556, -2.1406],
        ...,
        [ 2.1777,  0.6377,  1.6133,  ..., -0.4016,  1.3779, -1.3828],
        [-0.3909, -0.1877,  1.8037,  ..., -1.7988,  0.0443, -0.1127],
        [-0.0245, -1.0361, -0.2117,  ..., -0.9258, -1.0400,  0.7295]],
       device='cuda:0')

In [12]:
from pychop.chop import chop
pyq_f = chop('h', flip=0)
pyq_f(X_np)

array([[ 0.22277832, -0.89160156,  2.75195312, ..., -1.64160156,
        -1.18652344,  0.68457031],
       [ 0.15332031, -0.79199219, -0.70849609, ..., -1.54785156,
         0.88720703, -0.77490234],
       [-1.37402344, -2.13867188, -1.8359375 , ..., -1.13574219,
         0.43041992,  0.98046875],
       ...,
       [-0.36279297, -1.02636719, -1.31445312, ...,  0.33862305,
         0.34448242, -0.10742188],
       [-1.31738281, -0.40893555,  0.35620117, ..., -0.90966797,
         1.24609375, -1.453125  ],
       [ 0.08325195,  1.33789062,  0.33520508, ..., -0.27880859,
         0.859375  ,  1.75      ]])

In [8]:
pyq_f = quant()

In [9]:
X_q = pyq_f(X)
X_q

tensor([[ 21,  10, -14,  ...,   7,  11, -17],
        [ 42,  17, -40,  ...,  -2,  10, -40],
        [ -7,  -5,  -2,  ..., -47,   4,  16],
        ...,
        [ -7,  48, -14,  ..., -25,  36, -12],
        [  8,  36,   9,  ..., -21, -22,  -4],
        [ 24,   9,   1,  ...,  -6,   5,  11]], dtype=torch.int8)

In [10]:
X_q = pyq_f(X_gpu)
X_q

tensor([[ 21,  10, -14,  ...,   7,  11, -17],
        [ 42,  17, -40,  ...,  -2,  10, -40],
        [ -7,  -5,  -2,  ..., -47,   4,  16],
        ...,
        [ -7,  48, -14,  ..., -25,  36, -12],
        [  8,  36,   9,  ..., -21, -22,  -4],
        [ 24,   9,   1,  ...,  -6,   5,  11]], device='cuda:0',
       dtype=torch.int8)

In [11]:
pyq_f.dequant(X_q)

tensor([[ 1.1499e+00,  6.8225e-01, -3.3806e-01,  ...,  5.5471e-01,
          7.2476e-01, -4.6560e-01],
        [ 2.0427e+00,  9.7984e-01, -1.4434e+00,  ...,  1.7210e-01,
          6.8225e-01, -1.4434e+00],
        [-4.0468e-02,  4.4558e-02,  1.7210e-01,  ..., -1.7410e+00,
          4.2717e-01,  9.3733e-01],
        ...,
        [-4.0468e-02,  2.2977e+00, -3.3806e-01,  ..., -8.0570e-01,
          1.7876e+00, -2.5303e-01],
        [ 5.9723e-01,  1.7876e+00,  6.3974e-01,  ..., -6.3565e-01,
         -6.7816e-01,  8.7071e-02],
        [ 1.2774e+00,  6.3974e-01,  2.9964e-01,  ...,  2.0448e-03,
          4.6969e-01,  7.2476e-01]], device='cuda:0')

In [9]:
from pychop.jx.chop import chop

In [10]:
pyq_f = chop('h', flip=0)

In [11]:
import jax

pyq_f(jax.numpy.asarray(X_np))

Array([[ 0.22277832, -0.89160156,  2.7519531 , ..., -1.6416016 ,
        -1.1865234 ,  0.6845703 ],
       [ 0.15332031, -0.7919922 , -0.7084961 , ..., -1.5478516 ,
         0.88720703, -0.77490234],
       [-1.3740234 , -2.1386719 , -1.8359375 , ..., -1.1357422 ,
         0.43041992,  0.98046875],
       ...,
       [-0.36279297, -1.0263672 , -1.3144531 , ...,  0.33862305,
         0.34448242, -0.10742188],
       [-1.3173828 , -0.40893555,  0.35620117, ..., -0.90966797,
         1.2460938 , -1.453125  ],
       [ 0.08325195,  1.3378906 ,  0.33520508, ..., -0.2788086 ,
         0.859375  ,  1.75      ]], dtype=float32)

In [7]:

jnp = jax.numpy.asarray(X_np)
jnp[jnp > 1] = 3

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [16]:
from jax import random
key = random.key(1)
random.randint(key, shape=(3, ), minval=0, maxval=1)


Array([0, 0, 0], dtype=int32)

In [9]:
jnp 

Array([[-2.0479598 ,  0.        ,  0.        , ...,  0.89571846,
        -0.17137416, -0.01463014],
       [ 0.        ,  0.5509474 ,  0.04508554, ..., -0.96783084,
         0.9529769 ,  0.65047157],
       [-0.26722014,  0.        ,  0.36949125, ...,  0.        ,
        -2.0251095 ,  0.6836732 ],
       ...,
       [-0.8001436 ,  0.45967364, -1.3158927 , ..., -1.3393523 ,
        -1.2479155 , -0.10151188],
       [ 0.65616345, -1.1558628 ,  0.22463392, ...,  0.06328919,
         0.862858  ,  0.        ],
       [-2.9082816 ,  0.        , -0.6819175 , ...,  0.40110233,
        -0.04349303, -1.7842377 ]], dtype=float32)

In [7]:
jnp = jax.numpy.asarray(X_np)


Array([[False, False, False, ..., False, False, False],
       [ True, False, False, ...,  True, False, False],
       [False, False,  True, ..., False, False, False],
       ...,
       [False, False, False, ..., False,  True, False],
       [False, False, False, ..., False, False,  True],
       [ True, False, False, ..., False, False, False]], dtype=bool)