diff --git a/src/gnn_reco/components/utils.py b/src/gnn_reco/components/utils.py index 0be8ad8f6..45caaf974 100644 --- a/src/gnn_reco/components/utils.py +++ b/src/gnn_reco/components/utils.py @@ -261,18 +261,18 @@ def _predict(self): return out -def make_train_validation_dataloader(db, selection, pulsemap, batch_size, FEATURES, TRUTH, num_workers): +def make_train_validation_dataloader(db, selection, pulsemap, batch_size, FEATURES, TRUTH, num_workers, persistent_workers = True): training_selection, validation_selection = train_test_split(selection, test_size=0.33, random_state=42) training_dataset = SQLiteDataset(db, pulsemap, FEATURES, TRUTH, selection= training_selection) training_dataset.close_connection() training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, - collate_fn=Batch.from_data_list,persistent_workers=True,prefetch_factor=2) + collate_fn=Batch.from_data_list,persistent_workers=persistent_workers,prefetch_factor=2) validation_dataset = SQLiteDataset(db, pulsemap, FEATURES, TRUTH, selection= validation_selection) validation_dataset.close_connection() validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, - collate_fn=Batch.from_data_list,persistent_workers=True,prefetch_factor=2) + collate_fn=Batch.from_data_list,persistent_workers=persistent_workers,prefetch_factor=2) return training_dataloader, validation_dataloader def save_results(db, tag, results, archive,model):