In [None]:
!pip install conway_polynomials
!pip install galois
import jax.numpy as jnp
from jax import jit, pmap, random, vmap
import numpy as np
import time, math
import galois
import conway_polynomials
import torch

In [82]:
p, n = 2, 40

In [83]:
conway_poly = conway_polynomials.database()

def xn( p, n ):
    return ( -jnp.array( conway_poly[p][n][:-1] ) ) % p

def shift( v, k ):
    return jnp.roll( v, k ).at[:k].set(0) if k >= 0 else jnp.roll( v, k ).at[k:].set(0)

def x_times( v, f, p ):
    return ( shift( v, 1 ) - v[ -1 ] * f ) % p

def X( p, n ):
    Xn = xn( p, n )
    Xlist = [ Xn ]
    for i in range( n - 1 ):
        Xlist = Xlist + [ x_times( Xlist[i], Xn, p ) ]
    return jnp.array( Xlist )

XN = X( p, n ).T
XN

Array([[1, 0, 0, ..., 1, 1, 1],
       [1, 1, 0, ..., 1, 0, 0],
       [0, 1, 1, ..., 1, 1, 0],
       ...,
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 1, 1]], dtype=int32)

In [84]:
def int_to_vec( x ):
    assert x >= 0
    if x == 0:
        return jnp.zeros( n )
    v = [ ]
    while x:
        v.append( x % p )
        x //= p
    return jnp.array( v + ( n - len( v ) ) * [ 0 ], dtype = jnp.int32 )

def vec_to_int( v ):
    basis = jnp.array( [ p ** k for k in range( n ) ] )
    return jnp.sum( v * basis )

@jit
def M( v ):
    #m = jnp.zeros( ( len( v ), len( v ) ) )
    #a = jnp.array( [ shift( v, k ) for k in range( n ) ] )
    #b = jnp.array( [ jnp.sum( shift( v, k - n ) * XN, axis = 1 ) for k in range( n ) ] )
    #return ( a + b ) % p
    return jnp.array( [ ( shift( v, k ) + jnp.sum( shift( v, k - n ) * XN, axis = 1 ) ) % p for k in range( n ) ] )

In [85]:
@jit
def block( m ):
    return jnp.block( [ [ F_scalar( a ).m for a in b ] for b in m ] )

class F_scalar:
    def __init__( self, v: jnp.array ):
        self.v = v
        self.m = M( self.v )
        
    def times( self, x ):
        return F_scalar( ( self.m @ x.m ) % p, p, n )
        
class F_matrix:
    def __init__( self, m: jnp.array ):
        #self.m = jnp.vstack( [ jnp.hstack( [ F_scalar( int_to_vec( a ) ).m for a in b ] ) for b in m ] )
        self.m = block( m )
        
    def times( self, x ):
        return self.m @ x.m % p

In [None]:

vblock = vmap( block)

In [86]:
d,N =32,1
key = random.PRNGKey(0)
a = random.randint( key, ( d, d, n ), 0, p )

In [None]:
t0 = time.perf_counter( )
mat = F_matrix( a )
print( time.perf_counter( ) - t0 )

In [76]:
t0 = time.perf_counter( )
for i in range( 100 ):
    mat.times( mat )
print( time.perf_counter( ) - t0 )

0.024680529999841383


In [77]:
F = galois.GF( p ** n )
galois_mat = F( np.random.randint( 0, p, ( d, d )) )
#jnp_mats = [ jnp.array( np.random.randint( 0, 13, ( 4*N, 4*N )) ) for _ in range(10 ) ]
#torch_mats = [ torch.Tensor( np.random.randint( 0, 13, ( 8*N, 8*N )) ).cuda() for _ in range(100 ) ]

In [81]:
t0 = time.perf_counter( )
for _ in range( 100 ):
    galois_mat @ galois_mat
print( time.perf_counter( ) - t0 )

0.031226488999891444


In [21]:
t0 = time.perf_counter( )
[ torch.matmul( a, a ) % 13 for a in torch_mats ]
print( time.perf_counter( ) - t0 )

0.005008286999782285


In [None]:
t0 = time.perf_counter( )
[ jnp.matmul( a, a ) for a in jnp_mats ]
print( time.perf_counter( ) - t0 )