Skip to content

Commit

Permalink
Make nmod_mat subclass flint_mat
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbenjamin committed Nov 15, 2023
1 parent 655f392 commit 67b29cf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/flint/types/nmod_mat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from flint.flint_base.flint_base cimport flint_mat
from flint.flintlib.nmod_mat cimport nmod_mat_t
from flint.flintlib.flint cimport mp_limb_t

cdef class nmod_mat:
cdef class nmod_mat(flint_mat):
cdef nmod_mat_t val
cpdef long nrows(self)
cpdef long ncols(self)
Expand Down
77 changes: 39 additions & 38 deletions src/flint/types/nmod_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ from flint.flintlib.nmod_mat cimport (
nmod_mat_randtest,
)

from flint.utils.conversion cimport matrix_to_str
from flint.utils.typecheck cimport typecheck
from flint.types.fmpz_mat cimport any_as_fmpz_mat
from flint.types.fmpz_mat cimport fmpz_mat
Expand All @@ -50,6 +49,8 @@ from flint.types.nmod_poly cimport nmod_poly
from flint.pyflint cimport global_random_state
from flint.flint_base.flint_context cimport thectx

from flint.flint_base.flint_base cimport flint_mat


ctx = thectx

Expand All @@ -69,10 +70,12 @@ cdef any_as_nmod_mat(obj, nmod_t mod):
return NotImplemented


cdef class nmod_mat:
cdef class nmod_mat(flint_mat):
"""
The nmod_mat type represents dense matrices over Z/nZ for
word-size n. Some operations may assume that n is a prime.
The nmod_mat type represents dense matrices over Z/nZ for word-size n (see
fmpz_mod_mat for larger moduli).
Some operations may assume that n is a prime.
"""

# cdef nmod_mat_t val
Expand Down Expand Up @@ -177,18 +180,6 @@ cdef class nmod_mat:
entries = ', '.join(map(str, self.entries()))
return f"nmod_mat({m}, {n}, [{entries}], {self.modulus()})"

def str(self):
return matrix_to_str(self.table())

def __str__(self):
return self.str()

def __repr__(self):
if ctx.pretty:
return self.str()
else:
return self.repr()

def entries(self):
cdef long i, j, m, n
cdef nmod t
Expand All @@ -202,13 +193,6 @@ cdef class nmod_mat:
L[i*n + j] = t
return L

def table(self):
cdef long i, m, n
m = self.nrows()
n = self.ncols()
L = self.entries()
return [L[i*n:(i+1)*n] for i in range(m)]

def __getitem__(self, index):
cdef long i, j
cdef nmod x
Expand All @@ -229,21 +213,6 @@ cdef class nmod_mat:
else:
raise TypeError("cannot set item of type %s" % type(value))

def det(self):
"""
Returns the determinant of self as an nmod.
>>> nmod_mat(2,2,[1,2,3,4],17).det()
15
"""
if not nmod_mat_is_square(self.val):
raise ValueError("matrix must be square")
return nmod(nmod_mat_det(self.val), self.modulus())

def rank(self):
return nmod_mat_rank(self.val)

def __pos__(self):
return self

Expand Down Expand Up @@ -391,7 +360,29 @@ cdef class nmod_mat:
def __div__(s, t):
return nmod_mat._div_(s, t)

def det(self):
"""
Returns the determinant of self as an nmod.
>>> nmod_mat(2,2,[1,2,3,4],17).det()
15
"""
if not nmod_mat_is_square(self.val):
raise ValueError("matrix must be square")
return nmod(nmod_mat_det(self.val), self.modulus())

def inv(self):
"""
Returns the inverse of self.
>>> from flint import nmod_mat
>>> A = nmod_mat(2,2,[1,2,3,4],17)
>>> A.inv()
[15, 1]
[10, 8]
"""
cdef nmod_mat u
if not nmod_mat_is_square(self.val):
raise ValueError("matrix must be square")
Expand Down Expand Up @@ -486,6 +477,16 @@ cdef class nmod_mat:
rank = nmod_mat_rref((<nmod_mat>res).val)
return res, rank

def rank(self):
"""Return the rank of a matrix.
>>> from flint import nmod_mat
>>> M = nmod_mat([[1, 2], [3, 4]], 11)
>>> M.rank()
2
"""
return nmod_mat_rank(self.val)

def nullspace(self):
"""
Computes a basis for the nullspace of self. Returns (X, nullity)
Expand Down
26 changes: 2 additions & 24 deletions src/flint/utils/conversion.pxd
Original file line number Diff line number Diff line change
@@ -1,40 +1,18 @@
from cpython.version cimport PY_MAJOR_VERSION

cdef inline long prec_to_dps(n):
return max(1, int(round(int(n)/3.3219280948873626)-1))

cdef inline long dps_to_prec(n):
return max(1, int(round((int(n)+1)*3.3219280948873626)))

cdef inline chars_from_str(s):
if PY_MAJOR_VERSION < 3:
return s
else:
return s.encode('ascii')
return s.encode('ascii')

cdef inline str_from_chars(s):
if PY_MAJOR_VERSION < 3:
return str(s)
else:
return bytes(s).decode('ascii')

cdef inline matrix_to_str(tab):
if len(tab) == 0 or len(tab[0]) == 0:
return "[]"
tab = [[str(c) for c in row] for row in tab]
widths = []
for i in xrange(len(tab[0])):
w = max([len(row[i]) for row in tab])
widths.append(w)
for i in xrange(len(tab)):
tab[i] = [s.rjust(widths[j]) for j, s in enumerate(tab[i])]
tab[i] = "[" + (", ".join(tab[i])) + "]"
return "\n".join(tab)
return bytes(s).decode('ascii')

cdef inline _str_trunc(s, trunc=0):
if trunc > 0 and len(s) > 3 * trunc:
left = right = trunc
omitted = len(s) - left - right
return s[:left] + ("{...%s digits...}" % omitted) + s[-right:]
return s

0 comments on commit 67b29cf

Please sign in to comment.