@@ -17,7 +17,7 @@ use scannerlib::{
17
17
storage:: redis:: { RedisAddAdvisory , RedisAddNvt , RedisCtx , RedisWrapper } ,
18
18
} ;
19
19
use sqlx:: { QueryBuilder , Row , SqlitePool , query, query_scalar} ;
20
- use tokio:: sync:: mpsc:: Sender ;
20
+ use tokio:: { sync:: mpsc:: Sender , time :: MissedTickBehavior } ;
21
21
22
22
use crate :: { config:: Config , crypt:: Crypt } ;
23
23
mod nasl;
@@ -27,6 +27,13 @@ struct ScanScheduler<Scanner, Cryptor> {
27
27
cryptor : Arc < Cryptor > ,
28
28
scanner : Arc < Scanner > ,
29
29
max_concurrent_scan : usize ,
30
+ // Exists to prevent a bug in which scans were accidentally started twice due to deferred writes in sqlite.
31
+ // To address this bug, this field keeps track of all requested scans manually, instead of relying on
32
+ // sqlite for this information.
33
+ //
34
+ // This is a hack, if there are any better solutions that work as reliably this can be removed
35
+ // without functional harm.
36
+ requested_guard : Arc < RwLock < Vec < i64 > > > ,
30
37
}
31
38
32
39
#[ derive( Debug ) ]
@@ -98,38 +105,6 @@ impl<T, C> ScanScheduler<T, C> {
98
105
Ok ( ( ) )
99
106
}
100
107
101
- async fn scan_update_status ( & self , id : i64 , status : models:: Status ) -> R < ( ) > {
102
- let host_info = status. host_info . unwrap_or_default ( ) ;
103
-
104
- let row = query (
105
- r#"
106
- UPDATE scans SET
107
- start_time = COALESCE(?, start_time),
108
- end_time = COALESCE(?, end_time),
109
- host_dead = COALESCE(NULLIF(?, 0), host_dead),
110
- host_alive = COALESCE(NULLIF(?, 0), host_alive),
111
- host_queued = COALESCE(NULLIF(?, 0), host_queued),
112
- host_excluded = COALESCE(NULLIF(?, 0), host_excluded),
113
- host_all = COALESCE(NULLIF(?, 0), host_all),
114
- status = COALESCE(NULLIF(?, 'stored'), status)
115
- WHERE id = ?
116
- "# ,
117
- )
118
- . bind ( status. start_time . map ( |x| x as i64 ) )
119
- . bind ( status. end_time . map ( |x| x as i64 ) )
120
- . bind ( host_info. dead as i64 )
121
- . bind ( host_info. alive as i64 )
122
- . bind ( host_info. queued as i64 )
123
- . bind ( host_info. excluded as i64 )
124
- . bind ( host_info. all as i64 )
125
- . bind ( status. status . as_ref ( ) )
126
- . bind ( id)
127
- . execute ( & self . pool )
128
- . await ?;
129
- tracing:: debug!( id, rows_affected=row. rows_affected( ) , status = %status. status, "Set status." ) ;
130
- Ok ( ( ) )
131
- }
132
-
133
108
async fn scan_insert_results (
134
109
& self ,
135
110
id : i64 ,
@@ -207,16 +182,7 @@ where
207
182
Scanner : ScanStarter + ScanStopper + ScanDeleter + ScanResultFetcher + Send + Sync + ' static ,
208
183
C : Crypt + Send + Sync + ' static ,
209
184
{
210
- /// Checks for scans that are requested and may start them
211
- ///
212
- /// After verifying concurrently running scans it starts a scan when the scan was started
213
- /// successfully than it sets it to 'running', if the start fails then it sets it to failed.
214
- ///
215
- /// In the case that the ScanStarter implementation blocks start_scan is spawned as a
216
- /// background task.
217
- async fn requested_to_running ( & self ) -> R < ( ) > {
218
- let mut tx = self . pool . begin ( ) . await ?;
219
-
185
+ async fn fetch_requested ( & self ) -> R < Vec < i64 > > {
220
186
let ids: Vec < i64 > = match self . max_concurrent_scan {
221
187
0 => query_scalar (
222
188
"SELECT id FROM scans WHERE status = 'requested' ORDER BY created_at ASC " ,
@@ -240,34 +206,81 @@ where
240
206
)
241
207
. bind ( m as i64 ) ,
242
208
}
243
- . fetch_all ( & mut * tx )
209
+ . fetch_all ( & self . pool )
244
210
. await ?;
211
+
212
+ Ok ( ids)
213
+ }
214
+
215
+ async fn is_already_started ( & self , id : i64 ) -> bool {
216
+ let guard = self . requested_guard . clone ( ) ;
217
+ tokio:: task:: spawn_blocking ( move || {
218
+ let cached = guard. read ( ) . unwrap ( ) ;
219
+ cached. contains ( & id)
220
+ } )
221
+ . await
222
+ . unwrap ( )
223
+ }
224
+
225
+ async fn set_to_running ( & self , id : i64 ) -> R < ( ) > {
226
+ let row = query ( "UPDATE scans SET status = 'running' WHERE id = ?" )
227
+ . bind ( id)
228
+ . execute ( & self . pool )
229
+ . await ?;
230
+
231
+ let mut tx = self . pool . begin ( ) . await ?;
232
+ let scan = super :: get_scan ( & mut tx, self . cryptor . as_ref ( ) , id) . await ?;
233
+ tx. commit ( ) . await ?;
234
+ if self . is_already_started ( id) . await {
235
+ tracing:: trace!( id, "Has been already started, skipping" ) ;
236
+ return Ok ( ( ) ) ;
237
+ }
238
+
239
+ tracing:: info!( id, running = row. rows_affected( ) , "Started scan" ) ;
240
+
241
+ let guard = self . requested_guard . clone ( ) ;
242
+ tokio:: task:: spawn_blocking ( move || {
243
+ let mut x = guard. write ( ) . unwrap ( ) ;
244
+ x. push ( id) ;
245
+ } )
246
+ . await
247
+ . unwrap ( ) ;
248
+
249
+ scan_start ( self . pool . clone ( ) , self . scanner . clone ( ) , id, scan) . await ;
250
+ Ok ( ( ) )
251
+ }
252
+ /// Checks for scans that are requested and may start them
253
+ ///
254
+ /// After verifying concurrently running scans it starts a scan when the scan was started
255
+ /// successfully than it sets it to 'running', if the start fails then it sets it to failed.
256
+ async fn requested_to_running ( & self ) -> R < ( ) > {
257
+ let ids = self . fetch_requested ( ) . await ?;
258
+
245
259
for id in ids {
246
- if !self . scanner . can_start_scan ( ) . await {
260
+ // To prevent accidental state change from running -> requested based on an old
261
+ // snapshot we only do a resource when a scan has not already been started.
262
+ if !self . is_already_started ( id) . await && !self . scanner . can_start_scan ( ) . await {
247
263
break ;
248
264
}
249
265
250
- let scan = super :: get_scan ( & mut tx, self . cryptor . as_ref ( ) , id) . await ?;
251
- let row = query ( "UPDATE scans SET status = 'running' WHERE id = ?" )
252
- . bind ( id)
253
- . execute ( & mut * tx)
254
- . await ?;
255
- tracing:: info!( id, running = row. rows_affected( ) , "Started scan" ) ;
256
- tokio:: task:: spawn ( scan_start (
257
- self . pool . clone ( ) ,
258
- self . scanner . clone ( ) ,
259
- id,
260
- scan,
261
- ) ) ;
266
+ self . set_to_running ( id) . await ?;
262
267
}
263
268
264
- tx. commit ( ) . await ?;
265
-
266
269
Ok ( ( ) )
267
270
}
268
271
272
+ async fn remove_id_from_guard ( & self , id : i64 ) {
273
+ let cache = self . requested_guard . clone ( ) ;
274
+ tokio:: task:: spawn_blocking ( move || {
275
+ let mut cache = cache. write ( ) . unwrap ( ) ;
276
+ if let Some ( index) = cache. iter ( ) . position ( |x| x == & id) {
277
+ cache. swap_remove ( index) ;
278
+ }
279
+ } ) ;
280
+ }
281
+
269
282
async fn scan_import_results ( & self , internal_id : i64 , scan_id : String ) -> R < ( ) > {
270
- let mut results = match self . scanner . fetch_results ( scan_id) . await {
283
+ let mut results = match self . scanner . fetch_results ( scan_id. clone ( ) ) . await {
271
284
Ok ( x) => x,
272
285
Err ( scannerlib:: scanner:: Error :: ScanNotFound ( scan_id) ) => {
273
286
let reason = format ! ( "Tried to get results of an unknown scan ({scan_id})" ) ;
@@ -277,6 +290,7 @@ where
277
290
} ;
278
291
279
292
let kind = self . scanner . scan_result_status_kind ( ) ;
293
+
280
294
self . scan_insert_results ( internal_id, results. results , & kind)
281
295
. await ?;
282
296
let previous_status = super :: scan_get_status ( & self . pool , internal_id) . await ?;
@@ -289,7 +303,11 @@ where
289
303
}
290
304
} ;
291
305
292
- self . scan_update_status ( internal_id, status) . await
306
+ self . scan_update_status ( internal_id, & status) . await ?;
307
+ if status. is_done ( ) {
308
+ self . scanner_delete_scan ( internal_id, scan_id) . await ?;
309
+ }
310
+ Ok ( ( ) )
293
311
}
294
312
295
313
async fn import_results ( & self ) -> R < ( ) > {
@@ -311,14 +329,52 @@ where
311
329
Ok ( ( ) )
312
330
}
313
331
332
+ async fn scan_update_status ( & self , id : i64 , status : & models:: Status ) -> R < ( ) > {
333
+ let host_info = status. host_info . clone ( ) . unwrap_or_default ( ) ;
334
+
335
+ let row = query (
336
+ r#"
337
+ UPDATE scans SET
338
+ start_time = COALESCE(?, start_time),
339
+ end_time = COALESCE(?, end_time),
340
+ host_dead = COALESCE(NULLIF(?, 0), host_dead),
341
+ host_alive = COALESCE(NULLIF(?, 0), host_alive),
342
+ host_queued = COALESCE(NULLIF(?, 0), host_queued),
343
+ host_excluded = COALESCE(NULLIF(?, 0), host_excluded),
344
+ host_all = COALESCE(NULLIF(?, 0), host_all),
345
+ status = COALESCE(NULLIF(?, 'stored'), status)
346
+ WHERE id = ?
347
+ "# ,
348
+ )
349
+ . bind ( status. start_time . map ( |x| x as i64 ) )
350
+ . bind ( status. end_time . map ( |x| x as i64 ) )
351
+ . bind ( host_info. dead as i64 )
352
+ . bind ( host_info. alive as i64 )
353
+ . bind ( host_info. queued as i64 )
354
+ . bind ( host_info. excluded as i64 )
355
+ . bind ( host_info. all as i64 )
356
+ . bind ( status. status . as_ref ( ) )
357
+ . bind ( id)
358
+ . execute ( & self . pool )
359
+ . await ?;
360
+ tracing:: debug!( id, rows_affected=row. rows_affected( ) , status = %status. status, "Set status." ) ;
361
+ Ok ( ( ) )
362
+ }
363
+
364
+ async fn scanner_delete_scan ( & self , internal_id : i64 , scan_id : String ) -> R < ( ) > {
365
+ self . remove_id_from_guard ( internal_id) . await ;
366
+ self . scanner . delete_scan ( scan_id) . await ?;
367
+ Ok ( ( ) )
368
+ }
369
+
314
370
async fn scan_stop ( & self , id : i64 ) -> R < ( ) > {
315
371
let scan_id: String = query_scalar ( "SELECT scan_id FROM client_scan_map WHERE id = ?" )
316
372
. bind ( id)
317
373
. fetch_one ( & self . pool )
318
374
. await ?;
319
375
self . scanner . stop_scan ( scan_id. clone ( ) ) . await ?;
320
376
self . scan_import_results ( id, scan_id. clone ( ) ) . await ?;
321
- self . scanner . delete_scan ( scan_id . clone ( ) ) . await ? ;
377
+
322
378
self . scan_running_to_stopped ( id) . await ?;
323
379
324
380
Ok ( ( ) )
@@ -353,6 +409,11 @@ where
353
409
tracing:: warn!( %error, "Unable to set not stopped runs from a previous session to failed." )
354
410
}
355
411
let mut interval = tokio:: time:: interval ( check_interval) ;
412
+ // The default on missed ticks is bursted. Which means when a tick was missed instead of
413
+ // ticking in the interval after the new time it is immediately triggering missed ticks
414
+ // resulting in immediately calling scheduler.on_schedule. What we would rather do on a missed
415
+ // tick is waiting for that interval until we check again.
416
+ interval. set_missed_tick_behavior ( MissedTickBehavior :: Delay ) ;
356
417
let ( sender, mut recv) = tokio:: sync:: mpsc:: channel ( 10 ) ;
357
418
tokio:: spawn ( async move {
358
419
loop {
@@ -397,6 +458,7 @@ where
397
458
cryptor : crypter,
398
459
max_concurrent_scan : config. scheduler . max_queued_scans . unwrap_or ( 0 ) ,
399
460
scanner : Arc :: new ( scanner) ,
461
+ requested_guard : Arc :: new ( RwLock :: new ( vec ! [ ] ) ) ,
400
462
} ;
401
463
402
464
run_scheduler ( config. scheduler . check_interval , scheduler) . await
@@ -610,6 +672,7 @@ pub(crate) mod tests {
610
672
scanner,
611
673
cryptor,
612
674
max_concurrent_scan : 4 ,
675
+ requested_guard : Arc :: new ( RwLock :: new ( vec ! [ ] ) ) ,
613
676
} ;
614
677
let known_scans = prepare_scans ( pool. clone ( ) , & config) . await ;
615
678
Ok ( ( under_test, known_scans) )
0 commit comments