Skip to content

Commit

Permalink
refactor: CLI code to improve performance and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjermiah committed Feb 1, 2024
1 parent 4cd3785 commit d5bd973
Showing 1 changed file with 85 additions and 187 deletions.
272 changes: 85 additions & 187 deletions src/nbiatoolkit/nbia_cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import glob
import re

from requests import patch
import io
from .nbia import NBIAClient, __version__


import argparse
import os
import sys
import time, threading
from pprint import pprint
import threading
from pyfiglet import Figlet
import sys
import subprocess
Expand All @@ -18,7 +15,7 @@

done_event = threading.Event()
query: str
output: str | None = None
output: io.TextIOWrapper | None = None


def version():
Expand All @@ -42,35 +39,25 @@ def version():

return

def general_parser(parser: argparse.ArgumentParser) -> argparse.Namespace:
parser.add_argument("-o", "--output", dest="outputfile",
action="store", type=argparse.FileType('w', encoding='UTF-8'), help="Output file (tsv works best)",
)

# create a helper function that will be used if the user ever uses --output <FILE>.tsv
def writeResultsToFile(results: List, output: str) -> None:
"""
Writes the results of a query to a file.
Args:
results: The results of the query.
output: The path to the output file.
Returns:
None
"""
parser.add_argument(
"--version", "-v", action="store_true", help="Print the version number and exit"
)

# open the file in write mode
with open(output, "w") as f:
if isinstance(results[0], dict):
# write the header
f.write("\t".join(results[0].keys()) + "\n")
# write the results
for result in results:
f.write("\t".join(str(value) for value in result.values()) + "\n")
else:
# write the results
for result in results:
f.write(str(result) + "\n")
args = parser.parse_args()
if args.version:
version()
sys.exit(0)

return
if args.outputfile:
global output
output = args.outputfile

return args

# An abstraction of the getCollections and getPatients functions
# to generalize an interface for the CLI
Expand All @@ -88,33 +75,55 @@ def getResults_cli(func, **kwargs) -> None:

global query
global output

# Execute the function
results = cli_wrapper(func=func, **kwargs)

# this is for the downloadSingleSeries function
if results == True:
return

# If the user specified an output file, write the results to the file
if output and isinstance(results, list) and len(results):
writeResultsToFile(results, output)
if not isinstance(results, list) or not len(results):
return
elif not isinstance(results, list) or not len(results):

if output:
writeResultsToFile(results, output)
return
# Print the result
elif isinstance(results[0], dict) and len(results):

if isinstance(results[0], dict):
print("\t".join(results[0].keys()))
for result in results:
print("\t".join(str(value) for value in result.values())) # type: ignore
return

elif isinstance(results, list):
for result in results:
print(result)
return

print(f"No {query} found. Check parameters using -h or try again later.")
return

# create a helper function that will be used if the user ever uses --output <FILE>.tsv
# output should be a io.TextIOWrapper object
def writeResultsToFile(results: List, output: io.TextIOWrapper) -> None:
"""
Writes the results of a query to a file.
Args:
results: The results of the query.
output: The path to the output file.
Returns:
None
"""

# write to the file
if isinstance(results[0], dict):
# write the header
output.write("\t".join(results[0].keys()) + "\n")
# write the results
for result in results:
output.write("\t".join(str(value) for value in result.values()) + "\n")
else:
# write the results
for result in results:
output.write(str(result) + "\n")
return

def cli_wrapper(func, **kwargs) -> List[str] | None:
"""
Expand All @@ -131,97 +140,46 @@ def cli_wrapper(func, **kwargs) -> List[str] | None:
global done_event
global query

# Start the loading animation in a separate thread
loading_thread = threading.Thread(target=loading_animation)
# # Start the loading animation in a separate thread
# loading_thread = threading.Thread(target=loading_animation)

# daemon threads are killed when the main thread exits
loading_thread.daemon = True
loading_thread.start()
# # daemon threads are killed when the main thread exits
# loading_thread.daemon = True
# loading_thread.start()

# Perform the database query in the main thread
result = func(**kwargs)

# Stop the loading animation
done_event.set()
loading_thread.join()
# # Stop the loading animation
# done_event.set()
# loading_thread.join()

return result


def loading_animation():
"""
Displays a loading animation while retrieving data.
This function prints a loading animation to the console while data is being retrieved.
It uses a list of animation strings and continuously prints them in a loop until the
'done_event' is set. The animation strings are padded with spaces to the maximum length
to ensure consistent display. The animation pauses for 0.5 seconds between each iteration.
"""
global query

animations = [
"Retrieving " + query + "." * i + " This may take a few seconds"
for i in range(4)
]

# Find the maximum length of the loading animation strings
# print(animations[0])
while not done_event.is_set():
if done_event.is_set():
# clear the line
# sys.stdout.write("\033[F")
# print( len(animations[0]))
# print(" " * len(animations[0]), end="\r")
break


def getPatients_cli() -> None:
global query
global output
query = "patients"
p = general_argParser()

p.add_argument(
"--collection",
dest="collection",
action="store",
required=True,
type=str,
)
args = p.parse_args()
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

if args.version:
version()
sys.exit(0)
p.add_argument("-c", "--collection", action="store",
required=True,type=str,)

if args.output:
output = args.output
args = general_parser(p)

return getResults_cli(func=NBIAClient().getPatients, Collection=args.collection)


def getCollections_cli() -> None:
global query
global output
query = "collections"
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

p = general_argParser()

p.add_argument(
"--prefix",
dest="prefix",
action="store",
default="",
type=str,
p.add_argument("-p", "--prefix",
action="store", default="", type=str,
help="The prefix to filter collections by, i.e 'TCGA', 'LIDC', 'NSCLC'",
)
args = p.parse_args()
if args.version:
version()
sys.exit(0)

if args.output:
output = args.output
args = general_parser(p)

return getResults_cli(func=NBIAClient().getCollections, prefix=args.prefix)

Expand All @@ -231,69 +189,37 @@ def getBodyPartCounts_cli() -> None:
global output
query = f"BodyPartCounts"

p = general_argParser()

p.add_argument(
"--collection",
dest="collection",
action="store",
default="",
type=str,
)

args = p.parse_args()
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

if args.version:
version()
sys.exit(0)
p.add_argument("-c", "--collection", dest="collection",
action="store", default="", type=str,)

if args.output:
output = args.output
args = general_parser(p)


return getResults_cli(
func=NBIAClient().getBodyPartCounts, Collection=args.collection
)
return getResults_cli(func=NBIAClient().getBodyPartCounts, Collection=args.collection)


def getSeries_cli() -> None:
global query
global output
query = f"series"

p = general_argParser()
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

p.add_argument(
"--collection",
dest="collection",
action="store",
default="",
type=str,
p.add_argument("-c", "--collection", dest="collection", action="store",
default="", type=str,
)

p.add_argument(
"--patientID",
dest="patientID",
action="store",
default="",
type=str,
)
p.add_argument("-p", "--patientID", dest="patientID",
action="store", default="", type=str,)

p.add_argument(
"--studyInstanceUID",
dest="studyInstanceUID",
action="store",
default="",
type=str,
)
p.add_argument("-m", "--modality", dest="modality",
action="store", default="", type=str,)

p.add_argument("-study", "--studyInstanceUID", dest="studyInstanceUID",
action="store", default="", type=str,)

p.add_argument(
"--modality",
dest="modality",
action="store",
default="",
type=str,
)

p.add_argument(
"--seriesInstanceUID",
Expand Down Expand Up @@ -327,18 +253,7 @@ def getSeries_cli() -> None:
type=str,
)



args = p.parse_args()

if args.version:
version()
sys.exit(0)

if args.output:
output = args.output


args = general_parser(p)
return getResults_cli(
func=NBIAClient().getSeries,
Collection=args.collection,
Expand Down Expand Up @@ -403,20 +318,3 @@ def downloadSingleSeries_cli() -> None:
)


def general_argParser():
global query
p = argparse.ArgumentParser(description=f"NBIAToolkit: {query} ")

p.add_argument(
"--output",
dest="output",
action="store",
type=str,
help="Output file (tsv works best)",
)

p.add_argument(
"--version", "-v", action="store_true", help="Print the version number and exit"
)

return p

0 comments on commit d5bd973

Please sign in to comment.