Skip to content

Commit

Permalink
fix: add non-scalar metric expectation to protobufs [DET-4893] [DET-4911
Browse files Browse the repository at this point in the history
] (#1876)
  • Loading branch information
hamidzr authored Feb 3, 2021
1 parent 7f7aa2e commit 48418e6
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 3 deletions.
13 changes: 13 additions & 0 deletions docs/release-notes/1876-unexpected-non-scalar-metrics.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
:orphan:

**Bug Fixes**

- API: Fix an issue where requesting checkpoint or trial details of a
trial that had non-scalar metric values associated with it would
fail.

**Breaking Change**

- Metric values returned by the trial and checkpoint APIs can now
return non-float values and are defined using
https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#google.protobuf.Struct
82 changes: 82 additions & 0 deletions master/test/integration/api/api_trials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/determined-ai/determined/master/internal"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/elastic"

"github.com/determined-ai/determined/master/test/testutils"
Expand Down Expand Up @@ -49,6 +50,87 @@ func TestTrialLogAPIElastic(t *testing.T) {
})
}

func TestTrialDetail(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
_, _, cl, creds, err := testutils.RunMaster(ctx, nil)
defer cancel()
assert.NilError(t, err, "failed to start master")

trialDetailAPITests(t, creds, cl, pgDB)
}

func trialDetailAPITests(
t *testing.T, creds context.Context, cl apiv1.DeterminedClient, db *db.PgDB,
) {
type testCase struct {
name string
req *apiv1.GetTrialRequest
metrics map[string]interface{}
}

testCases := []testCase{
{
name: "scalar metric",
metrics: map[string]interface{}{
"myMetricName": 3,
},
},
{
name: "boolean metric",
metrics: map[string]interface{}{
"myMetricName": true,
},
},
{
name: "null metric",
metrics: map[string]interface{}{
"myMetricName": nil,
},
},
{
name: "list of scalars metric",
metrics: map[string]interface{}{
"myMetricName": []float32{1, 2.3, 3},
},
},
}

runTestCase := func(t *testing.T, tc testCase, id int) {
t.Run(tc.name, func(t *testing.T) {
experiment := testutils.ExperimentModel()
err := db.AddExperiment(experiment)
assert.NilError(t, err, "failed to insert experiment")

trial := testutils.TrialModel(experiment.ID, testutils.WithTrialState(model.ActiveState))
err = db.AddTrial(trial)
assert.NilError(t, err, "failed to insert trial")

step := testutils.StepModel(trial.ID)
step.ID = id
err = db.AddStep(step)
assert.NilError(t, err, "failed to insert step")

metrics := map[string]interface{}{
"avg_metrics": tc.metrics,
}
err = db.UpdateStep(trial.ID, step.ID, model.CompletedState, metrics)
assert.NilError(t, err, "failed to update step")

ctx, _ := context.WithTimeout(creds, 10*time.Second)
req := apiv1.GetTrialRequest{TrialId: int32(trial.ID)}

tlCl, err := cl.GetTrial(ctx, &req)
assert.NilError(t, err, "failed to fetch api details")
assert.Equal(t, len(tlCl.Workloads), 1, "mismatching workload length")
})
}

for idx, tc := range testCases {
runTestCase(t, tc, idx)
}

}

func trialLogAPITests(
t *testing.T, creds context.Context, cl apiv1.DeterminedClient, backend internal.TrialLogBackend,
awaitBackend func() error,
Expand Down
18 changes: 18 additions & 0 deletions master/test/testutils/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,21 @@ func TrialModel(eID int, opts ...TrialModelOption) *model.Trial {
}
return t
}

// StepModelOption is an option that can be applied to a step.
type StepModelOption interface {
apply(*model.Step)
}

// StepModel returns a new step with the specified options.
func StepModel(tID int, opts ...StepModelOption) *model.Step {
t := &model.Step{
TrialID: tID,
State: model.ActiveState,
StartTime: time.Now(),
}
for _, o := range opts {
o.apply(t)
}
return t
}
Binary file modified proto/buf.image.bin
Binary file not shown.
2 changes: 1 addition & 1 deletion proto/src/determined/checkpoint/v1/checkpoint.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ message Metrics {
// Number of inputs to the model.
int32 num_inputs = 1;
// Metrics calculated on the validation set.
map<string, float> validation_metrics = 2;
google.protobuf.Struct validation_metrics = 2;
}

// The current state of the checkpoint.
Expand Down
2 changes: 1 addition & 1 deletion proto/src/determined/trial/v1/trial.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ message MetricsWorkload {
// The current validation state.
determined.experiment.v1.State state = 3;
// Metrics.
map<string, float> metrics = 4;
google.protobuf.Struct metrics = 4;
// Number of inputs processed.
int32 num_inputs = 5;
// Number of batches in this workload.
Expand Down
14 changes: 13 additions & 1 deletion webui/react/src/services/decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import dayjs from 'dayjs';

import * as ioTypes from 'ioTypes';
import * as types from 'types';
import { isNumber, isObject } from 'utils/data';
import { capitalize } from 'utils/string';

import * as Sdk from './api-ts-sdk'; // API Bindings
Expand Down Expand Up @@ -262,10 +263,21 @@ export const decodeExperimentList = (data: Sdk.V1Experiment[]): types.Experiment
return data.map(decodeV1ExperimentToExperimentItem);
};

const filterNonScalarMetrics = (metrics: types.RawJson): types.RawJson | undefined => {
if (!isObject(metrics)) return undefined;
const scalarMetrics: types.RawJson = {};
for (const key in metrics){
if (isNumber(metrics[key])) {
scalarMetrics[key] = metrics[key];
}
}
return scalarMetrics;
};

const decodeMetricsWorkload = (data: Sdk.V1MetricsWorkload): types.MetricsWorkload => {
return {
endTime: data.endTime as unknown as string,
metrics: data.metrics,
metrics: data.metrics ? filterNonScalarMetrics(data.metrics) : undefined,
numBatches: data.numBatches,
priorBatchesProcessed: data.priorBatchesProcessed,
startTime: data.startTime as unknown as string,
Expand Down

0 comments on commit 48418e6

Please sign in to comment.