Skip to content

Commit

Permalink
fix: add better type checks to fix #22
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjermiah committed Jan 28, 2024
1 parent 518ffe3 commit f8aaa61
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 33 deletions.
6 changes: 3 additions & 3 deletions src/nbiatoolkit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self.refresh_expiry = None
self.scope = None

def getToken(self) -> Union[dict, int]:
def getToken(self) -> Union[dict, None]:
"""
Retrieves the access token from the API.
Expand All @@ -112,7 +112,7 @@ def getToken(self) -> Union[dict, int]:
"""
# Check if the access token is valid and not expired
if self.access_token is not None:
return 401 if self.access_token == -1 else self.access_token
return None if self.access_token == None else self.access_token

# Prepare the request data
data = {
Expand All @@ -129,7 +129,7 @@ def getToken(self) -> Union[dict, int]:
response = requests.post(token_url, data=data)
response.raise_for_status() # Raise an HTTPError for bad responses
except requests.exceptions.RequestException as e:
self.access_token = -1
self.access_token = None
raise requests.exceptions.RequestException(
f"Failed to get access token. Status code:\
{response.status_code}"
Expand Down
10 changes: 6 additions & 4 deletions src/nbiatoolkit/dicomsort/dicomsort.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import re, os, sys, shutil
import pydicom

from pydicom.filereader import InvalidDicomError
from pydicom.errors import InvalidDicomError

from .helper_functions import parseDICOMKeysFromFormat, sanitizeFileName, truncateUID

from typing import Optional

class DICOMSorter:
def __init__(
self,
sourceDir: str,
destinationDir: str,
targetPattern: str = None,
targetPattern: str = "%PatientName/%SeriesNumber-%SeriesInstanceUID/%InstanceNumber.dcm",
truncateUID: bool = True,
sanitizeFilename: bool = True,
):
Expand All @@ -25,7 +26,7 @@ def __init__(
self.sanitizeFilename = sanitizeFilename

def generateFilePathFromDICOMAttributes(
self, dataset: pydicom.dataset.FileDataset
self, dataset: pydicom.FileDataset
) -> str:
"""
Generate a file path for the DICOM file by formatting DICOM attributes.
Expand Down Expand Up @@ -57,7 +58,8 @@ def sortSingleDICOMFile(
assert option in ["copy", "move"], "Invalid option: symlink not implemented yet"

try:
dataset = pydicom.dcmread(filePath, stop_before_pixels=True)

dataset : pydicom.FileDataset = pydicom.dcmread(filePath, stop_before_pixels=True)
except InvalidDicomError as e:
print(f"Error reading file {filePath}: {e}")
return False
Expand Down
84 changes: 65 additions & 19 deletions src/nbiatoolkit/nbia.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,83 @@ def __init__(
self.log.debug("Setting up OAuth2 client... with username %s", username)

self._oauth2_client = OAuth2(username=username, password=password)
self._api_headers = self._oauth2_client.getToken()

try:
self._api_headers = self._oauth2_client.getToken()
except Exception as e:
self.log.error("Error retrieving access token: %s", e)
self._api_headers = None
raise e

@property
def headers(self):
return self._api_headers

def query_api(self, endpoint: NBIA_ENDPOINTS, params: dict = {}) -> dict:
def query_api(self, endpoint: NBIA_ENDPOINTS, params: dict = {}) -> Union[list, dict, bytes]:
query_url = NBIA_ENDPOINTS.BASE_URL.value + endpoint.value

self.log.debug("Querying API endpoint: %s", query_url)
response : requests.Response
try:
response = requests.get(url=query_url, headers=self.headers, params=params)
response.raise_for_status() # Raise an HTTPError for bad responses
except requests.exceptions.RequestException as e:
self.log.error("Error querying API: %s", e)
raise e

if response.status_code != 200:
self.log.error(
"Error querying API: %s %s", response.status_code, response.reason
)
raise requests.exceptions.RequestException(
f"Failed to get access token. Status code:\
{response.status_code}"
)

try:
if response.headers.get("Content-Type") == "application/json":
response_data = response.json()
response_json : dict | list = response.json()
return response_json
else:
# If response is binary data, return raw response
response_data = response.content
response_data : bytes = response.content
return response_data
except JSONDecodeError as j:
self.log.debug("Response: %s", response.text)
if response.text == "":
self.log.error("Response text is empty.")
return response
self.log.error("Error parsing response as JSON: %s", j)
self.log.debug("Response: %s", response.text)
else:
self.log.error("Error parsing response as JSON: %s", j)
raise j
except Exception as e:
self.log.error("Error querying API: %s", e)
raise e

return response_data

def getCollections(self, prefix: str = "") -> list[dict[str]]:

def getCollections(self, prefix: str = "") -> Union[list[str], None]:
response = self.query_api(NBIA_ENDPOINTS.GET_COLLECTIONS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
return None

collections = []
for collection in response:
name = collection["Collection"]
if name.lower().startswith(prefix.lower()):
collections.append(name)
return collections


# returns a list of dictionaries with the collection name and patient count
def getCollectionPatientCount(self, prefix: str = "") -> list[dict[str, int]]:
def getCollectionPatientCount(self, prefix: str = "") -> Union[list[dict[str, int]], None]:
response = self.query_api(NBIA_ENDPOINTS.GET_COLLECTION_PATIENT_COUNT)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
return None

patientCounts = []
for collection in response:
name = collection["criteria"]
Expand All @@ -86,24 +122,30 @@ def getCollectionPatientCount(self, prefix: str = "") -> list[dict[str, int]]:
"PatientCount": int(collection["count"]),
}
)

return patientCounts

def getBodyPartCounts(self, Collection: str = "", Modality: str = "") -> list:
def getBodyPartCounts(self, Collection: str = "", Modality: str = "") -> Union[list[dict[str, int]], None]:
PARAMS = self.parsePARAMS(locals())

response = self.query_api(
endpoint=NBIA_ENDPOINTS.GET_BODY_PART_PATIENT_COUNT, params=PARAMS
)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
return None

bodyparts = []
for bodypart in response:
bodyparts.append(
{"BodyPartExamined": bodypart["criteria"], "Count": bodypart["count"]}
{
"BodyPartExamined": bodypart["criteria"],
"Count": int(bodypart["count"]),
}
)
return bodyparts

def getPatients(self, Collection: str, Modality: str) -> list:
def getPatients(self, Collection: str, Modality: str) -> Union[list[str], None]:
assert Collection is not None
assert Modality is not None

Expand All @@ -113,6 +155,9 @@ def getPatients(self, Collection: str, Modality: str) -> list:
endpoint=NBIA_ENDPOINTS.GET_PATIENT_BY_COLLECTION_AND_MODALITY,
params=PARAMS,
)
if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
return None

patientList = [_["PatientId"] for _ in response]
return patientList
Expand All @@ -127,15 +172,16 @@ def getSeries(
BodyPartExamined: str = "",
ManufacturerModelName: str = "",
Manufacturer: str = "",
) -> list:
PARAMS = dict()
) -> Union[list[dict[str, str]], None]:

for key, value in locals().items():
if (value != "") and (key != "self"):
PARAMS[key] = value
PARAMS = self.parsePARAMS(locals())

response = self.query_api(endpoint=NBIA_ENDPOINTS.GET_SERIES, params=PARAMS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
return None

return response

def downloadSeries(
Expand Down
13 changes: 6 additions & 7 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def test_failed_oauth(failed_oauth2):
# should raise requests.exceptions.RequestException
with pytest.raises(requests.exceptions.RequestException):
failed_oauth2.getToken()
assert failed_oauth2.getToken() == 401
assert failed_oauth2.access_token == -1
assert failed_oauth2.token == -1
assert failed_oauth2.getToken() == 401
assert failed_oauth2.getToken() is None
assert failed_oauth2.access_token is None
assert failed_oauth2.token is None
assert failed_oauth2.api_headers is None
assert failed_oauth2.expiry_time is None
assert failed_oauth2.refresh_token is None
Expand All @@ -47,9 +46,9 @@ def test_failed_oauth(failed_oauth2):
def test_getToken_valid_token(oauth2):
# Test if the access token is valid and not expired
assert oauth2.getToken() == oauth2.access_token
assert oauth2.getToken() != 401
assert oauth2.access_token != -1
assert oauth2.token != -1
assert oauth2.getToken() is not None
assert oauth2.access_token is not None
assert oauth2.token is not None
assert oauth2.headers is not None
assert oauth2.expiry_time is not None
assert oauth2.refresh_token is not None
Expand Down

0 comments on commit f8aaa61

Please sign in to comment.