In [None]:
!pip install conway_polynomials
import conway_polynomials
!pip install galois
import galois
import jax.numpy as jnp
from   jax   import jit, random, vmap
import numpy     as  np
import time, math

In [2]:
# BEGIN PRECOMPUTATION
# We are in the finite field of order p**n.
p, n = 2, 10
assert n * math.log( p ) < 30 * math.log( 2 )

INDICES = jnp.arange(  n,      dtype = jnp.int32 )
ONES    = jnp.ones(    n,      dtype = jnp.int32 )
ZERO    = jnp.zeros( ( n, n ), dtype = jnp.int32 )
I       = jnp.eye(     n,      dtype = jnp.int32 )
BASIS   = jnp.power( p * ONES, INDICES  )
CONWAY  = conway_polynomials.database()

# XN is the vector of subleading coefficients of -1 * ( conway polynomial ).
XN = ( -jnp.array( CONWAY[ p ][ n ][ :-1 ], dtype = jnp.int32 ) ) % p
XX = ZERO.at[ 1:, :-1 ].set(I[ 1:, 1: ]).at[ :, -1 ].set( XN )
Xs = [ I ]
for i in range( 1, n ):
    Xs.append( Xs[ i-1 ] @ XX % p )
X = jnp.array( Xs )
# END PRECOMPUTATION


# BEGIN FUNCTION DEFINITIONS
@jit
def int_to_vec( x ):
    return jnp.floor_divide( x * ONES, BASIS ) % p
vint_to_vec = vmap(int_to_vec )

@jit
def vec_to_int( v ):
    return jnp.sum( v * BASIS, dtype = jnp.int32 )
vvec_to_int = vmap( vec_to_int )

@jit
def M( v ):
    return jnp.einsum( 'i,ijk->jk', v, X ) % p
vM = vmap( M )
# END FUNCTION DEFINITIONS

In [3]:
GF   = galois.GF( p**n )
key1 = random.PRNGKey( 1 )
key2 = random.PRNGKey( 3 )
v1   = random.randint( key1, ( n, ), 0, p, dtype = jnp.int32 )
v2   = random.randint( key2, ( n, ), 0, p, dtype = jnp.int32 )
x1   = vec_to_int( v1 )
x2   = vec_to_int( v2 )
m1   = M( v1 )
m2   = M( v2 )
gx1  = GF( x1.item( ) )
gx2  = GF( x2.item( ) )
x1, x2, v1, v2, m1, m2, m1@m2 % p, int_to_vec( gx1*gx2 ), gx1*gx2

(Array(958, dtype=int32),
 Array(554, dtype=int32),
 Array([0, 1, 1, 1, 1, 1, 0, 1, 1, 1], dtype=int32),
 Array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32),
 Array([[0, 1, 1, 1, 0, 0, 1, 1, 1, 1],
        [1, 1, 0, 0, 1, 0, 1, 0, 0, 0],
        [1, 0, 0, 1, 0, 1, 1, 0, 1, 1],
        [1, 0, 1, 1, 1, 0, 0, 0, 1, 0],
        [1, 1, 0, 1, 1, 1, 0, 0, 0, 1],
        [1, 0, 0, 1, 1, 1, 0, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 0, 1, 0, 0],
        [1, 0, 0, 1, 1, 1, 1, 0, 1, 0],
        [1, 1, 0, 0, 1, 1, 1, 1, 0, 1],
        [1, 1, 1, 0, 0, 1, 1, 1, 1, 0]], dtype=int32),
 Array([[0, 1, 0, 0, 0, 0, 1, 1, 1, 0],
        [1, 1, 1, 0, 0, 0, 1, 0, 0, 1],
        [0, 0, 1, 1, 0, 0, 1, 0, 1, 0],
        [1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
        [0, 1, 1, 0, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 1, 0, 1, 0, 1, 0, 0],
        [0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 1, 1, 0, 0, 1],
        [1, 0, 0, 0, 0, 1, 1, 1, 0, 0]], dtype=int32),
 Array([[1, 

In [4]:
N=400
xs = random.randint( key1, ( N*N, ), 0, p**n, dtype = jnp.int32 )
vs = vint_to_vec( xs )
gm = GF( np.array( xs.reshape( N, N ) ) )
ms = vM( vs )
m = jnp.swapaxes( ms.reshape(N,N,n,n), 1, 2 ).reshape( n*N, n*N )
print( xs.shape, vs.shape, gm.shape, m.shape )

(160000,) (160000, 10) (400, 400) (4000, 4000)


In [5]:
gm[0,0], vec_to_int(m[:n,0])

(GF(620, order=2^10), Array(620, dtype=int32))

In [6]:
gm2 = gm@gm
m2  =  m@m % p

In [7]:
gm2[0,0], vec_to_int(m2[:n,0])

(GF(90, order=2^10), Array(90, dtype=int32))

In [8]:
%timeit gm@gm

329 ms ± 36.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit  m@m % p