Skip to content

Commit

Permalink
fix: slot stats are not filled in everywhere (#9070)
Browse files Browse the repository at this point in the history
causes unmarshalling of get agent and enable/disable agent to fail
in python generated client.
introduced in e8dba6d

(cherry picked from commit 6c4bc44)
  • Loading branch information
hamidzr authored and determined-ci committed Mar 28, 2024
1 parent d2e3a5c commit b8db2e6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 52 deletions.
5 changes: 3 additions & 2 deletions master/internal/api_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (

"github.com/stretchr/testify/assert"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/proto/pkg/agentv1"
"github.com/determined-ai/determined/proto/pkg/containerv1"
"github.com/determined-ai/determined/proto/pkg/devicev1"
)

func TestSummarizeSlots_EmptySlots(t *testing.T) {
slots := make(map[string]*agentv1.Slot)
stats := SummarizeSlots(slots)
stats := model.SummarizeSlots(slots)

assert.Equal(t, 0, len(stats.TypeStats))
assert.Equal(t, 0, len(stats.BrandStats))
Expand Down Expand Up @@ -47,7 +48,7 @@ func TestSummarizeSlots_VariousStates(t *testing.T) {
},
}

stats := SummarizeSlots(slots)
stats := model.SummarizeSlots(slots)

assert.Equal(t, 2, int(stats.TypeStats[devicev1.Type_TYPE_CUDA.String()].Total))
assert.Equal(t, 1, int(stats.TypeStats[devicev1.Type_TYPE_CPU.String()].Total))
Expand Down
50 changes: 0 additions & 50 deletions master/internal/api_agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,9 @@ import (
"github.com/determined-ai/determined/master/internal/cluster"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/rm/rmerrors"
"github.com/determined-ai/determined/proto/pkg/agentv1"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

type slotStats map[string]*agentv1.DeviceStats

// SummarizeSlots a set of slots.
func SummarizeSlots(slots map[string]*agentv1.Slot) *agentv1.SlotStats {
stats := agentv1.SlotStats{
TypeStats: make(slotStats),
BrandStats: make(slotStats),
}

if len(slots) == 0 {
return &stats
}
for _, slot := range slots {
deviceType := slot.Device.Type.String()
deviceTypeStats, ok := stats.TypeStats[deviceType]
if !ok {
deviceTypeStats = &agentv1.DeviceStats{
States: make(map[string]int32),
}
stats.TypeStats[deviceType] = deviceTypeStats
}
deviceBrand := slot.Device.Brand
deviceBrandStats, ok := stats.BrandStats[deviceBrand]
if !ok {
deviceBrandStats = &agentv1.DeviceStats{
States: make(map[string]int32),
}
stats.BrandStats[deviceBrand] = deviceBrandStats
}
deviceBrandStats.Total++
deviceTypeStats.Total++

if !slot.Enabled {
deviceBrandStats.Disabled++
deviceTypeStats.Disabled++
}
if slot.Draining {
deviceBrandStats.Draining++
deviceTypeStats.Draining++
}
if slot.Container != nil {
deviceBrandStats.States[slot.Container.State.String()]++
deviceTypeStats.States[slot.Container.State.String()]++
}
}
return &stats
}

func (a *apiServer) GetAgents(
ctx context.Context, req *apiv1.GetAgentsRequest,
) (*apiv1.GetAgentsResponse, error) {
Expand Down Expand Up @@ -92,7 +43,6 @@ func (a *apiServer) GetAgents(

// PERF: can perhaps be done before RBAC.
for _, agent := range resp.Agents {
agent.SlotStats = SummarizeSlots(agent.Slots)
if req.ExcludeSlots {
agent.Slots = nil
}
Expand Down
49 changes: 49 additions & 0 deletions master/pkg/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,54 @@ type AgentSummary struct {
Version string `json:"version"`
}

type slotStats map[string]*agentv1.DeviceStats

// SummarizeSlots a set of slots.
func SummarizeSlots(slots map[string]*agentv1.Slot) *agentv1.SlotStats {
stats := agentv1.SlotStats{
TypeStats: make(slotStats),
BrandStats: make(slotStats),
}

if len(slots) == 0 {
return &stats
}
for _, slot := range slots {
deviceType := slot.Device.Type.String()
deviceTypeStats, ok := stats.TypeStats[deviceType]
if !ok {
deviceTypeStats = &agentv1.DeviceStats{
States: make(map[string]int32),
}
stats.TypeStats[deviceType] = deviceTypeStats
}
deviceBrand := slot.Device.Brand
deviceBrandStats, ok := stats.BrandStats[deviceBrand]
if !ok {
deviceBrandStats = &agentv1.DeviceStats{
States: make(map[string]int32),
}
stats.BrandStats[deviceBrand] = deviceBrandStats
}
deviceBrandStats.Total++
deviceTypeStats.Total++

if !slot.Enabled {
deviceBrandStats.Disabled++
deviceTypeStats.Disabled++
}
if slot.Draining {
deviceBrandStats.Draining++
deviceTypeStats.Draining++
}
if slot.Container != nil {
deviceBrandStats.States[slot.Container.State.String()]++
deviceTypeStats.States[slot.Container.State.String()]++
}
}
return &stats
}

// ToProto converts an agent summary to a proto struct.
func (a AgentSummary) ToProto() *agentv1.Agent {
slots := make(map[string]*agentv1.Slot)
Expand All @@ -39,6 +87,7 @@ func (a AgentSummary) ToProto() *agentv1.Agent {
Id: a.ID,
RegisteredTime: protoutils.ToTimestamp(a.RegisteredTime),
Slots: slots,
SlotStats: SummarizeSlots(slots),
Containers: containers,
ResourcePools: a.ResourcePool,
Addresses: a.Addresses,
Expand Down

0 comments on commit b8db2e6

Please sign in to comment.