Skip to content

Commit

Permalink
more robust grpc discovery with asyncio and proper error handling, ad…
Browse files Browse the repository at this point in the history
…d flops to device capabilities. fixes #23 and progress on #33
  • Loading branch information
AlexCheema committed Jul 19, 2024
1 parent fa9d416 commit 54c9860
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 66 deletions.
4 changes: 2 additions & 2 deletions examples/llama3_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from exo.inference.shard import Shard
from exo.networking.peer_handle import PeerHandle
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from typing import List
import asyncio
import argparse
Expand All @@ -32,7 +32,7 @@
peer2 = GRPCPeerHandle(
"node2",
"localhost:8081",
DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
)
shard = models[path_or_hf_repo]
request_id = str(uuid.uuid4())
Expand Down
91 changes: 50 additions & 41 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,28 @@
import json
import socket
import time
from typing import List, Dict
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
from .grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo import DEBUG_DISCOVERY

class ListenProtocol(asyncio.DatagramProtocol):
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
super().__init__()
self.on_message = on_message
self.loop = asyncio.get_event_loop()

def connection_made(self, transport):
self.transport = transport

def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
self.node_port = node_port
self.device_capabilities = device_capabilities
Expand All @@ -24,9 +37,10 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.cleanup_task = None

async def start(self):
self.broadcast_task = asyncio.create_task(self._broadcast_presence())
self.listen_task = asyncio.create_task(self._listen_for_peers())
self.cleanup_task = asyncio.create_task(self._cleanup_peers())
self.device_capabilities = device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
self.listen_task = asyncio.create_task(self.task_listen_for_peers())
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())

async def stop(self):
if self.broadcast_task:
Expand Down Expand Up @@ -62,54 +76,49 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:

return list(self.known_peers.values())

async def _broadcast_presence(self):
if not self.device_capabilities:
self.device_capabilities = device_capabilities()

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.settimeout(0.5)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": {
"model": self.device_capabilities.model,
"chip": self.device_capabilities.chip,
"memory": self.device_capabilities.memory
}
"device_capabilities": self.device_capabilities.to_dict()
}).encode('utf-8')

while True:
sock.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)

async def _listen_for_peers(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', self.listen_port))
sock.setblocking(False)

while True:
try:
data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
message = json.loads(data.decode('utf-8'))
if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
if message['type'] == 'discovery' and message['node_id'] != self.node_id:
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.peer_last_seen[peer_id] = time.time()
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
except Exception as e:
print(f"Error in peer discovery: {e}")
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
await asyncio.sleep(self.broadcast_interval / 2)

async def _cleanup_peers(self):
async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
if message['type'] == 'discovery' and message['node_id'] != self.node_id:
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.peer_last_seen[peer_id] = time.time()

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")

async def task_cleanup_peers(self):
while True:
current_time = time.time()
timeout = 15 * self.broadcast_interval
Expand Down
2 changes: 1 addition & 1 deletion exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
response = await self.stub.CollectTopology(request)
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory)
device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
topology.update_node(node_id, device_capabilities)
for node_id, peers in response.peer_graph.items():
for peer_id in peers.peer_ids:
Expand Down
2 changes: 1 addition & 1 deletion exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
topology = await self.node.collect_topology(visited, max_depth)
nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory, flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8)) for node_id, cap in topology.nodes.items()}
peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
Expand Down
7 changes: 7 additions & 0 deletions exo/networking/grpc/node_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ message Peers {
repeated string peer_ids = 1;
}

message DeviceFlops {
float fp32 = 1;
float fp16 = 2;
float int8 = 3;
}

message DeviceCapabilities {
string model = 1;
string chip = 2;
int32 memory = 3;
DeviceFlops flops = 4;
}

message SendResultRequest {
Expand Down
20 changes: 11 additions & 9 deletions exo/networking/grpc/node_service_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,5 @@ async def send_result_to_peer(peer):
import traceback
traceback.print_exc()

print(f"Broadcast result: {request_id=} {result=} {is_finished=}")
await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
Loading

0 comments on commit 54c9860

Please sign in to comment.