Skip to content

Commit

Permalink
Merge pull request #95 from fplll/subsolutions
Browse files Browse the repository at this point in the history
enable sub-solutions
  • Loading branch information
malb committed Aug 10, 2017
2 parents 15fe116 + 911cbd7 commit e32d7eb
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/bkz.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def svp_call(self, kappa, block_size, params, tracer=dummy_tracer):
try:
enum_obj = Enumeration(self.M)
with tracer.context("enumeration", enum_obj=enum_obj, probability=1.0):
solution, max_dist = enum_obj.enumerate(kappa, kappa + block_size, max_dist, expo)[0]
max_dist, solution = enum_obj.enumerate(kappa, kappa + block_size, max_dist, expo)[0]

except EnumerationError as msg:
if params.flags & BKZ.GH_BND:
Expand Down
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/bkz2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def svp_reduction(self, kappa, block_size, params, tracer=dummy_tracer):
enum_obj=enum_obj,
probability=pruning.expectation,
full=block_size==params.block_size):
solution, max_dist = enum_obj.enumerate(kappa, kappa + block_size, radius, 0,
max_dist, solution = enum_obj.enumerate(kappa, kappa + block_size, radius, 0,
pruning=pruning.coefficients)[0]
with tracer.context("postprocessing"):
self.svp_postprocessing(kappa, block_size, solution, tracer=tracer)
Expand Down
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/pbkz.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def parallel_svp_reduction_worker(self, kappa, block_size, params, rerandomize):
enum_obj=enum_obj,
probability=pruning.expectation,
full=block_size==params.block_size):
solution, max_dist = enum_obj.enumerate(kappa, kappa + block_size, radius, expo,
max_dist, solution = enum_obj.enumerate(kappa, kappa + block_size, radius, expo,
pruning=pruning.coefficients)[0]
with tracer.context("postprocessing"):
# we translate our solution to the canonical basis because our basis is not
Expand Down
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/simple_bkz.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def svp_reduction(self, kappa, block_size):
max_dist, expo = self.m.get_r_exp(kappa, kappa)
delta_max_dist = self.lll_obj.delta * max_dist

solution, max_dist = Enumeration(self.m).enumerate(kappa, kappa + block_size, max_dist, expo, pruning=None)[0]
max_dist, solution = Enumeration(self.m).enumerate(kappa, kappa + block_size, max_dist, expo, pruning=None)[0]

if max_dist >= delta_max_dist * (1<<expo):
return clean
Expand Down
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/simple_dbkz.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def dsvp_reduction(self, kappa, block_size):
expo *= -1.0
delta_max_dist = self.lll_obj.delta * max_dist

solution, max_dist = Enumeration(self.m).enumerate(kappa, kappa + block_size, max_dist, expo,
max_dist, solution = Enumeration(self.m).enumerate(kappa, kappa + block_size, max_dist, expo,
pruning=None, dual=True)[0]
if max_dist >= delta_max_dist:
return clean
Expand Down
110 changes: 75 additions & 35 deletions src/fpylll/fplll/enumeration.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ class EvaluatorStrategy:


cdef class Enumeration:
def __init__(self, MatGSO M, nr_solutions=1, strategy=EvaluatorStrategy.BEST_N_SOLUTIONS):
def __cinit__(self, MatGSO M, int nr_solutions=1,
strategy=EvaluatorStrategy.BEST_N_SOLUTIONS, bool sub_solutions=False):
"""Create new enumeration object
:param MatGSO M: GSO matrix
:param nr_solutions: Number of solutions to be returned by enumeration
:param strategy: EvaluatorStrategy to use when finding new solutions
:param MatGSO M: GSO matrix
:param nr_solutions: Number of solutions to be returned by enumeration
:param strategy: EvaluatorStrategy to use when finding new solutions
:param sub_solutions: Compute sub-solutions
"""

cdef MatGSO_c[Z_NR[mpz_t], FP_NR[d_t]] *m_mpz_d
Expand All @@ -89,45 +91,45 @@ cdef class Enumeration:
if M._type == gso_mpz_d:
m_mpz_d = M._core.mpz_d
self._fe_core.d = new FastEvaluator_c[FP_NR[double]](nr_solutions,
strategy,
False)
strategy,
sub_solutions)
self._core.mpz_d = new Enumeration_c[Z_NR[mpz_t], FP_NR[double]](m_mpz_d[0], self._fe_core.d[0])
elif M._type == gso_long_d:
m_l_d = M._core.long_d
self._fe_core.d = new FastEvaluator_c[FP_NR[double]](nr_solutions,
strategy,
False)
sub_solutions)
self._core.long_d = new Enumeration_c[Z_NR[long], FP_NR[double]](m_l_d[0], self._fe_core.d[0])
elif M._type == gso_mpz_ld:
IF HAVE_LONG_DOUBLE:
m_mpz_ld = M._core.mpz_ld
self._fe_core.ld = new FastEvaluator_c[FP_NR[longdouble]](nr_solutions,
strategy,
False)
strategy,
sub_solutions)
self._core.mpz_ld = new Enumeration_c[Z_NR[mpz_t], FP_NR[ld_t]](m_mpz_ld[0], self._fe_core.ld[0])
ELSE:
raise RuntimeError("MatGSO object '%s' has no core."%self)
elif M._type == gso_long_ld:
IF HAVE_LONG_DOUBLE:
m_l_ld = M._core.long_ld
self._fe_core.ld = new FastEvaluator_c[FP_NR[longdouble]](nr_solutions,
strategy,
False)
strategy,
sub_solutions)
self._core.long_ld = new Enumeration_c[Z_NR[long], FP_NR[ld_t]](m_l_ld[0], self._fe_core.ld[0])
ELSE:
raise RuntimeError("MatGSO object '%s' has no core."%self)
elif M._type == gso_mpz_dpe:
m_mpz_dpe = M._core.mpz_dpe
self._fe_core.dpe = new FastEvaluator_c[FP_NR[dpe_t]](nr_solutions,
strategy,
False)
sub_solutions)
self._core.mpz_dpe = new Enumeration_c[Z_NR[mpz_t], FP_NR[dpe_t]](m_mpz_dpe[0], self._fe_core.dpe[0])
elif M._type == gso_long_dpe:
m_long_dpe = M._core.long_dpe
m_l_dpe = M._core.long_dpe
self._fe_core.dpe = new FastEvaluator_c[FP_NR[dpe_t]](nr_solutions,
strategy,
False)
self._core.long_dpe = new Enumeration_c[Z_NR[long], FP_NR[dpe_t]](m_long_dpe[0], self._fe_core.dpe[0])
sub_solutions)
self._core.long_dpe = new Enumeration_c[Z_NR[long], FP_NR[dpe_t]](m_l_dpe[0], self._fe_core.dpe[0])
elif M._type == gso_mpz_mpfr:
m_mpz_mpfr = M._core.mpz_mpfr
self._fe_core.mpfr = new FastErrorBoundedEvaluator_c(M.d,
Expand All @@ -136,44 +138,44 @@ cdef class Enumeration:
EVALMODE_SV,
nr_solutions,
strategy,
False)
sub_solutions)
self._core.mpz_mpfr = new Enumeration_c[Z_NR[mpz_t], FP_NR[mpfr_t]](m_mpz_mpfr[0], self._fe_core.mpfr[0])
elif M._type == gso_long_mpfr:
m_long_mpfr = M._core.long_mpfr
m_l_mpfr = M._core.long_mpfr
self._fe_core.mpfr = new FastErrorBoundedEvaluator_c(M.d,
M._core.long_mpfr.get_mu_matrix(),
M._core.long_mpfr.get_r_matrix(),
EVALMODE_SV,
nr_solutions,
strategy,
False)
self._core.long_mpfr = new Enumeration_c[Z_NR[long], FP_NR[mpfr_t]](m_long_mpfr[0], self._fe_core.mpfr[0])
sub_solutions)
self._core.long_mpfr = new Enumeration_c[Z_NR[long], FP_NR[mpfr_t]](m_l_mpfr[0], self._fe_core.mpfr[0])
else:
IF HAVE_QD:
if M._type == gso_mpz_dd:
m_mpz_dd = M._core.mpz_dd
self._fe_core.dd = new FastEvaluator_c[FP_NR[dd_t]](nr_solutions,
strategy,
False)
sub_solutions)
self._core.mpz_dd = new Enumeration_c[Z_NR[mpz_t], FP_NR[dd_t]](m_mpz_dd[0], self._fe_core.dd[0])
elif M._type == gso_mpz_qd:
m_mpz_qd = M._core.mpz_qd
self._fe_core.qd = new FastEvaluator_c[FP_NR[qd_t]](nr_solutions,
strategy,
False)
strategy,
sub_solutions)
self._core.mpz_qd = new Enumeration_c[Z_NR[mpz_t], FP_NR[qd_t]](m_mpz_qd[0], self._fe_core.qd[0])
elif M._type == gso_long_dd:
m_long_dd = M._core.long_dd
m_l_dd = M._core.long_dd
self._fe_core.dd = new FastEvaluator_c[FP_NR[dd_t]](nr_solutions,
strategy,
False)
self._core.long_dd = new Enumeration_c[Z_NR[long], FP_NR[dd_t]](m_long_dd[0], self._fe_core.dd[0])
sub_solutions)
self._core.long_dd = new Enumeration_c[Z_NR[long], FP_NR[dd_t]](m_l_dd[0], self._fe_core.dd[0])
elif M._type == gso_long_qd:
m_long_qd = M._core.long_qd
m_l_qd = M._core.long_qd
self._fe_core.qd = new FastEvaluator_c[FP_NR[qd_t]](nr_solutions,
strategy,
False)
self._core.long_qd = new Enumeration_c[Z_NR[long], FP_NR[qd_t]](m_long_qd[0], self._fe_core.qd[0])
strategy,
sub_solutions)
self._core.long_qd = new Enumeration_c[Z_NR[long], FP_NR[qd_t]](m_l_qd[0], self._fe_core.qd[0])
else:
raise RuntimeError("MatGSO object '%s' has no core."%self)
ELSE:
Expand Down Expand Up @@ -306,7 +308,7 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_d).second.size()):
cur_sol.append(deref(solutions_d).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_d)

IF HAVE_LONG_DOUBLE:
Expand All @@ -332,7 +334,7 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_ld).second.size()):
cur_sol.append(deref(solutions_ld).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_ld)

if self.M._type == gso_mpz_dpe:
Expand All @@ -357,7 +359,7 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_dpe).second.size()):
cur_sol.append(deref(solutions_dpe).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_dpe)

IF HAVE_QD:
Expand All @@ -383,7 +385,7 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_dd).second.size()):
cur_sol.append(deref(solutions_dd).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_dd)

if self.M._type == gso_mpz_qd:
Expand All @@ -408,7 +410,7 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_qd).second.size()):
cur_sol.append(deref(solutions_qd).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_qd)

if self.M._type == gso_mpz_mpfr:
Expand All @@ -433,11 +435,49 @@ cdef class Enumeration:
cur_sol = []
for j in range(deref(solutions_mpfr).second.size()):
cur_sol.append(deref(solutions_mpfr).second[j].get_d())
solutions.append([tuple(cur_sol), cur_dist])
solutions.append((cur_dist, tuple(cur_sol)))
inc(solutions_mpfr)

return solutions

@property
def sub_solutions(self):
"""
Return sub-solutions computed in last enumeration call.
>>> from fpylll import *
>>> set_random_seed(1337)
>>> A = IntegerMatrix.random(80, "qary", bits=30, k=40)
>>> _ = LLL.reduction(A)
>>> M = GSO.Mat(A)
>>> _ = M.update_gso()
>>> pruning = Pruning.run(M.get_r(0, 0), 2**40, M.r()[:30], 0.2)
>>> enum = Enumeration(M, strategy=EvaluatorStrategy.BEST_N_SOLUTIONS, sub_solutions=True)
>>> _ = enum.enumerate(0, 30, 0.999*M.get_r(0, 0), 0, pruning=pruning.coefficients)
>>> [int(a) for a,b in enum.sub_solutions[:5]]
[5569754193, 5556022461, 5083806188, 5022873439, 4260865082]
"""
cdef list sub_solutions = []

cdef vector[pair[FP_NR[d_t], vector[FP_NR[d_t]]]].iterator _sub_solutions_d

if self.M._type == gso_mpz_d or self.M._type == gso_long_d:
_sub_solutions_d = self._fe_core.d.sub_solutions.begin()
while _sub_solutions_d != self._fe_core.d.sub_solutions.end():
cur_dist = deref(_sub_solutions_d).first.get_d()
if cur_dist == 0.0:
cur_dist = None
cur_sol = []
for j in range(deref(_sub_solutions_d).second.size()):
cur_sol.append(deref(_sub_solutions_d).second[j].get_d())
sub_solutions.append(tuple([cur_dist, tuple(cur_sol)]))
inc(_sub_solutions_d)
else:
raise NotImplementedError

return tuple(sub_solutions)

def get_nodes(self):
"""Return number of visited nodes in last enumeration call.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/fpylll/fplll/gso.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ cdef class MatGSO:
>>> M.update_gso()
True
>>> M.get_r(1, 0)
833.0
890.0
"""
preprocess_indices(i, j, self.d, self.d)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_cvp():
v0 = CVP.closest_vector(A, t)

E = Enumeration(M)
v1, _ = E.enumerate(0, A.nrows, 2, 40, M.from_canonical(t))[0]
_, v1 = E.enumerate(0, A.nrows, 2, 40, M.from_canonical(t))[0]
v1 = IntegerMatrix.from_iterable(1, A.nrows, map(lambda x: int(round(x)), v1))
v1 = tuple((v1*A)[0])

Expand Down
4 changes: 2 additions & 2 deletions tests/test_multisol_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_multisol():
solutions = []
solutions = Enumeration(m, nr_solutions=200).enumerate(0, 27, 48.5, 0)
assert len(solutions)== 126 / 2
for sol, _ in solutions:
for _, sol in solutions:
sol = IntegerMatrix.from_iterable(1, A.nrows, map(lambda x: int(round(x)), sol))
sol = tuple((sol*A)[0])
dist = sum([x**2 for x in sol])
Expand All @@ -66,7 +66,7 @@ def test_multisol():
solutions = []
solutions = Enumeration(m, nr_solutions=126 / 2).enumerate(0, 27, 100., 0)
assert len(solutions)== 126 / 2
for sol, _ in solutions:
for _, sol in solutions:
sol = IntegerMatrix.from_iterable(1, A.nrows, map(lambda x: int(round(x)), sol))
sol = tuple((sol*A)[0])
dist = sum([x**2 for x in sol])
Expand Down

0 comments on commit e32d7eb

Please sign in to comment.