Skip to content

Commit

Permalink
feat(downloadSeries): add parallel download option
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjermiah committed Jan 7, 2024
1 parent e94e356 commit b697aa9
Showing 1 changed file with 90 additions and 3 deletions.
93 changes: 90 additions & 3 deletions src/nbiatoolkit/nbia.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import requests
from requests.exceptions import JSONDecodeError as JSONDecodeError
from typing import Union
import io
import zipfile

Expand Down Expand Up @@ -172,20 +173,65 @@ def getSeries(self,
return response


from tqdm import tqdm

def downloadSeries(
self,
SeriesInstanceUID: str,
SeriesInstanceUID: Union[str, list],
downloadDir: str = "./NBIA-Download",
filePattern: str = '%PatientName/%StudyDescription-%StudyDate/%SeriesNumber-%SeriesDescription-%SeriesInstanceUID/%InstanceNumber.dcm',
overwrite: bool = False
) -> bool:
overwrite: bool = False,
nParallel: int = 1
) -> bool:
assert isinstance(SeriesInstanceUID, (str, list)), \
"SeriesInstanceUID must be a string or list"
assert isinstance(downloadDir, str), "downloadDir must be a string"
assert isinstance(filePattern, str), "filePattern must be a string"
assert isinstance(overwrite, bool), "overwrite must be a boolean"

import concurrent.futures as cf
from tqdm import tqdm

if isinstance(SeriesInstanceUID, str):
SeriesInstanceUID = [SeriesInstanceUID]

with cf.ThreadPoolExecutor(max_workers=nParallel) as executor:
futures = []
for seriesUID in SeriesInstanceUID:
future = executor.submit(
self._downloadSingleSeries,
SeriesInstanceUID=seriesUID,
downloadDir=downloadDir,
filePattern=filePattern,
overwrite=overwrite)
futures.append(future)

# Use tqdm to create a progress bar
with tqdm(
total=len(futures),
desc=f"Downloading {len(futures)} series") as pbar:

for future in cf.as_completed(futures):
pbar.update(1)

return True



# _downloadSingleSeries is a helper function that downloads a single series
# to simplify the code in downloadSeries and also allow for parallel
# downloads in the future
def _downloadSingleSeries(
self, SeriesInstanceUID: str, downloadDir: str,
filePattern: str, overwrite: bool) -> bool:

# create temporary directory
from tempfile import TemporaryDirectory

params = dict()
params["SeriesInstanceUID"] = SeriesInstanceUID

self.log.debug("Downloading series: %s", SeriesInstanceUID)
response = self.query_api(
endpoint=NBIA_ENDPOINTS.DOWNLOAD_SERIES,
params=params)
Expand Down Expand Up @@ -226,3 +272,44 @@ def parsePARAMS(self, params: dict) -> dict:
return PARAMS


# main
if __name__ == "__main__":
from pprint import pprint
import os
client = NBIAClient(log_level='info')
# collections = client.getCollections()
# pprint(collections[0:5])
# seriesJSON = client.getSeries(Collection="4D-Lung")
# # first get a list of the SeriesInstanceUIDs
# seriesUIDS = [series['SeriesInstanceUID'] for series in seriesJSON]
# pprint(seriesUIDS[0:5])

seriesUIDS = [
'1.3.6.1.4.1.14519.5.2.1.6834.5010.189721824525842725510380467695',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.336250251691987239290048605884',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.227929163446067537882961857921',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.925990093742075237571072608963',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.139116724721865252687455544825',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.364787732307640672278270360328',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.384197169742944248273003912317',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.149750833495190982103087204448',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.300347070051003027185063750283',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.317831614083862743715273480521',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.736089011729021729851027177073',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.133381852562664457904201355429',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.909088026336573109170906532418',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.953079890279542310843831057254',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.427052348021168186336245283790',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.295010883410722294053941635303',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.263257070197787007872578860295',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.672179203515231442641005032212',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.184961274239908956209701869504',
'1.3.6.1.4.1.14519.5.2.1.6834.5010.797307942821711099898506950104']

downloadDir = "./data"
os.makedirs(downloadDir, exist_ok=True)

client.downloadSeries(
seriesUIDS, downloadDir, overwrite=True, nParallel=8)

pprint(os.listdir(downloadDir))

0 comments on commit b697aa9

Please sign in to comment.