Skip to content

Commit

Permalink
Merge pull request #100 from fplll/compare
Browse files Browse the repository at this point in the history
Compare BKZ variants
  • Loading branch information
malb committed Feb 11, 2018
2 parents 5fa8de2 + 8077e04 commit 013f25b
Show file tree
Hide file tree
Showing 3 changed files with 495 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/fpylll/algorithms/bkz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fpylll.tools.bkz_stats import BKZTreeTracer, dummy_tracer


class BKZReduction:
class BKZReduction(object):
"""
An implementation of the BKZ algorithm in Python.
Expand Down
103 changes: 55 additions & 48 deletions src/fpylll/tools/bkz_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,46 +63,46 @@ def pretty_dict(d, keyword_width=None, round_bound=9999):
return u"{" + u", ".join(s) + u"}"


class Statistic(object):
class Accumulator(object):
"""
A ``statistic`` collects observed facts about some random variable (e.g. running time).
An ``Accumulator`` collects observed facts about some random variable (e.g. running time).
In particular,
- minimum,
- maximum,
- mean and
- variance
- minimum,
are stored.
- maximum,
>>> v = Statistic(1.0); v
1.0
- mean and
>>> v += 2.0; v
3.0
- variance
>>> v = Statistic(-5.4, repr="avg"); v
-5.4
are recorded::
>>> v += 0.2
>>> v += 5.2; v
0.0
>>> v.min, v.max
(-5.4, 5.2)
>>> v = Accumulator(1.0); v
1.0
>>> v += 2.0; v
3.0
>>> v = Accumulator(-5.4, repr="avg"); v
-5.4
>>> v += 0.2
>>> v += 5.2; v
0.0
>>> v.min, v.max
(-5.4, 5.2)
"""

def __init__(self, value, repr="sum", count=True):
def __init__(self, value, repr="sum", count=True, bessel_correction=False):
"""
Create a new instance.
:param value: some initial value
:param repr: how to represent this statistic: "min", "max", "avg", "sum" and "variance" are
:param repr: how to represent the data: "min", "max", "avg", "sum" and "variance" are
valid choices
:param count: if ``True`` the provided value is considered as an observed datum, i.e. the
counter is increased by one.
:param bessel_correction: apply Bessel's correction to the variance.
"""

self._min = value
Expand All @@ -111,12 +111,13 @@ def __init__(self, value, repr="sum", count=True):
self._sqr = value*value
self._ctr = 1 if count else 0
self._repr = repr
self._bessel_correction = bessel_correction

def add(self, value):
"""
Add value to this statistic.
Add value to the accumulator.
>>> v = Statistic(10.0)
>>> v = Accumulator(10.0)
>>> v.add(5.0)
15.0
Expand All @@ -134,7 +135,7 @@ def add(self, value):
@property
def min(self):
"""
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v += 5.0
>>> v.min
2.0
Expand All @@ -145,7 +146,7 @@ def min(self):
@property
def max(self):
"""
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v += 5.0
>>> v.max
5.0
Expand All @@ -156,18 +157,20 @@ def max(self):
@property
def avg(self):
"""
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v += 5.0
>>> v.avg
3.5
"""
return self._sum/self._ctr

mean = avg

@property
def sum(self):
"""
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v += 5.0
>>> v.sum
7.0
Expand All @@ -178,13 +181,17 @@ def sum(self):
@property
def variance(self):
"""
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v += 5.0
>>> v.variance
2.25
"""
return self._sqr/self._ctr - self.avg**2
s = self._sqr/self._ctr - self.avg**2
if self._bessel_correction:
return self._ctr * (s/(self._ctr-1))
else:
return s

def __add__(self, other):
"""
Expand All @@ -194,7 +201,7 @@ def __add__(self, other):
- ``stat + stat`` returns the sum of their underlying values
- ``stat + value`` inserts ``value`` into ``stat``
>>> v = Statistic(2.0)
>>> v = Accumulator(2.0)
>>> v + None
2.0
>>> v + v
Expand All @@ -205,13 +212,13 @@ def __add__(self, other):
"""
if other is None:
return copy.copy(self)
elif not isinstance(other, Statistic):
elif not isinstance(other, Accumulator):
ret = copy.copy(self)
return ret.add(other)
else:
if self._repr != other._repr:
raise ValueError("%s != %s"%(self._repr, other._repr))
ret = Statistic(0)
ret = Accumulator(0)
ret._min = min(self.min, other.min)
ret._max = max(self.max, other.max)
ret._sum = self._sum + other._sum
Expand All @@ -236,15 +243,15 @@ def __float__(self):
"""
Reduce this stats object down a float depending on strategy chosen in constructor.
>>> v = Statistic(2.0, "min"); v += 3.0; float(v)
>>> v = Accumulator(2.0, "min"); v += 3.0; float(v)
2.0
>>> v = Statistic(2.0, "max"); v += 3.0; float(v)
>>> v = Accumulator(2.0, "max"); v += 3.0; float(v)
3.0
>>> v = Statistic(2.0, "avg"); v += 3.0; float(v)
>>> v = Accumulator(2.0, "avg"); v += 3.0; float(v)
2.5
>>> v = Statistic(2.0, "sum"); v += 3.0; float(v)
>>> v = Accumulator(2.0, "sum"); v += 3.0; float(v)
5.0
>>> v = Statistic(2.0, "variance"); v += 3.0; float(v)
>>> v = Accumulator(2.0, "variance"); v += 3.0; float(v)
0.25
"""
return float(self.__getattribute__(self._repr))
Expand All @@ -253,15 +260,15 @@ def __str__(self):
"""
Reduce this stats object down a float depending on strategy chosen in constructor.
>>> v = Statistic(2.0, "min"); v += 3.0; str(v)
>>> v = Accumulator(2.0, "min"); v += 3.0; str(v)
'2.0'
>>> v = Statistic(2.0, "max"); v += 3.0; str(v)
>>> v = Accumulator(2.0, "max"); v += 3.0; str(v)
'3.0'
>>> v = Statistic(2.0, "avg"); v += 3.0; str(v)
>>> v = Accumulator(2.0, "avg"); v += 3.0; str(v)
'2.5'
>>> v = Statistic(2.0, "sum"); v += 3.0; str(v)
>>> v = Accumulator(2.0, "sum"); v += 3.0; str(v)
'5.0'
>>> v = Statistic(2.0, "variance"); v += 3.0; str(v)
>>> v = Accumulator(2.0, "variance"); v += 3.0; str(v)
'0.25'
"""
return str(self.__getattribute__(self._repr))
Expand Down Expand Up @@ -707,8 +714,8 @@ def reenter(self, **kwds):
"""

node = self.current
node.data["cputime"] = node.data.get("cputime", 0) + Statistic(-time.clock(), repr="sum", count=False)
node.data["walltime"] = node.data.get("walltime", 0) + Statistic(-time.time(), repr="sum", count=False)
node.data["cputime"] = node.data.get("cputime", 0) + Accumulator(-time.clock(), repr="sum", count=False)
node.data["walltime"] = node.data.get("walltime", 0) + Accumulator(-time.time(), repr="sum", count=False)

def exit(self, **kwds):
"""
Expand All @@ -724,19 +731,19 @@ def exit(self, **kwds):
if label == "enumeration":
full = kwds.get("full", True)
if full:
node.data["#enum"] = Statistic(kwds["enum_obj"].get_nodes(), repr="sum") + node.data.get("#enum", None)
node.data["#enum"] = Accumulator(kwds["enum_obj"].get_nodes(), repr="sum") + node.data.get("#enum", None) # noqa
try:
node.data["%"] = Statistic(kwds["probability"], repr="avg") + node.data.get("%", None)
node.data["%"] = Accumulator(kwds["probability"], repr="avg") + node.data.get("%", None)
except KeyError:
pass

if label[0] == "tour":
data = basis_quality(self.instance.M)
for k, v in data.items():
if k == "/":
node.data[k] = Statistic(v, repr="max")
node.data[k] = Accumulator(v, repr="max")
else:
node.data[k] = Statistic(v, repr="min")
node.data[k] = Accumulator(v, repr="min")

if self.verbosity and label[0] == "tour":
report = OrderedDict()
Expand Down

0 comments on commit 013f25b

Please sign in to comment.