Skip to content

Commit

Permalink
feat: Display experiment total checkpoint size [WEB-298] (#5554)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Dec 13, 2022
1 parent c7d1e62 commit 9d38289
Show file tree
Hide file tree
Showing 27 changed files with 322 additions and 70 deletions.
27 changes: 27 additions & 0 deletions harness/determined/common/api/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]:
class trialv1Trial:
bestCheckpoint: "typing.Optional[v1CheckpointWorkload]" = None
bestValidation: "typing.Optional[v1MetricsWorkload]" = None
checkpointCount: "typing.Optional[int]" = None
endTime: "typing.Optional[str]" = None
latestTraining: "typing.Optional[v1MetricsWorkload]" = None
latestValidation: "typing.Optional[v1MetricsWorkload]" = None
Expand All @@ -475,6 +476,7 @@ def __init__(
totalBatchesProcessed: int,
bestCheckpoint: "typing.Union[v1CheckpointWorkload, None, Unset]" = _unset,
bestValidation: "typing.Union[v1MetricsWorkload, None, Unset]" = _unset,
checkpointCount: "typing.Union[int, None, Unset]" = _unset,
endTime: "typing.Union[str, None, Unset]" = _unset,
latestTraining: "typing.Union[v1MetricsWorkload, None, Unset]" = _unset,
latestValidation: "typing.Union[v1MetricsWorkload, None, Unset]" = _unset,
Expand All @@ -495,6 +497,8 @@ def __init__(
self.bestCheckpoint = bestCheckpoint
if not isinstance(bestValidation, Unset):
self.bestValidation = bestValidation
if not isinstance(checkpointCount, Unset):
self.checkpointCount = checkpointCount
if not isinstance(endTime, Unset):
self.endTime = endTime
if not isinstance(latestTraining, Unset):
Expand Down Expand Up @@ -527,6 +531,8 @@ def from_json(cls, obj: Json) -> "trialv1Trial":
kwargs["bestCheckpoint"] = v1CheckpointWorkload.from_json(obj["bestCheckpoint"]) if obj["bestCheckpoint"] is not None else None
if "bestValidation" in obj:
kwargs["bestValidation"] = v1MetricsWorkload.from_json(obj["bestValidation"]) if obj["bestValidation"] is not None else None
if "checkpointCount" in obj:
kwargs["checkpointCount"] = obj["checkpointCount"]
if "endTime" in obj:
kwargs["endTime"] = obj["endTime"]
if "latestTraining" in obj:
Expand Down Expand Up @@ -559,6 +565,8 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]:
out["bestCheckpoint"] = None if self.bestCheckpoint is None else self.bestCheckpoint.to_json(omit_unset)
if not omit_unset or "bestValidation" in vars(self):
out["bestValidation"] = None if self.bestValidation is None else self.bestValidation.to_json(omit_unset)
if not omit_unset or "checkpointCount" in vars(self):
out["checkpointCount"] = self.checkpointCount
if not omit_unset or "endTime" in vars(self):
out["endTime"] = self.endTime
if not omit_unset or "latestTraining" in vars(self):
Expand Down Expand Up @@ -2348,6 +2356,8 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]:
return out

class v1Experiment:
checkpointCount: "typing.Optional[int]" = None
checkpointSize: "typing.Optional[str]" = None
description: "typing.Optional[str]" = None
displayName: "typing.Optional[str]" = None
endTime: "typing.Optional[str]" = None
Expand Down Expand Up @@ -2379,6 +2389,8 @@ def __init__(
startTime: str,
state: "determinedexperimentv1State",
username: str,
checkpointCount: "typing.Union[int, None, Unset]" = _unset,
checkpointSize: "typing.Union[str, None, Unset]" = _unset,
description: "typing.Union[str, None, Unset]" = _unset,
displayName: "typing.Union[str, None, Unset]" = _unset,
endTime: "typing.Union[str, None, Unset]" = _unset,
Expand Down Expand Up @@ -2407,6 +2419,10 @@ def __init__(
self.startTime = startTime
self.state = state
self.username = username
if not isinstance(checkpointCount, Unset):
self.checkpointCount = checkpointCount
if not isinstance(checkpointSize, Unset):
self.checkpointSize = checkpointSize
if not isinstance(description, Unset):
self.description = description
if not isinstance(displayName, Unset):
Expand Down Expand Up @@ -2453,6 +2469,10 @@ def from_json(cls, obj: Json) -> "v1Experiment":
"state": determinedexperimentv1State(obj["state"]),
"username": obj["username"],
}
if "checkpointCount" in obj:
kwargs["checkpointCount"] = obj["checkpointCount"]
if "checkpointSize" in obj:
kwargs["checkpointSize"] = obj["checkpointSize"]
if "description" in obj:
kwargs["description"] = obj["description"]
if "displayName" in obj:
Expand Down Expand Up @@ -2499,6 +2519,10 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]:
"state": self.state.value,
"username": self.username,
}
if not omit_unset or "checkpointCount" in vars(self):
out["checkpointCount"] = self.checkpointCount
if not omit_unset or "checkpointSize" in vars(self):
out["checkpointSize"] = self.checkpointSize
if not omit_unset or "description" in vars(self):
out["description"] = self.description
if not omit_unset or "displayName" in vars(self):
Expand Down Expand Up @@ -3078,6 +3102,7 @@ class v1GetExperimentTrialsRequestSortBy(enum.Enum):
SORT_BY_BATCHES_PROCESSED = "SORT_BY_BATCHES_PROCESSED"
SORT_BY_DURATION = "SORT_BY_DURATION"
SORT_BY_RESTARTS = "SORT_BY_RESTARTS"
SORT_BY_CHECKPOINT_SIZE = "SORT_BY_CHECKPOINT_SIZE"

class v1GetExperimentTrialsResponse:

Expand Down Expand Up @@ -3145,6 +3170,8 @@ class v1GetExperimentsRequestSortBy(enum.Enum):
SORT_BY_FORKED_FROM = "SORT_BY_FORKED_FROM"
SORT_BY_RESOURCE_POOL = "SORT_BY_RESOURCE_POOL"
SORT_BY_PROJECT_ID = "SORT_BY_PROJECT_ID"
SORT_BY_CHECKPOINT_SIZE = "SORT_BY_CHECKPOINT_SIZE"
SORT_BY_CHECKPOINT_COUNT = "SORT_BY_CHECKPOINT_COUNT"

class v1GetExperimentsResponse:

Expand Down
1 change: 1 addition & 0 deletions harness/determined/common/experimental/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class TrialSortBy(enum.Enum):
BATCHES_PROCESSED = _tsb.SORT_BY_BATCHES_PROCESSED.value
DURATION = _tsb.SORT_BY_DURATION.value
RESTARTS = _tsb.SORT_BY_RESTARTS.value
CHECKPOINT_SIZE = _tsb.SORT_BY_CHECKPOINT_SIZE.value

def _to_bindings(self) -> bindings.v1GetExperimentTrialsRequestSortBy:
return _tsb(self.value)
Expand Down
30 changes: 17 additions & 13 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,25 +463,29 @@ func (a *apiServer) GetExperiments(
ColumnExpr("(w.archived OR p.archived) AS parent_archived").
ColumnExpr("p.user_id AS project_owner_id").
Column("e.config").
Column("e.checkpoint_size").
Column("e.checkpoint_count").
Join("JOIN users u ON e.owner_id = u.id").
Join("JOIN projects p ON e.project_id = p.id").
Join("JOIN workspaces w ON p.workspace_id = w.id")

// Construct the ordering expression.
orderColMap := map[apiv1.GetExperimentsRequest_SortBy]string{
apiv1.GetExperimentsRequest_SORT_BY_UNSPECIFIED: "id",
apiv1.GetExperimentsRequest_SORT_BY_ID: "id",
apiv1.GetExperimentsRequest_SORT_BY_DESCRIPTION: "description",
apiv1.GetExperimentsRequest_SORT_BY_NAME: "name",
apiv1.GetExperimentsRequest_SORT_BY_START_TIME: "e.start_time",
apiv1.GetExperimentsRequest_SORT_BY_END_TIME: "e.end_time",
apiv1.GetExperimentsRequest_SORT_BY_STATE: "e.state",
apiv1.GetExperimentsRequest_SORT_BY_NUM_TRIALS: "num_trials",
apiv1.GetExperimentsRequest_SORT_BY_PROGRESS: "COALESCE(progress, 0)",
apiv1.GetExperimentsRequest_SORT_BY_USER: "display_name",
apiv1.GetExperimentsRequest_SORT_BY_FORKED_FROM: "e.parent_id",
apiv1.GetExperimentsRequest_SORT_BY_RESOURCE_POOL: "resource_pool",
apiv1.GetExperimentsRequest_SORT_BY_PROJECT_ID: "project_id",
apiv1.GetExperimentsRequest_SORT_BY_UNSPECIFIED: "id",
apiv1.GetExperimentsRequest_SORT_BY_ID: "id",
apiv1.GetExperimentsRequest_SORT_BY_DESCRIPTION: "description",
apiv1.GetExperimentsRequest_SORT_BY_NAME: "name",
apiv1.GetExperimentsRequest_SORT_BY_START_TIME: "e.start_time",
apiv1.GetExperimentsRequest_SORT_BY_END_TIME: "e.end_time",
apiv1.GetExperimentsRequest_SORT_BY_STATE: "e.state",
apiv1.GetExperimentsRequest_SORT_BY_NUM_TRIALS: "num_trials",
apiv1.GetExperimentsRequest_SORT_BY_PROGRESS: "COALESCE(progress, 0)",
apiv1.GetExperimentsRequest_SORT_BY_USER: "display_name",
apiv1.GetExperimentsRequest_SORT_BY_FORKED_FROM: "e.parent_id",
apiv1.GetExperimentsRequest_SORT_BY_RESOURCE_POOL: "resource_pool",
apiv1.GetExperimentsRequest_SORT_BY_PROJECT_ID: "project_id",
apiv1.GetExperimentsRequest_SORT_BY_CHECKPOINT_SIZE: "checkpoint_size",
apiv1.GetExperimentsRequest_SORT_BY_CHECKPOINT_COUNT: "checkpoint_count",
}
sortByMap := map[apiv1.OrderBy]string{
apiv1.OrderBy_ORDER_BY_UNSPECIFIED: "ASC",
Expand Down
5 changes: 3 additions & 2 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ func (a *apiServer) GetExperimentTrials(
apiv1.GetExperimentTrialsRequest_SORT_BY_BATCHES_PROCESSED: "total_batches_processed",
apiv1.GetExperimentTrialsRequest_SORT_BY_DURATION: "duration",
apiv1.GetExperimentTrialsRequest_SORT_BY_RESTARTS: "restarts",
apiv1.GetExperimentTrialsRequest_SORT_BY_CHECKPOINT_SIZE: "checkpoint_size",
}
sortByMap := map[apiv1.OrderBy]string{
apiv1.OrderBy_ORDER_BY_UNSPECIFIED: "ASC",
Expand Down Expand Up @@ -570,7 +571,7 @@ func (a *apiServer) GetExperimentTrials(
req.Offset,
req.Limit,
); err != nil {
return nil, errors.Wrapf(err, "failed to get trials for experiment %d", req.ExperimentId)
return nil, errors.Wrapf(err, "failed to get trial ids for experiment %d", req.ExperimentId)
} else if len(resp.Trials) == 0 {
return resp, nil
}
Expand All @@ -594,7 +595,7 @@ func (a *apiServer) GetExperimentTrials(
case err == db.ErrNotFound:
return nil, status.Errorf(codes.NotFound, "trials %v not found:", trialIDs)
case err != nil:
return nil, errors.Wrapf(err, "failed to get trials for experiment %d", req.ExperimentId)
return nil, errors.Wrapf(err, "failed to get trials detail for experiment %d", req.ExperimentId)
}

if err = a.enrichTrialState(resp.Trials...); err != nil {
Expand Down
65 changes: 64 additions & 1 deletion master/internal/db/postgres_checkpoints.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package db

import (
"context"
"fmt"

"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/uptrace/bun"

"github.com/determined-ai/determined/master/pkg/model"
)
Expand Down Expand Up @@ -72,7 +74,11 @@ func (db *PgDB) MarkCheckpointsDeleted(deleteCheckpoints []uuid.UUID) error {
if err != nil {
return fmt.Errorf("deleting checkpoints from checkpoints_v2: %w", err)
}

if len(deleteCheckpoints) > 0 {
if err := UpdateCheckpointSize(deleteCheckpoints); err != nil {
return fmt.Errorf("updating checkpoints size: %w", err)
}
}
return nil
}

Expand Down Expand Up @@ -111,3 +117,60 @@ func (db *PgDB) GroupCheckpointUUIDsByExperimentID(checkpoints []uuid.UUID) (

return groupeIDcUUIDS, nil
}

// UpdateCheckpointSize updates checkpoint size and count to experiment and trial.
func UpdateCheckpointSize(checkpoints []uuid.UUID) error {
trialID := Bun().NewSelect().Table("checkpoints_view").
Column("trial_id").
Where("uuid IN (?)", bun.In(checkpoints)).
Distinct()

sizeTuple := Bun().NewSelect().TableExpr("checkpoints_view AS c").
ColumnExpr("jsonb_each(c.resources) AS size_tuple").
Column("experiment_id").
Column("uuid").
Column("trial_id").
Where("state != ?", "DELETED").
Where("c.resources != 'null'::jsonb").
Where("trial_id IN (?)", trialID)

sizeAndCount := Bun().NewSelect().With("cp_size_tuple", sizeTuple).With("trial_ids", trialID).
Table("cp_size_tuple").
ColumnExpr("coalesce(sum((size_tuple).value::text::bigint), 0) AS size").
ColumnExpr("count(distinct(uuid)) AS count").
ColumnExpr("trial_ids.trial_id").
GroupExpr("trial_ids.trial_id").
Join("RIGHT JOIN trial_ids ON trial_ids.trial_id = cp_size_tuple.trial_id")

_, err := Bun().NewUpdate().With("size_and_count", sizeAndCount).
Table("trials", "size_and_count").
Set("checkpoint_size = size").
Set("checkpoint_count = count").
Where("id IN (?)", trialID).
Where("trials.id = size_and_count.trial_id").
Exec(context.Background())
if err != nil {
return err
}

experimentID := Bun().NewSelect().Table("checkpoints_view").
Column("experiment_id").
Where("uuid IN (?)", bun.In(checkpoints)).Distinct()

sizeAndCount = Bun().NewSelect().Table("trials").
ColumnExpr("coalesce(sum(checkpoint_size), 0) AS size").
ColumnExpr("coalesce(sum(checkpoint_count), 0) AS count").
Column("experiment_id").
Group("experiment_id").
Where("experiment_id IN (?)", experimentID)

_, err = Bun().NewUpdate().With("size_and_count", sizeAndCount).
Table("experiments", "size_and_count").
Set("checkpoint_size = size").
Set("checkpoint_count = count").
Where("id IN (?)", experimentID).
Where("experiments.id = experiment_id").
Exec(context.Background())

return err
}
5 changes: 5 additions & 0 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"

Expand Down Expand Up @@ -310,6 +311,10 @@ VALUES
return errors.Wrap(err, "inserting checkpoint")
}

if err := UpdateCheckpointSize([]uuid.UUID{m.UUID}); err != nil {
return errors.Wrap(err, "updating checkpoint size")
}

return nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ALTER TABLE experiments
DROP COLUMN checkpoint_size;
ALTER TABLE experiments
DROP COLUMN checkpoint_count;
ALTER TABLE trials
DROP COLUMN checkpoint_size;
ALTER TABLE trials
DROP COLUMN checkpoint_count;
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
ALTER TABLE experiments
ADD COLUMN checkpoint_size bigint;
ALTER TABLE experiments
ADD COLUMN checkpoint_count int;
ALTER TABLE trials
ADD COLUMN checkpoint_size bigint;
ALTER TABLE trials
ADD COLUMN checkpoint_count int;

UPDATE trials set (checkpoint_size, checkpoint_count) = (size, count) FROM (
SELECT coalesce(sum((size_tuple).value::text::bigint), 0) AS size, count(distinct(uuid)) AS count, trial_id
FROM (
SELECT jsonb_each(c.resources) AS size_tuple, trial_id, uuid
FROM checkpoints_view c
WHERE state != 'DELETED'
AND c.resources != 'null'::jsonb ) r GROUP BY trial_id
) s RIGHT JOIN (SELECT id FROM trials) t ON id = trial_id WHERE
t.id = trials.id;

UPDATE experiments set (checkpoint_size, checkpoint_count) = (size, count) FROM (
SELECT coalesce(sum(checkpoint_size), 0) AS size, coalesce(sum(checkpoint_count), 0) AS count, experiment_id
FROM trials GROUP BY experiment_id
) t WHERE experiments.id = experiment_id;
2 changes: 2 additions & 0 deletions master/static/srv/get_experiment.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ SELECT
e.job_id AS job_id,
e.parent_id AS forked_from,
e.owner_id AS user_id,
e.checkpoint_size AS checkpoint_size,
e.checkpoint_count AS checkpoint_count,
u.username AS username,
(SELECT json_agg(id) FROM trial_ids) AS trial_ids,
(SELECT count(id) FROM trial_ids) AS num_trials,
Expand Down
1 change: 1 addition & 0 deletions master/static/srv/proto_get_trial_ids_for_experiment.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ WITH searcher_info AS (
t.restarts,
t.start_time,
t.end_time,
t.checkpoint_size,
coalesce(t.end_time, now()) - t.start_time AS duration,
(
SELECT coalesce(max(s.total_batches), 0)
Expand Down
18 changes: 2 additions & 16 deletions master/static/srv/proto_get_trials_plus.sql
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ SELECT
t.hparams,
coalesce(new_ckpt.uuid, old_ckpt.uuid) AS warm_start_checkpoint_uuid,
t.task_id,
t.checkpoint_size AS total_checkpoint_size,
t.checkpoint_count,
(
SELECT s.total_batches
FROM steps s
Expand All @@ -175,22 +177,6 @@ SELECT
FROM allocations a
WHERE a.task_id = t.task_id
) AS wall_clock_time,
(
SELECT coalesce(sum((size_tuple).value::text::bigint), 0)
FROM (
SELECT jsonb_each(c.resources) AS size_tuple
FROM checkpoints_old_view c
WHERE c.trial_id = t.id
AND state != 'DELETED'
AND c.resources != 'null'::jsonb
UNION ALL
SELECT jsonb_each(c.resources) as size_tuple
FROM checkpoints_new_view c
WHERE c.trial_id = t.id
AND state != 'DELETED'
AND c.resources != 'null'::jsonb
) r
) AS total_checkpoint_size,
-- `restart` count is incremented before `restart <= max_restarts` stop restart check,
-- so trials in terminal state have restarts = max + 1
LEAST(t.restarts, max_restarts) as restarts
Expand Down
4 changes: 4 additions & 0 deletions proto/src/determined/api/v1/experiment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ message GetExperimentsRequest {
SORT_BY_RESOURCE_POOL = 13;
// Returns experiments sorted by project.
SORT_BY_PROJECT_ID = 14;
// Returns experiments sorted by checkpoint size.
SORT_BY_CHECKPOINT_SIZE = 15;
// Returns experiments sorted by checkpoint count.
SORT_BY_CHECKPOINT_COUNT = 16;
}
// Sort experiments by the given field.
SortBy sort_by = 1;
Expand Down
2 changes: 2 additions & 0 deletions proto/src/determined/api/v1/trial.proto
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ message GetExperimentTrialsRequest {
SORT_BY_DURATION = 10;
// Return the trials sorted by the number of restarts.
SORT_BY_RESTARTS = 11;
// Return the trials sorted by checkpoint size.
SORT_BY_CHECKPOINT_SIZE = 12;
}
// Sort trials by the given field.
SortBy sort_by = 1;
Expand Down
Loading

0 comments on commit 9d38289

Please sign in to comment.