Skip to content

Commit

Permalink
Enforce type hints in daf.py, tests/test_dap.py
Browse files Browse the repository at this point in the history
One notable issue: The signature for `run_daf()` was previously

```
def run_daf(Daf,
            agg_param: Daf.AggParam,
            measurements: list[Daf.Measurement],
            nonces: list[bytes]):
```

We want to treat `Daf` as a generic parameter, but that's not how it's
interpreted by mypy. The first argument, `Daf`, shadows `class Daf`
above, but the type hints here actually refer to the shadowed variable.
To resolve this, simply remove the type hints.
  • Loading branch information
cjpatton committed Jun 11, 2024
1 parent b56f3b1 commit cf93535
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
35 changes: 23 additions & 12 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -934,35 +934,46 @@ longer) if there was an Aggregator class which behaved like an actual aggregator
but with messages being sent by calling functions.
-->
~~~
def run_daf(Daf,
agg_param: Daf.AggParam,
measurements: list[Daf.Measurement],
nonces: list[bytes]):
out_shares = [[] for j in range(Daf.SHARES)]
def run_daf(daf,
agg_param,
measurements,
nonces):
"""
Run a DAF on a list of measurements.
Pre-conditions:
- `type(agg_param) == daf.AggParam`
- `type(measurement) == daf.Measurement` for each
`measurement` in `measurements`
- `len(nonce) == daf.NONCE_SIZE` for each `nonce` in `nonces`
- `len(nonces) == len(measurements)`
"""
out_shares = [[] for j in range(daf.SHARES)]
for (measurement, nonce) in zip(measurements, nonces):
# Each Client shards its measurement into input shares and
# distributes them among the Aggregators.
rand = gen_rand(Daf.RAND_SIZE)
rand = gen_rand(daf.RAND_SIZE)
(public_share, input_shares) = \
Daf.shard(measurement, nonce, rand)
daf.shard(measurement, nonce, rand)

# Each Aggregator prepares its input share for aggregation.
for j in range(Daf.SHARES):
for j in range(daf.SHARES):
out_shares[j].append(
Daf.prep(j, agg_param, nonce,
daf.prep(j, agg_param, nonce,
public_share, input_shares[j]))

# Each Aggregator aggregates its output shares into an aggregate
# share and sends it to the Collector.
agg_shares = []
for j in range(Daf.SHARES):
agg_share_j = Daf.aggregate(agg_param,
for j in range(daf.SHARES):
agg_share_j = daf.aggregate(agg_param,
out_shares[j])
agg_shares.append(agg_share_j)

# Collector unshards the aggregate result.
num_measurements = len(measurements)
agg_result = Daf.unshard(agg_param, agg_shares,
agg_result = daf.unshard(agg_param, agg_shares,
num_measurements)
return agg_result
~~~
Expand Down
54 changes: 29 additions & 25 deletions poc/daf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Definition of DAFs."""

from __future__ import annotations
from typing import Any

from common import gen_rand

Expand All @@ -9,37 +9,37 @@ class Daf:
"""A DAF"""

# Algorithm identifier for this DAF, a 32-bit integer.
ID: int = None
ID: int

# The number of Aggregators.
SHARES: int = None
SHARES: int

# Length of the nonce.
NONCE_SIZE = None
NONCE_SIZE: int

# Number of random bytes consumed by `shard()`.
RAND_SIZE = None
RAND_SIZE: int

# The measurement type.
Measurement = None
Measurement: Any = None

# The aggregation parameter type.
AggParam = None
AggParam: Any = None

# The public share type.
PublicShare = None
PublicShare: Any = None

# The input share type.
InputShare = None
InputShare: Any = None

# The output share type.
OutShare = None
OutShare: Any = None

# The aggregate share type.
AggShare = None
AggShare: Any = None

# The aggregate result type.
AggResult = None
AggResult: Any = None

@classmethod
def shard(Daf,
Expand Down Expand Up @@ -112,41 +112,45 @@ def unshard(Daf,
raise NotImplementedError()


def run_daf(Daf,
agg_param: Daf.AggParam,
measurements: list[Daf.Measurement],
nonces: list[bytes]):
def run_daf(daf,
agg_param,
measurements,
nonces):
"""
Run a DAF on a list of measurements.
Pre-conditions:
- `len(nonce) == Daf.NONCE_SIZE` for each `nonce` in `nonces`
- `type(agg_param) == daf.AggParam`
- `type(measurement) == daf.Measurement` for each
`measurement` in `measurements`
- `len(nonce) == daf.NONCE_SIZE` for each `nonce` in `nonces`
- `len(nonces) == len(measurements)`
"""
out_shares = [[] for j in range(Daf.SHARES)]
out_shares = [[] for j in range(daf.SHARES)]
for (measurement, nonce) in zip(measurements, nonces):
# Each Client shards its measurement into input shares and
# distributes them among the Aggregators.
rand = gen_rand(Daf.RAND_SIZE)
rand = gen_rand(daf.RAND_SIZE)
(public_share, input_shares) = \
Daf.shard(measurement, nonce, rand)
daf.shard(measurement, nonce, rand)

# Each Aggregator prepares its input share for aggregation.
for j in range(Daf.SHARES):
for j in range(daf.SHARES):
out_shares[j].append(
Daf.prep(j, agg_param, nonce,
daf.prep(j, agg_param, nonce,
public_share, input_shares[j]))

# Each Aggregator aggregates its output shares into an aggregate
# share and sends it to the Collector.
agg_shares = []
for j in range(Daf.SHARES):
agg_share_j = Daf.aggregate(agg_param,
for j in range(daf.SHARES):
agg_share_j = daf.aggregate(agg_param,
out_shares[j])
agg_shares.append(agg_share_j)

# Collector unshards the aggregate result.
num_measurements = len(measurements)
agg_result = Daf.unshard(agg_param, agg_shares,
agg_result = daf.unshard(agg_param, agg_shares,
num_measurements)
return agg_result

0 comments on commit cf93535

Please sign in to comment.