-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9702094
commit 2fce3ee
Showing
3 changed files
with
224 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package model | ||
|
||
import ( | ||
"encoding/json" | ||
"time" | ||
|
||
"github.com/pkg/errors" | ||
"github.com/uptrace/bun" | ||
|
||
"github.com/determined-ai/determined/master/pkg/protoutils" | ||
"github.com/determined-ai/determined/proto/pkg/checkpointv1" | ||
) | ||
|
||
// Resources maps filenames to file sizes. | ||
type Resources map[string]int64 | ||
|
||
// Scan converts jsonb from postgres into a Resources object. | ||
// TODO: Combine all json.unmarshal-based Scanners into a single Scan implementation. | ||
func (r *Resources) Scan(src interface{}) error { | ||
if src == nil { | ||
*r = nil | ||
return nil | ||
} | ||
bytes, ok := src.([]byte) | ||
if !ok { | ||
return errors.Errorf("unable to convert to []byte: %v", src) | ||
} | ||
obj := make(map[string]int64) | ||
if err := json.Unmarshal(bytes, &obj); err != nil { | ||
return errors.Wrapf(err, "unable to unmarshal Resources: %v", src) | ||
} | ||
*r = Resources(obj) | ||
return nil | ||
} | ||
|
||
// RBCheckpoint is rb's checkpoint. | ||
type RBCheckpoint struct { | ||
bun.BaseModel | ||
|
||
// Basic data, common to all checkpoints | ||
// ID int // `db:"id" json:"id"` | ||
UUID string // `db:"uuid" json:"uuid"` | ||
TaskID *string // `db:"trial_id" json:"task_id"` | ||
AllocationID *string // `db:"trial_id" json:"task_id"` | ||
ReportTime time.Time // `db:"end_time" json:"report_time"` | ||
State State // `db:"state" json:"state" | ||
Resources Resources // `db:"resources" json:"resources"` | ||
Metadata JSONObj // `db:"metadata" json:"metadata"` | ||
|
||
// Training data, flattened to work with the database. Considered empty if TrialID == 0. | ||
TrialID int | ||
ExperimentID int | ||
ExperimentConfig JSONObj | ||
Hparams JSONObj | ||
TrainingMetrics JSONObj | ||
ValidationMetrics JSONObj | ||
} | ||
|
||
func (c RBCheckpoint) ToProto(pc *protoutils.ProtoConverter) checkpointv1.Checkpoint { | ||
if pc.Error() != nil { | ||
return checkpointv1.Checkpoint{} | ||
} | ||
|
||
nilToStr := func(v *string) string { | ||
if v == nil { | ||
return "" | ||
} | ||
return *v | ||
} | ||
|
||
out := checkpointv1.Checkpoint{ | ||
TaskId: nilToStr(c.TaskID), | ||
AllocationId: nilToStr(c.AllocationID), | ||
Uuid: c.UUID, | ||
ReportTime: pc.ToTimestamp(c.ReportTime), | ||
State: pc.ToCheckpointv1State(string(c.State)), | ||
Resources: c.Resources, | ||
Metadata: pc.ToStruct(c.Metadata, "metadata"), | ||
} | ||
if c.TrialID != 0 { | ||
out.Training = &checkpointv1.CheckpointTrainingData{ | ||
TrialId: int32(c.TrialID), | ||
ExperimentId: int32(c.ExperimentID), | ||
ExperimentConfig: pc.ToStruct(c.ExperimentConfig, "experiment config"), | ||
Hparams: pc.ToStruct(c.Hparams, "hparams"), | ||
TrainingMetrics: pc.ToStruct(c.TrainingMetrics, "training metrics"), | ||
ValidationMetrics: pc.ToStruct(c.ValidationMetrics, "validation metrics"), | ||
} | ||
} | ||
|
||
return out | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
syntax = "proto3"; | ||
|
||
package determined.checkpoint.v1; | ||
option go_package = "github.com/determined-ai/determined/proto/pkg/checkpointv1"; | ||
|
||
import "google/protobuf/struct.proto"; | ||
import "google/protobuf/timestamp.proto"; | ||
import "google/protobuf/wrappers.proto"; | ||
import "protoc-gen-swagger/options/annotations.proto"; | ||
|
||
// The current state of the checkpoint. | ||
enum State { | ||
// The state of the checkpoint is unknown. | ||
STATE_UNSPECIFIED = 0; | ||
// The checkpoint is in an active state. | ||
STATE_ACTIVE = 1; | ||
// The checkpoint is persisted to checkpoint storage. | ||
STATE_COMPLETED = 2; | ||
// The checkpoint errored. | ||
STATE_ERROR = 3; | ||
// The checkpoint has been deleted. | ||
STATE_DELETED = 4; | ||
} | ||
|
||
message CheckpointTrainingData { | ||
// The ID of the trial that created this checkpoint. | ||
int32 trial_id = 1; | ||
// The ID of the experiment that created this checkpoint. | ||
int32 experiment_id = 2; | ||
// The configuration of the experiment that created this checkpoint. | ||
google.protobuf.Struct experiment_config = 3; | ||
// Hyperparameter values for the trial that created this checkpoint. | ||
google.protobuf.Struct hparams = 4; | ||
// Training metrics reported at the same latest_batch as the checkpoint. | ||
google.protobuf.Struct training_metrics = 5; | ||
// Validation metrics reported at the same latest_batch as the checkpoint. | ||
google.protobuf.Struct validation_metrics = 6; | ||
// TODO: is this actually a good idea? It couples a lot of parts of the | ||
// system pretty tightly, and if we had a good metrics api it would be a much | ||
// more natural way to serve this information. | ||
google.protobuf.DoubleValue searcher_metric = 17; | ||
} | ||
|
||
// Checkpoint is a collection of files. | ||
message Checkpoint { | ||
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { | ||
json_schema: { required: [] } | ||
}; | ||
// The ID of the task which generated this checkpoint. | ||
// (Future): Might be empty for an imported checkpoint. | ||
string task_id = 1; | ||
// The ID of the allocation which generated this checkpoint. | ||
// (Future): Might be empty for an imported checkpoint. | ||
string allocation_id = 2; | ||
// UUID of the checkpoint. | ||
string uuid = 3; | ||
// Timestamp when the checkpoint was reported. | ||
google.protobuf.Timestamp report_time = 4; | ||
// Dictionary of file paths to file sizes in bytes of all files in the | ||
// checkpoint. | ||
map<string, int64> resources = 5; | ||
// User defined metadata associated with the checkpoint. | ||
google.protobuf.Struct metadata = 6; | ||
// Training-related data for this checkpoint. | ||
// (Future): Present only when a training originated from a Trial. If | ||
// training.trial_id is zero the struct is empty. | ||
CheckpointTrainingData training = 7; | ||
State state = 8; | ||
|
||
// TODO: delete this and fix the master so it still compiles. | ||
State validation_state = 9; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
-- We'll need to key checkpoints by UUID and not by ID; | ||
-- when we see a checkpoint id we'll have to assume its old, and when we see | ||
-- a uuid we should not assume if it is old or new. | ||
-- XXX: foreign key to new_checkpoints? | ||
alter table trials drop column warm_start_checkpoint_uuid; | ||
alter table trials add column warm_start_checkpoint_uuid uuid; | ||
|
||
-- This is fairly well-optimized already, mostly by @stokc | ||
CREATE VIEW checkpoints_expanded AS | ||
SELECT | ||
c.uuid, | ||
t.task_id, | ||
c.end_time as report_time, | ||
c.state, | ||
c.resources, | ||
-- construct a metadata json from the user's metadata plus our training-specific fields that the | ||
-- TrialControllers inject when creating checkpoints. Those values used to be "system" values, | ||
-- but since the release of Core API, the TrialControllers are no longer part of the system | ||
-- proper but are considered userspace tools. | ||
jsonb_build_object( | ||
'latest_batch', c.total_batches, | ||
'framework', c.framework, | ||
'determined_version', c.determined_version | ||
) || COALESCE(c.metadata, '{}'::jsonb) as metadata, | ||
-- .Training substruct | ||
c.trial_id, | ||
t.experiment_id, | ||
e.config as experiment_config, | ||
t.hparams, | ||
s.metrics as training_metrics, | ||
v.metrics as validation_metrics | ||
FROM raw_checkpoints AS c | ||
LEFT JOIN trials AS t on c.trial_id = t.id | ||
LEFT JOIN experiments AS e on t.experiment_id = e.id | ||
LEFT JOIN validations AS v on c.total_batches = v.total_batches and c.trial_id = v.trial_id | ||
-- avoiding the steps view causes Postgres to not "Materialize" in this join. | ||
LEFT JOIN raw_steps AS s on c.total_batches = s.total_batches and c.trial_id = s.trial_id | ||
where s.archived = false | ||
UNION | ||
SELECT | ||
c.uuid, | ||
c.task_id, | ||
c.report_time, | ||
c.state, | ||
c.resources, | ||
c.metadata, | ||
-- .Training substruct | ||
t.id as trial_id, | ||
t.experiment_id, | ||
e.config as experiment_config, | ||
t.hparams, | ||
s.metrics as training_metrics, | ||
v.metrics as validation_metrics | ||
FROM checkpoints_new AS c | ||
LEFT JOIN trials AS t on c.task_id = t.task_id | ||
LEFT JOIN experiments AS e on t.experiment_id = e.id | ||
LEFT JOIN validations AS v on c.metadata->>'latest_batch' = v.total_batches::text and t.id = v.trial_id | ||
-- avoiding the steps view causes Postgres to not "Materialize" in this join. | ||
LEFT JOIN raw_steps AS s on c.metadata->>'latest_batch' = s.total_batches::text and t.id = s.trial_id | ||
where s.archived = false; |