Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ dist/
.idea/
.vscode/
*.DS_Store

# venv
.venv/
12 changes: 3 additions & 9 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
/2_QAID_1.masked.reshaped.squared.224.png
/9-mnist-example.png
/CIFAR10/
/cifar10-agents
/cifar10-example.png
/cifar10-labels.csv
/cifar10*
/simple_face
/simple_face-example.png
/titanic_test.csv
/titanic_train.csv
/titanic_train_with_nan.csv
/titanic_test_with_nan.csv
/titanic_wrong_number_columns.csv
/titanic_wrong_passengerID.csv
/titanic*
/mnist*

# wikitext
/wikitext/
Expand Down
11 changes: 8 additions & 3 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
await this.waitForParticipantsIfNeeded()
// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()
// Exchange weight updates with peers and return aggregated weights
// Wait StartWeightSharing message from the server before exchanging weight updates
await waitMessage(this.server, type.StartWeightSharing)
// Exchange weight updates with peers and return aggregated weights // and then send out the contributions
return await this.exchangeWeightUpdates(weights)
}

Expand All @@ -178,8 +180,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
try {
debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
const receivedMessage = await waitMessage(this.server, type.PeersForRound)

const peers = Set(receivedMessage.peers)
debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray());

if (this.ownId !== undefined && peers.has(this.ownId)) {
throw new Error('received peer list contains our own id')
Expand All @@ -198,7 +201,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
(conn) => this.receivePayloads(conn)
)

debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
// Signal server that all connections with other peers in the round are established
this.server.send({ type: type.ConnectionsReady });
debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS());
this.#connections = connections
} catch (e) {
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
Expand Down
19 changes: 17 additions & 2 deletions discojs/src/client/decentralized/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ export interface PeersForRound {
aggregationRound: number
}

// peer sends to server to signal all the connections to other peers
// are established
export interface ConnectionsReady {
type: type.ConnectionsReady
}

// Server signal each peer to start weight update sharing
export interface StartWeightSharing {
type: type.StartWeightSharing;
}

/// Phase 1 communication (between peers)

export interface Payload {
Expand All @@ -55,13 +66,15 @@ export type MessageFromServer =
SignalForPeer |
PeersForRound |
WaitingForMoreParticipants |
EnoughParticipants
EnoughParticipants |
StartWeightSharing

export type MessageToServer =
ClientConnected |
SignalForPeer |
PeerIsReady |
JoinRound
JoinRound |
ConnectionsReady

export type PeerMessage = Payload

Expand All @@ -80,6 +93,7 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer {
return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID)
case type.WaitingForMoreParticipants:
case type.EnoughParticipants:
case type.StartWeightSharing:
return true
}

Expand All @@ -97,6 +111,7 @@ export function isMessageToServer (o: unknown): o is MessageToServer {
'signal' in o // TODO check signal content?
case type.JoinRound:
case type.PeerIsReady:
case type.ConnectionsReady:
return true
}

Expand Down
4 changes: 4 additions & 0 deletions discojs/src/client/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ export enum type {
// Message forwarded by the server from a client to another client
// to establish a peer-to-peer (WebRTC) connection
SignalForPeer,
// Message sent by nodes to server to signal all connections are established
ConnectionsReady,
// Sent by the server to signal nodes proceed to weight update sharing
StartWeightSharing,
// The weight update
Payload,

Expand Down
42 changes: 40 additions & 2 deletions server/src/controllers/decentralized_controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export class DecentralizedController<
// the node has already sent a PeerIsReady message)
// We wait for all peers to be ready to exchange weight updates
#roundPeers = Map<client.NodeID, boolean>()
#connectFinishedNodes = Map<client.NodeID, boolean>()
#aggregationRound = 0

handle (ws: WebSocket): void {
Expand Down Expand Up @@ -84,6 +85,11 @@ export class DecentralizedController<
this.connections.get(msg.peer)?.send(msgpack.encode(forward))
break
}
case MessageTypes.ConnectionsReady: {
this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true)
this.signalWeightSharing()
break
}
default: {
const _: never = msg
throw new Error('should never happen')
Expand Down Expand Up @@ -145,9 +151,41 @@ export class DecentralizedController<
}
return [conn, encoded] as [WebSocket, Buffer]
}).forEach(([conn, encoded]) => { conn.send(encoded) })

// Initialize connectFinishedNodes with all peers set to false
this.#connectFinishedNodes = this.#roundPeers.map(() => false)
this.#aggregationRound++
}

/**
* Check if all the participants of the round finished connecting
* with other peers in the round
* If so, send StartWeightSharing message to signal peers to proceed
*/
private signalWeightSharing(): void {
if (!this.#connectFinishedNodes.every((ready) => ready))
return
this.#roundPeers.keySeq()
.map((id) => {
const startSignal = {
type: MessageTypes.StartWeightSharing,
}
debug("Signaling weight sharing to: %o", id.slice(0, 4))

const encoded = msgpack.encode(startSignal)
return [id, encoded] as [client.NodeID, Buffer]
})
.map(([id, encoded]) => {
const conn = this.connections.get(id)
if (conn === undefined) {
throw new Error(`peer ${id} marked as ready but not connection to it`)
}
return [conn, encoded] as [WebSocket, Buffer]
})
.forEach(([conn, encoded]) => {conn.send(encoded)})

// empty the list of peers for the next round
this.#roundPeers = Map()
this.#aggregationRound++
this.#connectFinishedNodes = Map()
}
}