Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 139 additions & 34 deletions cake-core/src/cake/sharding/worker.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#[cfg(unix)]
use std::os::fd::AsRawFd;
use std::{
collections::HashMap,
io::ErrorKind,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};

use crate::cake::{Context, Forwarder};
use super::{Message, WorkerInfo};
use crate::cake::{Context, Forwarder};
use crate::models::Generator;

use anyhow::Result;
Expand Down Expand Up @@ -77,11 +80,45 @@ impl<F: Forwarder> WorkerContext<F> {

/// Cake worker node.
pub struct Worker<G: Generator> {
listener: TcpListener,
listener: Option<TcpListener>,
context: WorkerContext<G::Shardable>,
}

impl<G: Generator + 'static> Worker<G> {
fn describe_listener(listener: &TcpListener) -> String {
let addr = listener
.local_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|_| "<unknown>".to_string());
#[cfg(unix)]
{
format!("{addr} fd={}", listener.as_raw_fd())
}
#[cfg(not(unix))]
{
addr
}
}

fn accept_error_is_transient(kind: ErrorKind) -> bool {
matches!(
kind,
ErrorKind::ConnectionAborted
| ErrorKind::Interrupted
| ErrorKind::TimedOut
| ErrorKind::WouldBlock
)
}

async fn bind_listener(bind_address: &str) -> Result<TcpListener> {
let listener = TcpListener::bind(bind_address).await?;
log::info!(
"worker listener bound on {}",
Self::describe_listener(&listener)
);
Ok(listener)
}

/// Detect how many CUDA devices are available.
fn detect_cuda_device_count() -> usize {
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -199,17 +236,14 @@ impl<G: Generator + 'static> Worker<G> {
.context()
.bind_to_thread()
.map_err(|e| {
anyhow!(
"failed to bind CUDA context for GPU {gpu_idx}: {e:?}"
)
anyhow!("failed to bind CUDA context for GPU {gpu_idx}: {e:?}")
})?;
}

let mut results = Vec::new();
for layer_name in layers {
log::info!("loading {} on cuda:{} ...", &layer_name, gpu_idx);
let block =
G::Shardable::load(layer_name.clone(), &thread_ctx)?;
let block = G::Shardable::load(layer_name.clone(), &thread_ctx)?;
results.push((layer_name, dev.clone(), block));
}
Ok(results)
Expand Down Expand Up @@ -243,16 +277,24 @@ impl<G: Generator + 'static> Worker<G> {
let listener = {
let taken = ctx.listener_override.lock().unwrap().take();
if let Some(existing) = taken {
log::info!(
"using pre-bound worker listener {}",
Self::describe_listener(&existing)
);
existing
} else {
TcpListener::bind(&ctx.args.address).await?
Self::bind_listener(&ctx.args.address).await?
}
};

log::info!(
"listening on {} (mem:{}) ...",
&ctx.args.address,
human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64)
human_bytes::human_bytes(
memory_stats::memory_stats()
.map(|m| m.physical_mem)
.unwrap_or(0) as f64
)
);

let device = ctx.device.clone();
Expand All @@ -268,7 +310,10 @@ impl<G: Generator + 'static> Worker<G> {
context: ctx.clone(),
};

Ok(Self { listener, context })
Ok(Self {
listener: Some(listener),
context,
})
}

/// Read a message from the socket and return elapsed time, message size and message.
Expand Down Expand Up @@ -355,12 +400,19 @@ impl<G: Generator + 'static> Worker<G> {
let mut write_buf = Vec::with_capacity(64 * 1024);

// keep reading messages
while let Ok((read_time, read_size, op_message)) = {
let start = Instant::now();
Message::from_reader_buf(&mut socket, &mut read_buf)
.await
.map(|(size, msg)| (start.elapsed(), size, msg))
} {
loop {
let (read_time, read_size, op_message) = match {
let start = Instant::now();
Message::from_reader_buf(&mut socket, &mut read_buf)
.await
.map(|(size, msg)| (start.elapsed(), size, msg))
} {
Ok(result) => result,
Err(e) => {
log::info!("[{}] connection loop ended: {}", &client, e);
break;
}
};
if matches!(op_message, Message::Goodbye) {
log::debug!("[{}] goodbye", &client);
context
Expand Down Expand Up @@ -574,24 +626,67 @@ impl<G: Generator + 'static> Worker<G> {
msg_idx += 1;
}

log::info!("[{}] handler exiting", &client);
Ok(())
}

/// Run the worker server accept loop.
pub async fn run(&mut self) -> Result<()> {
while let Ok((socket, client)) = self.listener.accept().await {
let _ = socket.set_nodelay(true);
log::debug!("{} connected", &client);

let context = self.context.get_client_context();
tokio::spawn(async move {
if let Err(e) = Self::handle_master_client(socket, client, context).await {
log::error!("{}", e);
loop {
let listener_desc = self
.listener
.as_ref()
.map(Self::describe_listener)
.unwrap_or_else(|| "<missing listener>".to_string());
log::info!("worker accept loop awaiting master on {}", listener_desc);

let accept_result = match self.listener.as_mut() {
Some(listener) => listener.accept().await,
None => {
let bind_address = self.context.context.args.address.clone();
log::warn!(
"worker listener missing before accept; rebinding on {}",
bind_address
);
self.listener = Some(Self::bind_listener(&bind_address).await?);
continue;
}
});
}
};

Ok(())
match accept_result {
Ok((socket, client)) => {
let _ = socket.set_nodelay(true);
log::info!("[{}] accepted on {}", &client, listener_desc);

let context = self.context.get_client_context();
tokio::spawn(async move {
if let Err(e) = Self::handle_master_client(socket, client, context).await {
log::error!("{}", e);
}
});
}
Err(e) if Self::accept_error_is_transient(e.kind()) => {
log::warn!(
"transient accept error on {}: {} ({:?})",
listener_desc,
e,
e.kind()
);
}
Err(e) => {
let bind_address = self.context.context.args.address.clone();
log::error!(
"accept failed on {}: {} ({:?}); dropping listener and rebinding {}",
listener_desc,
e,
e.kind(),
bind_address
);
self.listener.take();
self.listener = Some(Self::bind_listener(&bind_address).await?);
}
}
}
}
}

Expand Down Expand Up @@ -742,7 +837,10 @@ mod tests {
// New context should have a fresh cache (as_new clears KV entries)
assert!(client_ctx.context.cache.is_some());
// Device and dtype should be copied
assert_eq!(format!("{:?}", client_ctx.device), format!("{:?}", Device::Cpu));
assert_eq!(
format!("{:?}", client_ctx.device),
format!("{:?}", Device::Cpu)
);
assert_eq!(client_ctx.dtype, DType::F32);
}

Expand All @@ -760,8 +858,9 @@ mod tests {
let (mut server, mut client) = duplex(65536);

// Write from server side
let (write_dur, write_size) =
<Worker<SD>>::write_message_timed(&mut server, msg).await.unwrap();
let (write_dur, write_size) = <Worker<SD>>::write_message_timed(&mut server, msg)
.await
.unwrap();
assert!(write_size > 0);
assert!(write_dur.as_nanos() > 0);

Expand All @@ -773,7 +872,12 @@ mod tests {

// Verify the message was correctly serialized/deserialized
match read_msg {
Message::SingleOp { layer_name, x, index_pos, block_idx } => {
Message::SingleOp {
layer_name,
x,
index_pos,
block_idx,
} => {
assert_eq!(layer_name, "test_layer");
assert_eq!(index_pos, 0);
assert_eq!(block_idx, 0);
Expand All @@ -797,10 +901,11 @@ mod tests {
let msg = Message::from_batch(&tensor, batch);

let (mut server, mut client) = duplex(65536);
<Worker<SD>>::write_message_timed(&mut server, msg).await.unwrap();
<Worker<SD>>::write_message_timed(&mut server, msg)
.await
.unwrap();

let (_dur, _size, read_msg) =
<Worker<SD>>::read_message_timed(&mut client).await.unwrap();
let (_dur, _size, read_msg) = <Worker<SD>>::read_message_timed(&mut client).await.unwrap();
match read_msg {
Message::Batch { x, batch } => {
let t = x.to_tensor(&Device::Cpu).unwrap();
Expand Down