Skip to content

Commit

Permalink
checkpoints with bun
Browse files Browse the repository at this point in the history
* change naming pattern

Previous:

    internal/ckpts/model.go       # db access
    internal/ckpts/controller.go  # business logic
    internal/ckpts/view.go        # api layer

Current:

    internal/ckpts/ckpts.go       # db access
    internal/ckpts/api.go         # api layer

The reason for the change was that there's basically nothing that's
really in the model layer; bun handles all the database code all we
really need to do is business logic (ckpts.go).  The view layer is
commonly called "api_*" in our system and terminates REST endpoints;
it's only really responsible for converting to/from wire format
(protobuf).
  • Loading branch information
rb-determined-ai committed Mar 24, 2022
1 parent 03c68f3 commit 8b35cf5
Show file tree
Hide file tree
Showing 15 changed files with 764 additions and 161 deletions.
39 changes: 39 additions & 0 deletions master/TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# git grep -liE 'from (raw_|)checkpoints' static/srv internal

DONE: static/srv/get_checkpoint.sql
DONE: static/srv/get_checkpoints_for_trial.sql
DONE: static/srv/get_model_version.sql

static/srv/get_checkpoints_for_experiment.sql
static/srv/get_model_versions.sql
static/srv/get_trial.sql
static/srv/get_trial_metrics.sql
static/srv/insert_model_version.sql
static/srv/proto_get_trial_workloads.sql
static/srv/update_model_version.sql
internal/db/postgres_trial.go
internal/db/postgres_experiments.go

static/srv/get_checkpoint.sql
internal/api_checkpoint.go::GetCheckpoint()
moved to internal/ckpts/api.go::Server.GetCheckpoint()
Replaced with ckpts.ByUUID()
internal/api_model.go::PostModelVersion()
Used only to verify checkpoint existence and state==COMPLETED.
Replaced with new method ckpts.State(), which only touches the
checkpoints tables (or at least, there's an XXX to make it so)

static/srv/get_checkpoints_for_trial.sql
internal/api_trials.go::GetTrialCheckpoints()
moved to internal/ckpts/api.go::Server.GetTrialCheckpoints()
Replaced with a Count/Limit.Offset.Scan pattern in a bun.Tx
was a 404 possible with no checkpoints? It's not clear to me.

# this was included indirectly, to make get_model_verision easier
static/srv/get_model.sql
internal/api_model.go::GetModel()
replaced with models.ByName()

static/srv/get_model_version.sql
internal/api_model.go::GetModelVerison()
replaced with models.ByNameTx, models.VersionTx, ckpts.ByIDTx
2 changes: 2 additions & 0 deletions master/internal/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"github.com/pkg/errors"

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/ckpts"
"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

type apiServer struct {
m *Master
ckpts.Server
}

// paginate returns a paginated subset of the values and sets the pagination response.
Expand Down
17 changes: 0 additions & 17 deletions master/internal/api_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,12 @@ import (

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
)

func (a *apiServer) GetCheckpoint(
_ context.Context, req *apiv1.GetCheckpointRequest) (*apiv1.GetCheckpointResponse, error) {
resp := &apiv1.GetCheckpointResponse{}
resp.Checkpoint = &checkpointv1.Checkpoint{}
switch err := a.m.db.QueryProto("get_checkpoint", resp.Checkpoint, req.CheckpointUuid); err {
case db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "checkpoint %s not found", req.CheckpointUuid)
default:
return resp,
errors.Wrapf(err, "error fetching checkpoint %s from database", req.CheckpointUuid)
}
}

func (a *apiServer) PostCheckpointMetadata(
ctx context.Context, req *apiv1.PostCheckpointMetadataRequest,
) (*apiv1.PostCheckpointMetadataResponse, error) {
Expand Down
96 changes: 61 additions & 35 deletions master/internal/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,38 @@ import (

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/uptrace/bun"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/internal/ckpts"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/models"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
"github.com/determined-ai/determined/proto/pkg/modelv1"

structpb "github.com/golang/protobuf/ptypes/struct"
)

func (a *apiServer) GetModel(
_ context.Context, req *apiv1.GetModelRequest) (*apiv1.GetModelResponse, error) {
m := &modelv1.Model{}
switch err := a.m.db.QueryProto("get_model", m, req.ModelName); err {
case db.ErrNotFound:
ctx context.Context, req *apiv1.GetModelRequest) (*apiv1.GetModelResponse, error) {
m, err := models.ByName(ctx, req.ModelName)
switch {
// XXX: I've noticed this isn't actually returned by GetModel at all.
// Maybe database getters should return nullable types to make it much more obvious when
// the thing was not found?
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "model \"%s\" not found", req.ModelName)
default:
return &apiv1.GetModelResponse{Model: m},
errors.Wrapf(err, "error fetching model \"%s\" from database", req.ModelName)
case err != nil:
return nil, errors.Wrapf(err, "error fetching model \"%s\" from database", req.ModelName)
}
pc := protoutils.ProtoConverter{}
mv1 := m.ToProto(&pc)
return &apiv1.GetModelResponse{Model: &mv1}, pc.Error()
}

func (a *apiServer) GetModels(
Expand Down Expand Up @@ -291,23 +300,44 @@ func (a *apiServer) DeleteModel(
}

func (a *apiServer) GetModelVersion(
ctx context.Context, req *apiv1.GetModelVersionRequest) (*apiv1.GetModelVersionResponse, error) {
parentModel, err := a.GetModel(ctx, &apiv1.GetModelRequest{ModelName: req.ModelName})
if err != nil {
return nil, err
}

ctx context.Context, req *apiv1.GetModelVersionRequest,
) (*apiv1.GetModelVersionResponse, error) {
resp := &apiv1.GetModelVersionResponse{}
resp.ModelVersion = &modelv1.ModelVersion{}

switch err := a.m.db.QueryProto(
"get_model_version", resp.ModelVersion, parentModel.Model.Id, req.ModelVersion); {
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
default:
return resp, err
}
return resp, db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
m, err := models.ByName(ctx, req.ModelName)
// XXX actually handle 404
switch {
case err == db.ErrNotFound:
return status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
case err != nil:
return err
}

// XXX: make models.Version accept int or int32, with generics
mv, err := models.VersionTx(ctx, tx, m.ID, int(req.ModelVersion))
// XXX actually handle 404
switch {
case err == db.ErrNotFound:
return status.Errorf(
codes.NotFound, "model %s version %d not found", req.ModelName, req.ModelVersion)
case err != nil:
return err
}

ckpt, err := ckpts.ByIDTx(ctx, tx, mv.CheckpointID)
// XXX: not even sure what this should error out as
if err != nil {
return err
}

pc := protoutils.ProtoConverter{}
mvv1 := mv.ToProto(&pc, m, ckpt)
resp.ModelVersion = &mvv1

return pc.Error()
})
}

func (a *apiServer) GetModelVersions(
Expand Down Expand Up @@ -340,21 +370,17 @@ func (a *apiServer) PostModelVersion(
modelResp.Model.Name)
}

// make sure the checkpoint exists
c := &checkpointv1.Checkpoint{}

switch getCheckpointErr := a.m.db.QueryProto("get_checkpoint", c, req.CheckpointUuid); {
case getCheckpointErr == db.ErrNotFound:
// make sure the checkpoint exists in a COMPLETED state
if state, err := ckpts.State(ctx, req.CheckpointUuid); err != nil {
return nil, err
} else if state == nil {
return nil, status.Errorf(
codes.NotFound, "checkpoint %s not found", req.CheckpointUuid)
case getCheckpointErr != nil:
return nil, getCheckpointErr
}

if c.State != checkpointv1.State_STATE_COMPLETED {
// XXX: the old code used checkpointv1.State_STATE_COMPLETED... why though?
} else if *state != model.CompletedState {
return nil, errors.Errorf(
"checkpoint %s is in %s state. checkpoints for model versions must be in a COMPLETED state",
c.Uuid, c.State,
req.CheckpointUuid, state,
)
}

Expand All @@ -377,7 +403,7 @@ func (a *apiServer) PostModelVersion(
"insert_model_version",
respModelVersion.ModelVersion,
modelResp.Model.Id,
c.Uuid,
req.CheckpointUuid,
req.Name,
req.Comment,
mdata,
Expand Down
59 changes: 0 additions & 59 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/master/pkg/searcher"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
"github.com/determined-ai/determined/proto/pkg/experimentv1"
"github.com/determined-ai/determined/proto/pkg/trialv1"
)
Expand Down Expand Up @@ -297,64 +296,6 @@ func (a *apiServer) TrialLogsFields(
})
}

func (a *apiServer) GetTrialCheckpoints(
_ context.Context, req *apiv1.GetTrialCheckpointsRequest,
) (*apiv1.GetTrialCheckpointsResponse, error) {
switch exists, err := a.m.db.CheckTrialExists(int(req.Id)); {
case err != nil:
return nil, err
case !exists:
return nil, status.Error(codes.NotFound, "trial not found")
}

resp := &apiv1.GetTrialCheckpointsResponse{}
resp.Checkpoints = []*checkpointv1.Checkpoint{}

switch err := a.m.db.QueryProto("get_checkpoints_for_trial", &resp.Checkpoints, req.Id); {
case err == db.ErrNotFound:
return nil, status.Errorf(
codes.NotFound, "no checkpoints found for trial %d", req.Id)
case err != nil:
return nil,
errors.Wrapf(err, "error fetching checkpoints for trial %d from database", req.Id)
}

a.filter(&resp.Checkpoints, func(i int) bool {
v := resp.Checkpoints[i]

found := false
for _, state := range req.States {
if state == v.State {
found = true
break
}
}

if len(req.States) != 0 && !found {
return false
}

found = false
for _, state := range req.ValidationStates {
if state == v.ValidationState {
found = true
break
}
}

if len(req.ValidationStates) != 0 && !found {
return false
}

return true
})

a.sort(
resp.Checkpoints, req.OrderBy, req.SortBy, apiv1.GetTrialCheckpointsRequest_SORT_BY_BATCH_NUMBER)

return resp, a.paginate(&resp.Pagination, &resp.Checkpoints, req.Offset, req.Limit)
}

func (a *apiServer) KillTrial(
ctx context.Context, req *apiv1.KillTrialRequest,
) (*apiv1.KillTrialResponse, error) {
Expand Down

0 comments on commit 8b35cf5

Please sign in to comment.