Skip to content

Commit

Permalink
Enforce type hints in flp.py, flp_generic.py and tests
Browse files Browse the repository at this point in the history
This revealed that the "associated type" pattern we're
following is not enforceable.

For example, we currently do something like this:

```
class Foo:
    # We expect this to be set by the superclass.
    Measurement = None
    def foo(self, bar: Measurement): pass
```

Running `sage -python -m mypy hella.py`:

```
hella.py:4: error: Variable "hella.Foo.Measurement" is not valid as a type  [valid-type]
hella.py:4: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
Found 1 error in 1 file (checked 1 source file)
```

So let's make `Measurement` a type:
  • Loading branch information
cjpatton committed Jun 8, 2024
1 parent 3ac0362 commit 43f14f4
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 272 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ jobs:

- name: Enforce type hints
working-directory: poc
run: sage -python -m mypy xof.py field.py
# TODO(#59) sage -python -m mypy *.py tests/*.py
run: sage -python -m mypy xof.py field.py flp.py flp_generic.py tests/test_xof.py tests/test_field.py tests/test_flp.py tests/test_flp_generic.py
68 changes: 40 additions & 28 deletions poc/flp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
"""Fully linear proof (FLP) systems."""

from typing import TypeVar, Generic

import field
from common import vec_add, vec_sub
from field import Field


class Flp:
"""The base class for FLPs."""
M = TypeVar('M')
A = TypeVar('A')
F = TypeVar('F', bound=field.Field)


class Flp(Generic[M, A, F]):
"""
The base class for FLPs.
# Generic paraemters
Measurement = None
AggResult = None
Field: field.Field = None
Generic parameters:
- `M` is the measurement
- `A` is the aggregate result
- `F` is the field
"""

# Length of the joint randomness shared by the prover and verifier.
JOINT_RAND_LEN: int
Expand All @@ -34,14 +43,17 @@ class Flp:
# Length of the verifier message.
VERIFIER_LEN: int

def encode(self, measurement: Measurement) -> list[Field]:
# Operational parameters
field: type[F]

def encode(self, measurement: M) -> list[F]:
"""Encode a measurement."""
raise NotImplementedError()

def prove(self,
meas: list[Field],
prove_rand: list[Field],
joint_rand: list[Field]) -> list[Field]:
meas: list[F],
prove_rand: list[F],
joint_rand: list[F]) -> list[F]:
"""
Generate a proof of a measurement's validity.
Expand All @@ -54,11 +66,11 @@ def prove(self,
raise NotImplementedError()

def query(self,
meas: list[Field],
proof: list[Field],
query_rand: list[Field],
joint_rand: list[Field],
num_shares: int) -> list[Field]:
meas: list[F],
proof: list[F],
query_rand: list[F],
joint_rand: list[F],
num_shares: int) -> list[F]:
"""
Generate a verifier message for a measurement and proof.
Expand All @@ -72,7 +84,7 @@ def query(self,
"""
raise NotImplementedError()

def decide(self, verifier: list[Field]) -> bool:
def decide(self, verifier: list[F]) -> bool:
"""
Decide if a verifier message was generated from a valid measurement.
Expand All @@ -82,7 +94,7 @@ def decide(self, verifier: list[Field]) -> bool:
"""
raise NotImplementedError()

def truncate(self, meas: list[Field]) -> list[Field]:
def truncate(self, meas: list[F]) -> list[F]:
"""
Map an encoded measurement to an aggregatable output.
Expand All @@ -92,7 +104,7 @@ def truncate(self, meas: list[Field]) -> list[Field]:
"""
raise NotImplementedError()

def decode(self, output: list[Field], num_measurements: int) -> AggResult:
def decode(self, output: list[F], num_measurements: int) -> A:
"""
Decode an aggregate result.
Expand All @@ -111,9 +123,9 @@ def test_vec_set_type_param(self, test_vec) -> list[str]:
return []


def additive_secret_share(vec: list[Field],
def additive_secret_share(vec: list[F],
num_shares: int,
field: type) -> list[list[Field]]:
field: type[F]) -> list[list[F]]:
shares = [
field.rand_vec(len(vec))
for _ in range(num_shares - 1)
Expand All @@ -126,19 +138,19 @@ def additive_secret_share(vec: list[Field],


# NOTE This is used to generate {{run-flp}}.
def run_flp(flp, meas, num_shares):
def run_flp(flp: Flp[M, A, F], meas: list[F], num_shares: int):
"""Run the FLP on an encoded measurement."""

joint_rand = flp.Field.rand_vec(flp.JOINT_RAND_LEN)
prove_rand = flp.Field.rand_vec(flp.PROVE_RAND_LEN)
query_rand = flp.Field.rand_vec(flp.QUERY_RAND_LEN)
joint_rand = flp.field.rand_vec(flp.JOINT_RAND_LEN)
prove_rand = flp.field.rand_vec(flp.PROVE_RAND_LEN)
query_rand = flp.field.rand_vec(flp.QUERY_RAND_LEN)

# Prover generates the proof.
proof = flp.prove(meas, prove_rand, joint_rand)

# Shard the measurement and the proof.
meas_shares = additive_secret_share(meas, num_shares, flp.Field)
proof_shares = additive_secret_share(proof, num_shares, flp.Field)
meas_shares = additive_secret_share(meas, num_shares, flp.field)
proof_shares = additive_secret_share(proof, num_shares, flp.field)

# Verifier queries the meas shares and proof shares.
verifier_shares = [
Expand All @@ -153,7 +165,7 @@ def run_flp(flp, meas, num_shares):
]

# Combine the verifier shares into the verifier.
verifier = flp.Field.zeros(len(verifier_shares[0]))
verifier = flp.field.zeros(len(verifier_shares[0]))
for verifier_share in verifier_shares:
verifier = vec_add(verifier, verifier_share)

Expand Down
Loading

0 comments on commit 43f14f4

Please sign in to comment.