Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions api/design_components.partials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Role:
replica:
format: int32
type: integer
groupAssociation:
type: array
items:
$ref: '#/components/schemas/GroupAssociation'
required:
- name
example:
Expand All @@ -137,6 +141,17 @@ Role:
description: These are responsible to aggregate the updates from trainer nodes.
replica: 2

#########################
# GroupAssociation
#########################
GroupAssociation:
type: object
additionalProperties:
type: string
example:
"param-channel": "red"
"global-channel": "black"

#########################
# Channel between roles
#########################
Expand Down
22 changes: 11 additions & 11 deletions cmd/controller/app/job/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,19 @@ func (b *JobBuilder) getTaskTemplates() ([]string, map[string]*taskTemplate) {

for _, role := range b.schema.Roles {
template := &taskTemplate{}
JobConfig := &template.JobConfig
jobConfig := &template.JobConfig

JobConfig.Configure(b.jobSpec, b.jobParams.Brokers, b.jobParams.Registry, role, b.schema.Channels)
jobConfig.Configure(b.jobSpec, b.jobParams.Brokers, b.jobParams.Registry, role, b.schema.Channels)

// check channels and set default group if channels don't have groupby attributes set
for i := range JobConfig.Channels {
if len(JobConfig.Channels[i].GroupBy.Value) > 0 {
// check channels and set default group if channels don't have groupBy attributes set
for i := range jobConfig.Channels {
if len(jobConfig.Channels[i].GroupBy.Value) > 0 {
continue
}

// since there is no groupby attribute, set default
JobConfig.Channels[i].GroupBy.Type = groupByTypeTag
JobConfig.Channels[i].GroupBy.Value = append(JobConfig.Channels[i].GroupBy.Value, defaultGroup)
// since there is no groupBy attribute, set default
jobConfig.Channels[i].GroupBy.Type = groupByTypeTag
jobConfig.Channels[i].GroupBy.Value = append(jobConfig.Channels[i].GroupBy.Value, defaultGroup)
}

template.isDataConsumer = role.IsDataConsumer
Expand All @@ -192,7 +192,7 @@ func (b *JobBuilder) getTaskTemplates() ([]string, map[string]*taskTemplate) {
}
template.ZippedCode = b.roleCode[role.Name]
template.Role = role.Name
template.JobId = JobConfig.Job.Id
template.JobId = jobConfig.Job.Id

templates[role.Name] = template
}
Expand All @@ -205,9 +205,9 @@ func (b *JobBuilder) preCheck(dataRoles []string, templates map[string]*taskTemp
// This function will evolve as more invariants are defined
// Before processing templates, the following invariants should be met:
// 1. At least one data consumer role should be defined.
// 2. a role shouled be associated with a code.
// 2. a role should be associated with a code.
// 3. template should be connected.
// 4. when graph traversal starts at a data role template, the depth of groupby tag
// 4. when graph traversal starts at a data role template, the depth of groupBy tag
// should strictly decrease from one channel to another.
// 5. two different data roles cannot be connected directly.

Expand Down
17 changes: 17 additions & 0 deletions cmd/controller/app/objects/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type JobConfig struct {
Job JobIdName `json:"job"`
Role string `json:"role"`
Realm string `json:"realm"`
Groups map[string]string `json:"groups"`
Channels []openapi.Channel `json:"channels"`

MaxRunTime int32 `json:"maxRunTime,omitempty"`
Expand Down Expand Up @@ -101,6 +102,22 @@ func (cfg *JobConfig) Configure(jobSpec *openapi.JobSpec, brokers []config.Broke
// Realm will be updated when datasets are handled
cfg.Realm = ""
cfg.Channels = cfg.extractChannels(role.Name, channels)

// configure the groups of the job based on the groups associated with the assigned role
cfg.Groups = cfg.extractGroups(role.GroupAssociation)
}

// extractGroups - extracts the associated groups that a given role has of a particular job
func (cfg *JobConfig) extractGroups(groupAssociation []map[string]string) map[string]string {
groups := make(map[string]string)

for _, ag := range groupAssociation {
for key, value := range ag {
groups[key] = value
}
}

return groups
}

func (cfg *JobConfig) extractChannels(role string, channels []openapi.Channel) []openapi.Channel {
Expand Down
71 changes: 44 additions & 27 deletions examples/adult/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,51 @@
"name": "A simple two-tier topology schema",
"description": "a sample schema to demonstrate a TAG layout",
"roles": [
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true
},
{
"name": "aggregator",
"description": "It aggregates the updates from trainers"
}
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true,
"groupAssociation": [
{
"param-channel": "default"
}
]
},
{
"name": "aggregator",
"description": "It aggregates the updates from trainers",
"replica": 1,
"groupAssociation": [
{
"param-channel": "default"
}
]
}
],
"channels": [
{
"name": "param-channel",
"description": "Model update is sent from trainer to aggregator and vice-versa",
"pair": [
"trainer",
"aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"trainer": ["fetch", "upload"],
"aggregator": ["distribute", "aggregate"]
}
}
{
"name": "param-channel",
"description": "Model update is sent from trainer to aggregator and vice-versa",
"pair": [
"trainer",
"aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"trainer": [
"fetch",
"upload"
],
"aggregator": [
"distribute",
"aggregate"
]
}
}
]
}
21 changes: 14 additions & 7 deletions examples/distributed_training/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
"name": "A simple schema for distributed training with MQTT backend",
"description": "This implementation is on Keras using MNIST dataset.",
"roles": [
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true
}
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true,
"groupAssociation": [
{
"param-channel": "default/us"
}
]
}
],
"channels": [
{
Expand All @@ -22,8 +27,10 @@
"trainer",
"trainer"
],
"funcTags": {
"trainer": ["ring_allreduce"]
"funcTags": {
"trainer": [
"ring_allreduce"
]
}
}
]
Expand Down
130 changes: 80 additions & 50 deletions examples/hier_mnist/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,87 @@
"name": "A simple hierarchical FL MNIST example schema v1.0.0",
"description": "a sample schema to demostrate the hierarchical FL setting",
"roles": [
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true
},
{
"name": "middle-aggregator",
"description": "It aggregates the updates from trainers"
},
{
"name": "top-aggregator",
"description": "It aggregates the updates from middle-aggregator"
}
{
"name": "trainer",
"description": "It consumes the data and trains local model",
"isDataConsumer": true,
"groupAssociation": [
{
"param-channel": "default/eu"
}
]
},
{
"name": "middle-aggregator",
"description": "It aggregates the updates from trainers",
"replica": 1,
"groupAssociation": [
{
"param-channel": "default/us",
"global-channel": "default"
}
]
},
{
"name": "top-aggregator",
"description": "It aggregates the updates from middle-aggregator",
"replica": 1,
"groupAssociation": [
{
"global-channel": "default"
}
]
}
],
"channels": [
{
"name": "param-channel",
"description": "Model update is sent from trainer to middle-aggregator and vice-versa",
"pair": [
"trainer",
"middle-aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default/eu",
"default/na"
]
},
"funcTags": {
"trainer": ["fetch", "upload"],
"middle-aggregator": ["distribute", "aggregate"]
}
},
{
"name": "global-channel",
"description": "Model update is sent from middle-aggregator to top-aggregator and vice-versa",
"pair": [
"top-aggregator",
"middle-aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"top-aggregator": ["distribute", "aggregate"],
"middle-aggregator": ["fetch", "upload"]
}
}
{
"name": "param-channel",
"description": "Model update is sent from trainer to middle-aggregator and vice-versa",
"pair": [
"trainer",
"middle-aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default/eu",
"default/na"
]
},
"funcTags": {
"trainer": [
"fetch",
"upload"
],
"middle-aggregator": [
"distribute",
"aggregate"
]
}
},
{
"name": "global-channel",
"description": "Model update is sent from middle-aggregator to top-aggregator and vice-versa",
"pair": [
"top-aggregator",
"middle-aggregator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"top-aggregator": [
"distribute",
"aggregate"
],
"middle-aggregator": [
"fetch",
"upload"
]
}
}
]
}
Loading