Skip to content

Commit

Permalink
checkpoints with bun
Browse files Browse the repository at this point in the history
  • Loading branch information
rb-determined-ai committed Mar 21, 2022
1 parent 51b2a08 commit 30f77ed
Show file tree
Hide file tree
Showing 16 changed files with 678 additions and 131 deletions.
37 changes: 37 additions & 0 deletions master/TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# git grep -liE 'from (raw_|)checkpoints' static/srv internal

DONE: static/srv/get_checkpoint.sql
DONE: static/srv/get_checkpoints_for_trial.sql
DONE: static/srv/get_model_version.sql

static/srv/get_checkpoints_for_experiment.sql
static/srv/get_model_versions.sql
static/srv/get_trial.sql
static/srv/get_trial_metrics.sql
static/srv/insert_model_version.sql
static/srv/proto_get_trial_workloads.sql
static/srv/update_model_version.sql
internal/db/postgres_experiments.go
internal/db/postgres_trial.go

static/srv/get_checkpoint.sql
internal/api_checkpoint.go::GetCheckpoint()
Replaced with ckpts.Get()
internal/api_model.go::PostModelVersion()
Used only to verify checkpoint existence and state==COMPLETED.
Replaced with new method ckpts.State(), which only touches the
checkpoints tables (or at least, there's an XXX to make it so)

static/srv/get_checkpoints_for_trial.sql
internal/api_trials.go::GetTrialCheckpoints()
Replaced with cktps.Count()/ckpts.List() in a bun.Tx
was a 404 possible with no checkpoints? It's not clear to me.

# this was included indirectly, to make get_model_verision easier
static/srv/get_model.sql
internal/api_model.go::GetModel()
replaced with models.ByName()

static/srv/get_model_version.sql
internal/api_model.go::GetModelVerison()
replaced with models.ByNameTx, models.VersionTx, ckpts.ByIDTx
2 changes: 1 addition & 1 deletion master/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/determined-ai/determined/master

go 1.17
go 1.18

require (
cloud.google.com/go v0.94.0
Expand Down
20 changes: 13 additions & 7 deletions master/internal/api_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,29 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/internal/ckpts"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
)

func (a *apiServer) GetCheckpoint(
_ context.Context, req *apiv1.GetCheckpointRequest) (*apiv1.GetCheckpointResponse, error) {
resp := &apiv1.GetCheckpointResponse{}
resp.Checkpoint = &checkpointv1.Checkpoint{}
switch err := a.m.db.QueryProto("get_checkpoint", resp.Checkpoint, req.CheckpointUuid); err {
case db.ErrNotFound:
ctx context.Context, req *apiv1.GetCheckpointRequest) (*apiv1.GetCheckpointResponse, error) {
ckpt, err := ckpts.ByUUID(ctx, req.CheckpointUuid)
switch {
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "checkpoint %s not found", req.CheckpointUuid)
default:
return resp,
case err != nil:
return nil,
errors.Wrapf(err, "error fetching checkpoint %s from database", req.CheckpointUuid)
}

pc := protoutils.ProtoConverter{}
protoCkpt := ckpt.ToProto(&pc)
resp := &apiv1.GetCheckpointResponse{Checkpoint: &protoCkpt}
return resp, pc.Error()
}

func (a *apiServer) PostCheckpointMetadata(
Expand Down
96 changes: 61 additions & 35 deletions master/internal/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,38 @@ import (

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/uptrace/bun"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/internal/ckpts"
"github.com/determined-ai/determined/master/internal/models"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
"github.com/determined-ai/determined/proto/pkg/modelv1"

structpb "github.com/golang/protobuf/ptypes/struct"
)

func (a *apiServer) GetModel(
_ context.Context, req *apiv1.GetModelRequest) (*apiv1.GetModelResponse, error) {
m := &modelv1.Model{}
switch err := a.m.db.QueryProto("get_model", m, req.ModelName); err {
case db.ErrNotFound:
ctx context.Context, req *apiv1.GetModelRequest) (*apiv1.GetModelResponse, error) {
m, err := models.ByName(ctx, req.ModelName)
switch {
// XXX: I've noticed this isn't actually returned by GetModel at all.
// Maybe database getters should return nullable types to make it much more obvious when
// the thing was not found?
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "model \"%s\" not found", req.ModelName)
default:
return &apiv1.GetModelResponse{Model: m},
errors.Wrapf(err, "error fetching model \"%s\" from database", req.ModelName)
case err != nil:
return nil, errors.Wrapf(err, "error fetching model \"%s\" from database", req.ModelName)
}
pc := protoutils.ProtoConverter{}
mv1 := m.ToProto(&pc)
return &apiv1.GetModelResponse{Model: &mv1}, pc.Error()
}

func (a *apiServer) GetModels(
Expand Down Expand Up @@ -291,23 +300,44 @@ func (a *apiServer) DeleteModel(
}

func (a *apiServer) GetModelVersion(
ctx context.Context, req *apiv1.GetModelVersionRequest) (*apiv1.GetModelVersionResponse, error) {
parentModel, err := a.GetModel(ctx, &apiv1.GetModelRequest{ModelName: req.ModelName})
if err != nil {
return nil, err
}

ctx context.Context, req *apiv1.GetModelVersionRequest,
) (*apiv1.GetModelVersionResponse, error) {
resp := &apiv1.GetModelVersionResponse{}
resp.ModelVersion = &modelv1.ModelVersion{}

switch err := a.m.db.QueryProto(
"get_model_version", resp.ModelVersion, parentModel.Model.Id, req.ModelVersion); {
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
default:
return resp, err
}
return resp, db.Bun.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
m, err := models.ByName(ctx, req.ModelName)
// XXX actually handle 404
switch {
case err == db.ErrNotFound:
return status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
case err != nil:
return err
}

// XXX: make models.Version accept int or int32, with generics
mv, err := models.VersionTx(ctx, tx, m.ID, int(req.ModelVersion))
// XXX actually handle 404
switch {
case err == db.ErrNotFound:
return status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
case err != nil:
return err
}

ckpt, err := ckpts.ByIDTx(ctx, tx, mv.CheckpointID)
// XXX: not even sure what this should error out as
if err != nil {
return err
}

pc := protoutils.ProtoConverter{}
mvv1 := mv.ToProto(&pc, m, ckpt)
resp.ModelVersion = &mvv1;

return pc.Error()
})
}

func (a *apiServer) GetModelVersions(
Expand Down Expand Up @@ -340,21 +370,17 @@ func (a *apiServer) PostModelVersion(
modelResp.Model.Name)
}

// make sure the checkpoint exists
c := &checkpointv1.Checkpoint{}

switch getCheckpointErr := a.m.db.QueryProto("get_checkpoint", c, req.CheckpointUuid); {
case getCheckpointErr == db.ErrNotFound:
// make sure the checkpoint exists in a COMPLETED state
if state, err := ckpts.State(ctx, req.CheckpointUuid); err != nil {
return nil, err
} else if state == nil {
return nil, status.Errorf(
codes.NotFound, "checkpoint %s not found", req.CheckpointUuid)
case getCheckpointErr != nil:
return nil, getCheckpointErr
}

if c.State != checkpointv1.State_STATE_COMPLETED {
// XXX: the old code used checkpointv1.State_STATE_COMPLETED... why though?
} else if *state != model.CompletedState {
return nil, errors.Errorf(
"checkpoint %s is in %s state. checkpoints for model versions must be in a COMPLETED state",
c.Uuid, c.State,
req.CheckpointUuid, state,
)
}

Expand All @@ -377,7 +403,7 @@ func (a *apiServer) PostModelVersion(
"insert_model_version",
respModelVersion.ModelVersion,
modelResp.Model.Id,
c.Uuid,
req.CheckpointUuid,
req.Name,
req.Comment,
mdata,
Expand Down
115 changes: 80 additions & 35 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import (
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/ckpts"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/sproto"
Expand Down Expand Up @@ -299,7 +301,7 @@ func (a *apiServer) TrialLogsFields(
}

func (a *apiServer) GetTrialCheckpoints(
_ context.Context, req *apiv1.GetTrialCheckpointsRequest,
ctx context.Context, req *apiv1.GetTrialCheckpointsRequest,
) (*apiv1.GetTrialCheckpointsResponse, error) {
switch exists, err := a.m.db.CheckTrialExists(int(req.Id)); {
case err != nil:
Expand All @@ -308,52 +310,95 @@ func (a *apiServer) GetTrialCheckpoints(
return nil, status.Error(codes.NotFound, "trial not found")
}

resp := &apiv1.GetTrialCheckpointsResponse{}
resp.Checkpoints = []*checkpointv1.Checkpoint{}

switch err := a.m.db.QueryProto("get_checkpoints_for_trial", &resp.Checkpoints, req.Id); {
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "no checkpoints found for trial %d", req.Id)
case err != nil:
return nil,
errors.Wrapf(err, "error fetching checkpoints for trial %d from database", req.Id)
}
ext := db.QueryExt{}

a.filter(&resp.Checkpoints, func(i int) bool {
v := resp.Checkpoints[i]
// always filter by trial
ext = ext.Where("trial_id = ?", req.Id)

found := false
for _, state := range req.States {
if state == v.State {
found = true
break
if len(req.ValidationStates) > 0 {
var states []string
for _, vstate := range req.ValidationStates {
switch vstate {
case checkpointv1.State_STATE_ACTIVE:
states = append(states, "'ACTIVE'")
case checkpointv1.State_STATE_COMPLETED:
states = append(states, "'COMPLETED'")
case checkpointv1.State_STATE_ERROR:
states = append(states, "'ERROR'")
case checkpointv1.State_STATE_DELETED:
states = append(states, "'DELETED'")
default:
// XXX: throw an error to the user
}
}
// filter by state
ext = ext.Where("checkpoint_state IN (?)", bun.In(states))
}

// choose an ordering
asc := " ASC"
if req.OrderBy == apiv1.OrderBy_ORDER_BY_DESC {
asc = " DESC"
}
switch req.SortBy {
case apiv1.GetTrialCheckpointsRequest_SORT_BY_UNSPECIFIED:
// noop, use the default
case apiv1.GetTrialCheckpointsRequest_SORT_BY_BATCH_NUMBER:
// noop, this is the default
case apiv1.GetTrialCheckpointsRequest_SORT_BY_START_TIME:
// noop, this field doesn't exist anymore
case apiv1.GetTrialCheckpointsRequest_SORT_BY_END_TIME:
ext = ext.Order("report_time" + asc)
case apiv1.GetTrialCheckpointsRequest_SORT_BY_VALIDATION_STATE:
ext = ext.Order("validation_metrics != 'null'::jsonb" + asc)
case apiv1.GetTrialCheckpointsRequest_SORT_BY_STATE:
ext = ext.Order("state" + asc)
}

// secondary/default ordering is based on latest_batch
// XXX: the get_checkpoints_for_trial also had an ORDER BY clause for report time
ext = ext.Order("(metadata->>'latest_batch')::int8", asc)

resp := &apiv1.GetTrialCheckpointsResponse{}
// Count and paginate within a tx.
return resp, db.Bun.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
total, err := ckpts.CountTx(ctx, tx, ext)
if err != nil {
return errors.Wrapf(err, "error counting checkpoints for trial %d from database", req.Id)
}

if len(req.States) != 0 && !found {
return false
// apply pagination
ext = ext.Limit(int(req.Limit))
ext = ext.Offset(int(req.Offset))

ckpts, err := ckpts.ListTx(ctx, tx, ext)
if err != nil {
return errors.Wrapf(err, "error fetching checkpoints for trial %d from database", req.Id)
}
if len(ckpts) == 0 {
// XXX: This is dumb. Why not return an empty array?
return status.Errorf(codes.NotFound, "no checkpoints found for trial %d", req.Id)
}

found = false
for _, state := range req.ValidationStates {
if state == v.ValidationState {
found = true
break
}
pc := protoutils.ProtoConverter{}
resp.Checkpoints = make([]*checkpointv1.Checkpoint, len(ckpts))
for i, ckpt := range ckpts {
protoCkpt := ckpt.ToProto(&pc)
// XXX taking address of protoCkpt looks buggy to me
resp.Checkpoints[i] = &protoCkpt
}

if len(req.ValidationStates) != 0 && !found {
return false
resp.Pagination = &apiv1.Pagination{
Offset: req.Offset,
Limit: req.Limit,
// XXX wut? Why have both offset and start_idx?
StartIndex: req.Offset,
EndIndex: pc.ToInt32(int(req.Offset) + len(ckpts)),
Total: pc.ToInt32(total),
}

return true
return pc.Error()
})

a.sort(
resp.Checkpoints, req.OrderBy, req.SortBy, apiv1.GetTrialCheckpointsRequest_SORT_BY_BATCH_NUMBER)

return resp, a.paginate(&resp.Pagination, &resp.Checkpoints, req.Offset, req.Limit)
}

func (a *apiServer) KillTrial(
Expand Down

0 comments on commit 30f77ed

Please sign in to comment.