Skip to content

Commit

Permalink
Fix up beam registrations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCutter committed Sep 7, 2023
1 parent e356d6e commit a3dbebe
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
14 changes: 10 additions & 4 deletions binary_transparency/firmware/internal/ftmap/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (

func init() {
beam.RegisterFunction(aggregationFn)
beam.RegisterFunction(annotationLogIndexFn)
beam.RegisterFunction(logEntryIndexFn)
beam.RegisterType(reflect.TypeOf((*api.AggregatedFirmware)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*aggregatedFirmwareHashFn)(nil)).Elem())
}
Expand All @@ -40,14 +42,18 @@ func init() {
// - AnnotationMalware: `Good` is true providing there are no malware annotations that claim the
// firmware is bad.
func Aggregate(s beam.Scope, treeID int64, fws, annotationMalwares beam.PCollection) (beam.PCollection, beam.PCollection) {
keyedFws := beam.ParDo(s, func(l *firmwareLogEntry) (uint64, *firmwareLogEntry) { return uint64(l.Index), l }, fws)
keyedAnns := beam.ParDo(s, func(a *annotationMalwareLogEntry) (uint64, *annotationMalwareLogEntry) {
return a.Annotation.FirmwareID.LogIndex, a
}, annotationMalwares)
keyedFws := beam.ParDo(s, logEntryIndexFn, fws)
keyedAnns := beam.ParDo(s, annotationLogIndexFn, annotationMalwares)
annotations := beam.ParDo(s, aggregationFn, beam.CoGroupByKey(s, keyedFws, keyedAnns))
return beam.ParDo(s, &aggregatedFirmwareHashFn{treeID}, annotations), annotations
}

func logEntryIndexFn(l *firmwareLogEntry) (uint64, *firmwareLogEntry) { return uint64(l.Index), l }

func annotationLogIndexFn(a *annotationMalwareLogEntry) (uint64, *annotationMalwareLogEntry) {
return a.Annotation.FirmwareID.LogIndex, a
}

func aggregationFn(fwIndex uint64, fwit func(**firmwareLogEntry) bool, amit func(**annotationMalwareLogEntry) bool) (*api.AggregatedFirmware, error) {
// There will be exactly one firmware entry for the log index.
var fwle *firmwareLogEntry
Expand Down
16 changes: 14 additions & 2 deletions binary_transparency/firmware/internal/ftmap/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@ import (
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
"github.com/google/trillian-examples/binary_transparency/firmware/api"
)

func init() {
register.Function1x1(testAggregationToStringFn)
}

func TestMain(m *testing.M) {
ptest.Main(m)
}

func TestAggregate(t *testing.T) {
fwEntries := []*firmwareLogEntry{
{Index: 0, Firmware: createFW("dummy", 400)},
Expand Down Expand Up @@ -104,8 +113,7 @@ func TestAggregate(t *testing.T) {
passert.Count(s, entries, "entries", len(fwEntries))
passert.Count(s, aggs, "aggs", len(fwEntries))

aggregationToString := func(a *api.AggregatedFirmware) string { return fmt.Sprintf("%d: %t", a.Index, a.Good) }
passert.Equals(s, beam.ParDo(s, aggregationToString, aggs), beam.CreateList(s, test.wantGood))
passert.Equals(s, beam.ParDo(s, testAggregationToStringFn, aggs), beam.CreateList(s, test.wantGood))

err := ptest.Run(p)
if err != nil {
Expand All @@ -114,3 +122,7 @@ func TestAggregate(t *testing.T) {
})
}
}

func testAggregationToStringFn(a *api.AggregatedFirmware) string {
return fmt.Sprintf("%d: %t", a.Index, a.Good)
}
7 changes: 6 additions & 1 deletion binary_transparency/firmware/internal/ftmap/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
)

func init() {
beam.RegisterFunction(logEntryDeviceIDFn)
beam.RegisterFunction(makeDeviceReleaseLogFn)
beam.RegisterType(reflect.TypeOf((*moduleLogHashFn)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*api.DeviceReleaseLog)(nil)).Elem())
Expand All @@ -44,11 +45,15 @@ func init() {
// 1. the first is of type Entry; the key/value data to include in the map
// 2. the second is of type DeviceReleaseLog.
func MakeReleaseLogs(s beam.Scope, treeID int64, logEntries beam.PCollection) (beam.PCollection, beam.PCollection) {
keyed := beam.ParDo(s, func(l *firmwareLogEntry) (string, *firmwareLogEntry) { return l.Firmware.DeviceID, l }, logEntries)
keyed := beam.ParDo(s, logEntryDeviceIDFn, logEntries)
logs := beam.ParDo(s, makeDeviceReleaseLogFn, beam.GroupByKey(s, keyed))
return beam.ParDo(s, &moduleLogHashFn{TreeID: treeID}, logs), logs
}

func logEntryDeviceIDFn(l *firmwareLogEntry) (string, *firmwareLogEntry) {
return l.Firmware.DeviceID, l
}

type moduleLogHashFn struct {
TreeID int64

Expand Down
16 changes: 12 additions & 4 deletions binary_transparency/firmware/internal/ftmap/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ import (
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
"github.com/google/trillian-examples/binary_transparency/firmware/api"
"github.com/google/trillian/experimental/batchmap"
)

func init() {
register.Function1x1(testLogToStringFn)
register.Function1x1(testRootToStringFn)
}

func TestCreate(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -74,11 +80,9 @@ func TestCreate(t *testing.T) {
t.Fatalf("failed to Create(): %v", err)
}

rootToString := func(t *batchmap.Tile) string { return fmt.Sprintf("%x", t.RootHash) }
passert.Equals(s, beam.ParDo(s, rootToString, result.MapTiles), test.wantRoot)
passert.Equals(s, beam.ParDo(s, testRootToStringFn, result.MapTiles), test.wantRoot)

logToString := func(l *api.DeviceReleaseLog) string { return fmt.Sprintf("%s: %v", l.DeviceID, l.Revisions) }
passert.Equals(s, beam.ParDo(s, logToString, result.DeviceLogs), beam.CreateList(s, test.wantLogs))
passert.Equals(s, beam.ParDo(s, testLogToStringFn, result.DeviceLogs), beam.CreateList(s, test.wantLogs))

err = ptest.Run(p)
if err != nil {
Expand All @@ -87,6 +91,10 @@ func TestCreate(t *testing.T) {
})
}
}
func testRootToStringFn(t *batchmap.Tile) string { return fmt.Sprintf("%x", t.RootHash) }
func testLogToStringFn(l *api.DeviceReleaseLog) string {
return fmt.Sprintf("%s: %v", l.DeviceID, l.Revisions)
}

func createFW(device string, revision uint64) api.FirmwareMetadata {
image := fmt.Sprintf("this image is the firmware at revision %d for device %s.", revision, device)
Expand Down

0 comments on commit a3dbebe

Please sign in to comment.