Skip to content

Commit

Permalink
Merge pull request #39 from RasmusOrsoe/main
Browse files Browse the repository at this point in the history
Added persistent_workers as argument
  • Loading branch information
RasmusOrsoe committed Oct 22, 2021
2 parents c2cf966 + 499bb93 commit c8b3ba2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/gnn_reco/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,18 @@ def _predict(self):
return out


def make_train_validation_dataloader(db, selection, pulsemap, batch_size, FEATURES, TRUTH, num_workers):
def make_train_validation_dataloader(db, selection, pulsemap, batch_size, FEATURES, TRUTH, num_workers, persistent_workers = True):
training_selection, validation_selection = train_test_split(selection, test_size=0.33, random_state=42)

training_dataset = SQLiteDataset(db, pulsemap, FEATURES, TRUTH, selection= training_selection)
training_dataset.close_connection()
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
collate_fn=Batch.from_data_list,persistent_workers=True,prefetch_factor=2)
collate_fn=Batch.from_data_list,persistent_workers=persistent_workers,prefetch_factor=2)

validation_dataset = SQLiteDataset(db, pulsemap, FEATURES, TRUTH, selection= validation_selection)
validation_dataset.close_connection()
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
collate_fn=Batch.from_data_list,persistent_workers=True,prefetch_factor=2)
collate_fn=Batch.from_data_list,persistent_workers=persistent_workers,prefetch_factor=2)
return training_dataloader, validation_dataloader

def save_results(db, tag, results, archive,model):
Expand Down

0 comments on commit c8b3ba2

Please sign in to comment.