Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fix availability checker #1574

Merged
merged 14 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/lib/config/src/configs/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct FriProverConfig {
pub queue_capacity: usize,
pub witness_vector_receiver_port: u16,
pub zone_read_url: String,
pub availability_check_interval_in_secs: u32,
pub availability_check_interval_in_secs: Option<u32>,

// whether to write to public GCS bucket for https://github.com/matter-labs/era-boojum-validator-cli
pub shall_save_to_public_bucket: bool,
Expand Down
2 changes: 1 addition & 1 deletion core/lib/env_config/src/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mod tests {
},
max_retries: 5,
}),
availability_check_interval_in_secs: 1_800,
availability_check_interval_in_secs: Some(1_800),
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/lib/protobuf_config/src/proto/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ message Prover {
optional uint64 queue_capacity = 10; // required
optional uint32 witness_vector_receiver_port = 11; // required; u16
optional string zone_read_url = 12; // required
optional uint32 availability_check_interval_in_secs = 21; // required; s
optional uint32 availability_check_interval_in_secs = 21; // optional; s
optional bool shall_save_to_public_bucket = 13; // required
optional config.object_store.ObjectStore object_store = 20;
reserved 5, 6, 9; reserved "base_layer_circuit_ids_to_be_verified", "recursive_layer_circuit_ids_to_be_verified", "witness_vector_generator_thread_count";
Expand Down
7 changes: 2 additions & 5 deletions core/lib/protobuf_config/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,7 @@ impl ProtoRepr for proto::Prover {
zone_read_url: required(&self.zone_read_url)
.context("zone_read_url")?
.clone(),
availability_check_interval_in_secs: *required(
&self.availability_check_interval_in_secs,
)
.context("availability_check_interval_in_secs")?,
availability_check_interval_in_secs: self.availability_check_interval_in_secs,
shall_save_to_public_bucket: *required(&self.shall_save_to_public_bucket)
.context("shall_save_to_public_bucket")?,
object_store,
Expand All @@ -341,7 +338,7 @@ impl ProtoRepr for proto::Prover {
queue_capacity: Some(this.queue_capacity.try_into().unwrap()),
witness_vector_receiver_port: Some(this.witness_vector_receiver_port.into()),
zone_read_url: Some(this.zone_read_url.clone()),
availability_check_interval_in_secs: Some(this.availability_check_interval_in_secs),
availability_check_interval_in_secs: this.availability_check_interval_in_secs,
shall_save_to_public_bucket: Some(this.shall_save_to_public_bucket),
object_store: this.object_store.as_ref().map(ProtoRepr::build),
}
Expand Down
6 changes: 5 additions & 1 deletion prover/prover_fri/src/gpu_prover_availability_checker.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#[cfg(feature = "gpu")]
pub mod availability_checker {
use std::time::Duration;
use std::{sync::Arc, time::Duration};

use prover_dal::{ConnectionPool, Prover, ProverDal};
use tokio::sync::Notify;
use zksync_types::prover_dal::{GpuProverInstanceStatus, SocketAddress};

use crate::metrics::{KillingReason, METRICS};
Expand Down Expand Up @@ -34,7 +35,10 @@ pub mod availability_checker {
pub async fn run(
self,
stop_receiver: tokio::sync::watch::Receiver<bool>,
init_notifier: Arc<Notify>,
) -> anyhow::Result<()> {
init_notifier.notified().await;

while !*stop_receiver.borrow() {
let status = self
.pool
Expand Down
47 changes: 34 additions & 13 deletions prover/prover_fri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use local_ip_address::local_ip;
use prometheus_exporter::PrometheusExporterConfig;
use prover_dal::{ConnectionPool, Prover, ProverDal};
use tokio::{
sync::{oneshot, watch::Receiver},
sync::{oneshot, watch::Receiver, Notify},
Artemka374 marked this conversation as resolved.
Show resolved Hide resolved
task::JoinHandle,
};
use zksync_config::configs::{
Expand Down Expand Up @@ -142,18 +142,23 @@ async fn main() -> anyhow::Result<()> {
.await
.context("failed to build a connection pool")?;
let port = prover_config.witness_vector_receiver_port;

let notify = Arc::new(Notify::new());

let prover_tasks = get_prover_tasks(
prover_config,
stop_receiver.clone(),
object_store_factory,
public_blob_store,
pool,
circuit_ids_for_round_to_be_proven,
notify,
)
.await
.context("get_prover_tasks()")?;

let mut tasks = vec![tokio::spawn(exporter_config.run(stop_receiver))];

tasks.extend(prover_tasks);

let mut tasks = ManagedTasks::new(tasks);
Expand All @@ -176,6 +181,7 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}

#[allow(clippy::too_many_arguments)]
#[cfg(not(feature = "gpu"))]
async fn get_prover_tasks(
prover_config: FriProverConfig,
Expand All @@ -184,6 +190,7 @@ async fn get_prover_tasks(
public_blob_store: Option<Arc<dyn ObjectStore>>,
pool: ConnectionPool<Prover>,
circuit_ids_for_round_to_be_proven: Vec<CircuitIdRoundTuple>,
_init_notifier: Arc<Notify>,
) -> anyhow::Result<Vec<JoinHandle<anyhow::Result<()>>>> {
use zksync_vk_setup_data_server_fri::commitment_utils::get_cached_commitments;

Expand All @@ -210,6 +217,7 @@ async fn get_prover_tasks(
Ok(vec![tokio::spawn(prover.run(stop_receiver, None))])
}

#[allow(clippy::too_many_arguments)]
#[cfg(feature = "gpu")]
async fn get_prover_tasks(
prover_config: FriProverConfig,
Expand All @@ -218,6 +226,7 @@ async fn get_prover_tasks(
public_blob_store: Option<Arc<dyn ObjectStore>>,
pool: ConnectionPool<Prover>,
circuit_ids_for_round_to_be_proven: Vec<CircuitIdRoundTuple>,
init_notifier: Arc<Notify>,
) -> anyhow::Result<Vec<JoinHandle<anyhow::Result<()>>>> {
use gpu_prover_job_processor::gpu_prover;
use socket_listener::gpu_socket_listener;
Expand Down Expand Up @@ -263,17 +272,29 @@ async fn get_prover_tasks(
prover_config.specialized_group_id,
zone.clone(),
);
let availability_checker =
gpu_prover_availability_checker::availability_checker::AvailabilityChecker::new(
address,
zone,
prover_config.availability_check_interval_in_secs,
pool,
);

Ok(vec![
tokio::spawn(socket_listener.listen_incoming_connections(stop_receiver.clone())),

let mut tasks = vec![
tokio::spawn(
socket_listener
.listen_incoming_connections(stop_receiver.clone(), init_notifier.clone()),
),
tokio::spawn(prover.run(stop_receiver.clone(), None)),
tokio::spawn(availability_checker.run(stop_receiver.clone())),
])
];

// TODO(PLA-874): remove the check after making the availability checker required
if let Some(check_interval) = prover_config.availability_check_interval_in_secs {
let availability_checker =
gpu_prover_availability_checker::availability_checker::AvailabilityChecker::new(
address,
zone,
check_interval,
pool,
);

tasks.push(tokio::spawn(
availability_checker.run(stop_receiver.clone(), init_notifier),
));
}

Ok(tasks)
}
10 changes: 6 additions & 4 deletions prover/prover_fri/src/socket_listener.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#[cfg(feature = "gpu")]
pub mod gpu_socket_listener {
use std::{net::SocketAddr, time::Instant};
use std::{net::SocketAddr, sync::Arc, time::Instant};

use anyhow::Context as _;
use prover_dal::{ConnectionPool, Prover, ProverDal};
use tokio::{
io::copy,
net::{TcpListener, TcpStream},
sync::watch,
sync::{watch, Notify},
};
use zksync_object_store::bincode;
use zksync_prover_fri_types::WitnessVectorArtifacts;
Expand Down Expand Up @@ -42,7 +42,7 @@ pub mod gpu_socket_listener {
zone,
}
}
async fn init(&self) -> anyhow::Result<TcpListener> {
async fn init(&self, init_notifier: Arc<Notify>) -> anyhow::Result<TcpListener> {
let listening_address = SocketAddr::new(self.address.host, self.address.port);
tracing::info!(
"Starting assembly receiver at host: {}, port: {}",
Expand All @@ -65,14 +65,16 @@ pub mod gpu_socket_listener {
self.zone.clone(),
)
.await;
init_notifier.notify_one();
Ok(listener)
}

pub async fn listen_incoming_connections(
self,
stop_receiver: watch::Receiver<bool>,
init_notifier: Arc<Notify>,
) -> anyhow::Result<()> {
let listener = self.init().await.context("init()")?;
let listener = self.init(init_notifier).await.context("init()")?;
let mut now = Instant::now();
loop {
if *stop_receiver.borrow() {
Expand Down
Loading