Skip to content

Commit

Permalink
feat: Add new CLI commands and improve error handling in DICOMSorter …
Browse files Browse the repository at this point in the history
…and NBIAClient
  • Loading branch information
jjjermiah committed Jan 30, 2024
1 parent 5d9e2d7 commit 83fd4de
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 39 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ NBIAToolkit = "nbiatoolkit:version"
getCollections = "nbiatoolkit.nbia_cli:getCollections_cli"
getPatients = "nbiatoolkit.nbia_cli:getPatients_cli"
getBodyPartCounts = "nbiatoolkit.nbia_cli:getBodyPartCounts_cli"

downloadSingleSeries = "nbiatoolkit.nbia_cli:downloadSingleSeries_cli"
getSeries = "nbiatoolkit.nbia_cli:getSeries_cli"

[tool.poetry.dependencies]
python = ">=3.11 || 3.12"
Expand Down
2 changes: 1 addition & 1 deletion src/nbiatoolkit/dicomsort/dicomsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def sortSingleDICOMFile(
if os.path.exists(targetFilename) and not overwrite:
print(f"Source File: {filePath}\n")
print(f"File {targetFilename} already exists. ")
sys.exit(
raise ValueError(
"Pattern is probably not unique or overwrite is set to False. Exiting."
)

Expand Down
38 changes: 23 additions & 15 deletions src/nbiatoolkit/nbia.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import e
from .auth import OAuth2
from .utils.nbia_endpoints import NBIA_ENDPOINTS
from .logger.logger import setup_logger
Expand Down Expand Up @@ -265,23 +266,30 @@ def downloadSeries(

with cf.ThreadPoolExecutor(max_workers=nParallel) as executor:
futures = []
for seriesUID in SeriesInstanceUID:
future = executor.submit(
self._downloadSingleSeries,
SeriesInstanceUID=seriesUID,
downloadDir=downloadDir,
filePattern=filePattern,
overwrite=overwrite,
)
futures.append(future)

# Use tqdm to create a progress bar
with tqdm(
total=len(futures), desc=f"Downloading {len(futures)} series"
) as pbar:
for future in cf.as_completed(futures):
pbar.update(1)
try:
os.makedirs(downloadDir)

for seriesUID in SeriesInstanceUID:
future = executor.submit(
self._downloadSingleSeries,
SeriesInstanceUID=seriesUID,
downloadDir=downloadDir,
filePattern=filePattern,
overwrite=overwrite,
)
futures.append(future)

except Exception as e:
self.log.error("Error creating download directory: %s", e)
raise e
else:
# Use tqdm to create a progress bar
with tqdm(
total=len(futures), desc=f"Downloading {len(futures)} series"
) as pbar:
for future in cf.as_completed(futures):
pbar.update(1)
return True

# _downloadSingleSeries is a helper function that downloads a single series
Expand Down
212 changes: 190 additions & 22 deletions src/nbiatoolkit/nbia_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,52 @@
import time, threading
from pprint import pprint
from pyfiglet import Figlet
import sys

# import from typing libraries
from typing import List, Dict, Tuple, Union, Optional, Any

done_event = threading.Event()
query: str

output: str | None = None

def version():
f = Figlet(font="slant")
print(f.renderText("NBIAToolkit"))
print("Version: {}".format(__version__))
return


# create a helper function that will be used if the user ever uses --output <FILE>.tsv
def writeResultsToFile(results: List, output: str) -> None:
"""
Writes the results of a query to a file.
Args:
results: The results of the query.
output: The path to the output file.
Returns:
None
"""

# open the file in write mode
with open(output, "w") as f:

if isinstance(results[0], dict):
# write the header
f.write("\t".join(results[0].keys()) + "\n")
# write the results
for result in results:
f.write("\t".join(str(value) for value in result.values()) + "\n")
else:
# write the results
for result in results:
f.write(str(result) + "\n")

return


# An abstraction of the getCollections and getPatients functions
# to generalize an interface for the CLI
def getResults_cli(func, **kwargs) -> None:
Expand All @@ -39,18 +71,39 @@ def getResults_cli(func, **kwargs) -> None:
"""

global query
global output

# Execute the function
results = cli_wrapper(func=func, **kwargs)
if results == True:
return

# If the user specified an output file, write the results to the file
if output and isinstance(results, list) and len(results):
writeResultsToFile(results, output)
return
elif not isinstance(results, list) or not len(results):

return
# Print the result
if isinstance(results, list) and len(results):
[print(patient) for patient in results]
elif isinstance(results[0], dict) and len(results):
print("\t".join(results[0].keys()))
for result in results:
print("\t".join(str(value) for value in result.values())) # type: ignore
return

elif(isinstance(results, list)):
for result in results:
print(result)
return


print(f"No {query} found. Check parameters using -h or try again later.")
return




def cli_wrapper(func, **kwargs) -> List[str] | None:
"""
Wraps a function call with a loading animation.
Expand All @@ -68,6 +121,9 @@ def cli_wrapper(func, **kwargs) -> List[str] | None:

# Start the loading animation in a separate thread
loading_thread = threading.Thread(target=loading_animation)

# daemon threads are killed when the main thread exits
loading_thread.daemon = True
loading_thread.start()

# Perform the database query in the main thread
Expand Down Expand Up @@ -97,52 +153,58 @@ def loading_animation():
]

# Find the maximum length of the loading animation strings
max_length = max(len(animation) for animation in animations)
# print(animations[0])
while not done_event.is_set():

if done_event.is_set():
# clear the line
print(" " * max_length*2, end="\r", flush=True)
# sys.stdout.write("\033[F")
# print( len(animations[0]))
# print(" " * len(animations[0]), end="\r")
break

for animation in animations:
# Pad the animation string with spaces to the maximum length
padded_animation = animation.ljust(max_length).rstrip("\n")

print(padded_animation, end="\r", flush=True)
time.sleep(0.5)


def getPatients_cli() -> None:
p = argparse.ArgumentParser(description="NBIAToolkit: get patient names")
global query
global output
query = "patients"
p = general_argParser()

p.add_argument(
"--collection", dest="collection", action="store", required=True, type=str,
)

args = p.parse_args()

global query
query = "patients"
if args.version:
version()
sys.exit(0)

if args.output:
output = args.output

return getResults_cli(func=NBIAClient().getPatients, Collection=args.collection)


def getCollections_cli() -> None:
p = argparse.ArgumentParser(description="NBIAtoolkit: get collection names")
global query
global output
query = "collections"

p = general_argParser()

p.add_argument(
"--prefix", dest="prefix", action="store", default="", type=str,
help = "The prefix to filter collections by, i.e \'TCGA\', \'LIDC\', \'NSCLC\'"
)
args = p.parse_args()
if args.version:
version()
sys.exit(0)

global query
query = "collections"
if args.output:
output = args.output

return getResults_cli(func=NBIAClient().getCollections, prefix=args.prefix)



def getBodyPartCounts_cli() -> None:
p = argparse.ArgumentParser(description="NBIAToolkit: get body part counts")

Expand All @@ -157,3 +219,109 @@ def getBodyPartCounts_cli() -> None:

return getResults_cli(func=NBIAClient().getBodyPartCounts, Collection=args.collection)

def getSeries_cli() -> None:

p = argparse.ArgumentParser(description="NBIAToolkit: get series")

p.add_argument(
"--collection", dest="collection", action="store", default = "", type=str,
)

p.add_argument(
"--patientID", dest="patientID", action="store", default = "", type=str,
)

p.add_argument(
"--studyInstanceUID", dest="studyInstanceUID", action="store", default = "", type=str,
)

p.add_argument(
"--modality", dest="modality", action="store", default = "", type=str,
)

p.add_argument(
"--seriesInstanceUID", dest="seriesInstanceUID", action="store", default = "", type=str,
)

p.add_argument(
"--bodyPartExamined", dest="bodyPartExamined", action="store", default = "", type=str,
)

p.add_argument(
"--manufacturerModelName", dest="manufacturerModelName", action="store", default = "", type=str,
)

p.add_argument(
"--manufacturer", dest="manufacturer", action="store", default = "", type=str,
)

args = p.parse_args()


global query
query = f"series"

return getResults_cli(
func=NBIAClient().getSeries,
Collection=args.collection,
PatientID=args.patientID,
StudyInstanceUID=args.studyInstanceUID,
Modality=args.modality,
SeriesInstanceUID=args.seriesInstanceUID,
BodyPartExamined=args.bodyPartExamined,
ManufacturerModelName=args.manufacturerModelName,
Manufacturer=args.manufacturer,
)

def downloadSingleSeries_cli() -> None:
global query
query = f"series"
# use the NBIAClient._downloadSingleSeries function to download a single series

p = argparse.ArgumentParser(description="NBIAToolkit: download a single series")

p.add_argument(
"--seriesUID", dest="seriesUID", action="store", required=True, type=str,
)

p.add_argument(
"--downloadDir", dest="downloadDir", action="store", required=True, type=str,
help = "The directory to download the series to"
)

p.add_argument(
"--filePattern", dest="filePattern", action="store", type=str,
default="%PatientID/%StudyInstanceUID/%SeriesInstanceUID/%SOPInstanceUID.dcm",
help = "The file pattern to use when downloading the series"
)

p.add_argument(
"--overwrite", action="store_true", default=False, help="Overwrite existing files"
)


args = p.parse_args()



return getResults_cli(
func=NBIAClient()._downloadSingleSeries,
SeriesInstanceUID=args.seriesUID,
downloadDir=args.downloadDir,
filePattern=args.filePattern,
overwrite=args.overwrite)

def general_argParser():
global query
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

p.add_argument(
"--output", dest="output", action="store", type=str, help="Output file (tsv works best)"
)

p.add_argument(
"--version", "-v", action="store_true", help="Print the version number and exit"
)

return p

0 comments on commit 83fd4de

Please sign in to comment.