In [1]:
import poppy
import jax
import jax.numpy as jnp
import functools
import galois
import numpy as np

In [85]:
@jax.jit
def pgetrf2(aperm, inv): # sequential lu decomposition.

  p = 1+inv[-1] # p is prime.
  I = jnp.arange(aperm.shape[0])
  J = jnp.arange(aperm.shape[1]-1)
  R = jnp.arange(min(len(I),len(J)))
  
  def f(ap, i):

    j = jnp.argmax( jnp.where(I >= i, ap[:,i], -1) ) # search column i for j.
    ap = ap.at[[i,j],:].set( ap[[j,i],:] ) # swap rows i and j.
    ap = ap.at[:,i].set( jnp.where( I > i, (ap[:,i] * inv[ ap[i,i] ]) % p, ap[:,i] ) )  # scale column i.
    ap = ap.at[:,:-1].set( (ap[:,:-1] - jnp.where( (I[:,None] > i) & (J[None,:] > i), jnp.outer( ap[:,i], ap[i,:-1] ), 0 )) % p ) # geru.
        
    return ap, i

  return jax.lax.scan(f, aperm, R, unroll = False)[0]

In [87]:
DTYPE = poppy.DTYPE
@jax.jit
def ptrsm(a, b, p): # mod p triangular solve.
  
  R = jnp.arange(len(a), dtype = DTYPE)

  def ptrsm_vmap(bb): # bb is the matrix b.
    def ptrsm_scan(bc): # bc is a column of bb.
      def f(x,j):
        x = x.at[j].set( (bc[j] - jnp.dot( a[j], x )) % p )
        return x, x[j]  
      return jax.lax.scan( f, jnp.where( R == 0, bc[0], 0 ), R[1:] )[0] # scan the rows of a.
    return jax.vmap(ptrsm_scan)(bb.T).T  # vmap the columns of b.

  return ptrsm_vmap(b)

In [119]:
@functools.partial(jax.jit, static_argnums = 2)
def pgetrf(a, inv, b): # blocked lu decompposition.

  p = 1+inv[-1]
  m = min(a.shape)
  perm = jnp.arange(len(a))
  
  for i in range(0, m, b):
    bb = min(m-i, b)
    ap = pgetrf2(jnp.hstack([ a[i:, i:i+bb], jnp.arange(i,len(a)).reshape((-1,1)) ]), inv)
    perm = perm.at[i:].set( perm[ap[:,-1]] )
    a = a.at[i:,:].set( a[ap[:,-1], :] ) # swap rows.
    a = a.at[i:, i:i+bb].set( ap[:,:-1] )  # update block C.
    a = a.at[i:i+bb, i+bb:].set(ptrsm( a[i:i+bb, i:i+bb], a[i:i+bb, i+bb:], p )) # update block B.
    a = a.at[i+bb:, i+bb:].set( (a[i+bb:, i+bb:] - jax.lax.dot( a[i+bb: , i:i+bb], a[i:i+bb, i+bb:] )) % p) # update block D.

  l = jnp.fill_diagonal(jnp.tril(a), 1, inplace = False)
  u = jnp.tril(a.T).T
  d = jnp.diagonal(u)
  iperm = jnp.arange(len(perm))
  iperm = iperm.at[perm].set(iperm)
  
  return l, u, d, iperm

In [128]:
@functools.partial(jax.jit, static_argnums = 2)
def inv( a, INV, b):
  
  p = 1+INV[-1] # 1/(p-1) = p-1
  I = jnp.eye(len(a), dtype = DTYPE)
  l, u, d, iperm = pgetrf(a, INV, b)
  D = INV[d]
  L = ptrsm(l, I, p) # L = 1/l.
  U = ptrsm((D*u%p).T, D*I, p).T # U = 1/u.
  
  return (U@L%p)[:,iperm]

@jax.jit
def det(d, p):
  return jax.lax.associative_scan(lambda a,b: a*b%p, d)[-1]


In [129]:
p = 5
N = 33
B = 2
field = poppy.field(p,1)
#GF = galois.Field(p)
ap = poppy.random((N,N), field, poppy.SEED+2)
#ag = GF( np.array( a.proj( ) ) )
l, u, d, perm = pgetrf(ap.lift, field.INV, B)
l,u,d,perm, det(d,p)


(Array([[1, 0, 0, ..., 0, 0, 0],
        [0, 1, 0, ..., 0, 0, 0],
        [0, 0, 1, ..., 0, 0, 0],
        ...,
        [4, 0, 3, ..., 1, 0, 0],
        [2, 2, 1, ..., 0, 1, 0],
        [0, 3, 2, ..., 4, 3, 1]], dtype=int64),
 Array([[4, 2, 0, ..., 2, 4, 1],
        [0, 4, 2, ..., 0, 3, 3],
        [0, 0, 4, ..., 2, 4, 0],
        ...,
        [0, 0, 0, ..., 3, 1, 0],
        [0, 0, 0, ..., 0, 4, 0],
        [0, 0, 0, ..., 0, 0, 4]], dtype=int64),
 Array([4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 3, 3, 4, 2, 3, 3, 4, 4], dtype=int64),
 Array([ 8, 17,  2,  3, 13, 12,  9, 14,  7,  6,  5, 26,  0, 21,  4,  1, 24,
        10, 15, 11, 20, 18, 16, 22, 27, 23, 25, 28, 31, 19, 29, 32, 30],      dtype=int64),
 Array(4, dtype=int64))

In [131]:
ap.lift@inv(ap.lift,field.INV,B)%p, det(d,p)

(Array([[1, 0, 0, ..., 0, 0, 0],
        [0, 1, 0, ..., 0, 0, 0],
        [0, 0, 1, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 1, 0, 0],
        [0, 0, 0, ..., 0, 1, 0],
        [0, 0, 0, ..., 0, 0, 1]], dtype=int64),
 Array(4, dtype=int64))

In [122]:
ap.lift, (l@u%p)[perm]

(Array([[4, 1, 2],
        [4, 0, 0],
        [2, 3, 2]], dtype=int64),
 Array([[4, 1, 2],
        [4, 0, 0],
        [2, 3, 2]], dtype=int64))

In [125]:
p = 12421
N = 34
B = 16
field = poppy.field(p,1)
#GF = galois.Field(p)
a = poppy.random((N,N), field, poppy.SEED+2)
#ag = GF( np.array( a.proj( ) ) )
#lup = pgetrf2( jnp.hstack( [ a.lift, jnp.arange( N ).reshape( ( -1, 1 )) ]), field.INV )
#lu, perm = pgetrf(a.lift, field.INV, B)
%timeit pgetrf( a.lift, field.INV, B )[0].block_until_ready()

246 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [48]:
@functools.partial( jax.jit, static_argnums = 1 )
def pgetrf2( aperm, p, INV ):

  @functools.partial( jax.jit, static_argnums = 1 )
  def f( ap, i ):

    j = jnp.argmax( jnp.where( I >= i, ap[ : , i ], 0 ) )
    ap = ap.at[ [ i, j ], ].set( ap[ [ j, i ], ] ) # swp.
    ap = ap.at[ :, i ].set( jnp.where( I > i, ap[ : , i ] * INV[ ap[ i, i ] ], ap[ :, i ] ) % p ) # scal.
    ap = ( ap - jnp.where( ( I[ :, None ] > i ) & ( J[ None, : ] > i ) & ( J[ None, : ] < J[ -1 ] ), \
           jnp.outer( ap[ :, i ], ap[ i, : ] ), \
            0 ) ) % p # geru.
        
    return ap, i

  I = jnp.arange( aperm.shape[ 0 ] )
  J = jnp.arange( aperm.shape[ 1 ] )
  R = jnp.arange( min( aperm.shape[ 0 ], aperm.shape[ 1 ] - 1 ) )

  #return jax.lax.fori_loop( 0, min( aperm.shape[ 0 ], aperm.shape[ 1 ] - 1 ), f, aperm )

  return jax.lax.scan( f, aperm, R[ : -1 ], unroll = False )[ 0 ]

@functools.partial( jax.jit, static_argnums = ( 1, 2 ) )
def pgetrf( a, field, B ):

  @jax.jit
  def ptrsm( aa, b ): # mod p triangular solve.
  
    M = aa.shape[ 0 ]
    N = b.shape[ 1 ]
    I = jnp.eye( N, M, dtype = DTYPE )
    C = jnp.arange( N )
    R = jnp.arange( 1, M )

    @jax.jit 
    def ptrsm_vmap( a1, b1 ):

      @jax.jit
      def ptrsm_scan( a2, b2, i ): 

        @jax.jit
        def f( x, j ):
  
          x = x.at[ j ].set( x[ j ] + ( 1 - x[ j ] ) * ( b2[ j ] - jnp.dot( a2[ j ], x ) ) % field.p )
  
          return x, x[ j ]  
  
        # scan the rows.
        return jax.lax.scan( f, I[ i ], R )[ 0 ]

      # vmap the columns.
      return jax.vmap( ptrsm_scan, in_axes = ( None, 0, 0 ) )( a1, b1.transpose( ), C ).transpose( )

    return ptrsm_vmap( aa, b )

  @functools.partial( jax.jit, static_argnums = 1 )
  def f( a, i ):
    
    roof = ptrsm( a[ i:i+b, i:i+b ], a[ i:i+b, i+b: ] ) 
    house = ( a[ i+b:, i+b: ] - jax.lax.dot( a[ i+b: , i:i+b ], roof ) ) % field.p
    #return a.at[ i: , i+b: ].set( jnp.vstack( [ roof, house ] ) )
    a = a.at[ i:i+b, i+b: ].set( roof )
    
    return a.at[ i+b:, i+b: ].set( house )

  perm = jnp.arange( len( a ) )
  m = min( a.shape )
  
  for i in range( 0, m, B ):
    
    b = min( m - i, B )
    ap = pgetrf2( jnp.hstack( [ a[ i: , i:i+b ], perm[ i: ].reshape( ( -1, 1 ) ) ] ), field.p, field.INV )
    perm = perm.at[ i: ].set( ap[ : , -1 ] )
    a = a.at[ i:, :-1 ].set( a[ perm[ i: ], :-1 ] ) # swap rows.
    a = a.at[ i: , i : i+b ].set( ap[ : , : -1 ] )  # update left wall.

    if i + b < a.shape[ 1 ]:
      a = f( a, i )
  
  return a, perm

@functools.partial( jax.jit, static_argnums = ( 1, 2 ) )
def pgetrf_scan( a, field, B ):

  perm = jnp.arange( len( a ) )
  m = min( a.shape )
  I = jnp.arange( a.shape[ 0 ] )
  J = jnp.arange( a.shape[ 1 ] )
  
  #@functools.partial( jax.jit, static_argnums = 1 )
  def f( i, ap ):

    a1, perm = ap
    #b = min( m - i, B )
    b = B
    ap_block = pgetrf2( jnp.hstack( [ jax.lax.dynamic_slice( a1, ( 0, i*b ), ( len( a ), b ) ), perm.reshape( ( -1, 1 ) ) ] ), field )
    perm = ap_block[ : , -1 ]
    a1 = jax.lax.dynamic_update_slice( a1, ap_block[ : , : -1 ], ( 0, i ) )

    a00 = jax.lax.dynamic_slice( a1, ( i*b, i*b ), ( b, b ) )
    a01 = jax.lax.dynamic_slice( a1, ( i*b, 0 ), ( b, len( J ) ) )
    a10 = jax.lax.dynamic_slice( a1, ( 0, i ), ( len( I ), b ) )
    #a00 = jnp.where( a00 >= 0 )

    a01p = poppy.ptrsm( a00, a01, field.p )
    a11p = jax.lax.dot( a10, a01 )
    a01 = jnp.where( J[ None, : ] > i*b, a01p, a01 )
    a11 = jnp.where( ( I[ : , None ] > i*b ) & ( J[ None, : ] > i*b ), a11p, a1 )
    a1 = jax.lax.dynamic_update_slice( a1, a01, ( i, 0 ) )
    a1 = ( a1 - a11 ) % field.p

  

    #a1 = a1.at[ i:i+b, i+b: ].set( poppy.ptrsm( a1[ i:i+b, i:i+b ], a1[ i:i+b, i+b: ], field.p ) )
    #a1 = a1.at[ i+b:, i+b: ].add( -jax.lax.dot( a1[ i+b: , i:i+b ], a1[ i:i+b, i+b: ] ) ) % field.p

    return a1, perm

  return jax.lax.fori_loop(0, m // B, f, ( a, perm ) )
  #a = jax.lax.scan( f, a, jnp.arange( m // B ) * B )[ 0 ]
  
  #return a, perm


In [5]:
p = 12421
N = 512
B = 128
field = poppy.field( p, 1 )
GF = galois.Field( p )
a = poppy.random( ( N, N ), field )
ag = GF( np.array( a.proj( ) ) )

In [53]:
%timeit pgetrf( a.lift, field, 32 )[ 0 ].block_until_ready( )

60.1 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [77]:
p = 12421
N = 512
B = 128
field = poppy.field( p, 1 )
GF = galois.Field( p )
a = poppy.random( ( N, N ), field )
ag = GF( np.array( a.proj( ) ) )

In [158]:
%timeit pgetrf2( a.lift, field ).block_until_ready( )
#%timeit a.inv( ).lift.block_until_ready( )

259 ms ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [135]:
%timeit ag.lu_decompose( )

1.09 s ± 279 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
