Skip to content

Commit

Permalink
Merge pull request #89 from jjjermiah/52-feature-add-nslt-endpoint
Browse files Browse the repository at this point in the history
52 feature add nslt endpoint
  • Loading branch information
jjjermiah committed Feb 4, 2024
2 parents a6b23ea + 5a992df commit 682e50e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
22 changes: 18 additions & 4 deletions src/nbiatoolkit/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import requests
import time
from typing import Union

from .utils import NBIA_ENDPOINTS

class OAuth2:
"""
Expand Down Expand Up @@ -68,7 +68,11 @@ class OAuth2:
"""

def __init__(
self, username: str = "nbia_guest", password: str = "", client_id: str = "NBIA"
self,
username: str = "nbia_guest",
password: str = "",
client_id: str = "NBIA",
base_url: Union[str, NBIA_ENDPOINTS] = NBIA_ENDPOINTS.BASE_URL,
):
"""
Initialize the OAuth2 class.
Expand All @@ -82,17 +86,26 @@ def __init__(
The password for authentication. Default is an empty string.
client_id : str, optional
The client ID for authentication. Default is "NBIA".
base_url : str or NBIA_ENDPOINTS, optional. Default is NBIA_ENDPOINTS.BASE_URL
"""
self.client_id = client_id
self.username = username
self.password = password

if isinstance(base_url, NBIA_ENDPOINTS):
self.base_url = base_url.value
else:
self.base_url = base_url

self.access_token = None
self.api_headers = None
self.expiry_time = None
self.refresh_token = None
self.refresh_expiry = None
self.scope = None


def getToken(self) -> Union[dict, None]:
"""
Retrieves the access token from the API.
Expand All @@ -115,14 +128,15 @@ def getToken(self) -> Union[dict, None]:
return None if self.access_token == None else self.access_token

# Prepare the request data
data = {
data: dict[str, str] = {
"username": self.username,
"password": self.password,
"client_id": self.client_id,
"grant_type": "password",
}
token_url = "https://services.cancerimagingarchive.net/nbia-api/oauth/token"
token_url: str = self.base_url + "oauth/token"

response : requests.models.Response
response = requests.post(token_url, data=data)

try:
Expand Down
87 changes: 55 additions & 32 deletions src/nbiatoolkit/nbia.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from .auth import OAuth2
from .logger.logger import setup_logger
from logging import Logger
from .utils import NBIA_ENDPOINTS, validateMD5, clean_html, convertMillis, convertDateFormat
from .dicomsort import DICOMSorter

import requests
from requests.exceptions import JSONDecodeError as JSONDecodeError
from typing import Union
from typing import Union, LiteralString
import io
import zipfile
from tqdm import tqdm
from pyfiglet import Figlet

import os
from datetime import datetime
# set __version__ variable
Expand All @@ -28,47 +28,70 @@ class NBIAClient:
"""

def __init__(
self, username: str = "nbia_guest", password: str = "", log_level: str = "INFO"
self,
username: str = "nbia_guest",
password: str = "",
log_level: str = "INFO"
) -> None:
# Setup logger
self.log = setup_logger(
self._log: Logger = setup_logger(
name="NBIAClient", log_level=log_level, console_logging=True, log_file=None
)

# Setup OAuth2 client
self.log.debug("Setting up OAuth2 client... with username %s", username)

self._log.debug("Setting up OAuth2 client... with username %s", username)
self._oauth2_client = OAuth2(username=username, password=password)

try:
self._api_headers = self._oauth2_client.getToken()
except Exception as e:
self.log.error("Error retrieving access token: %s", e)
self._log.error("Error retrieving access token: %s", e)
self._api_headers = None
raise e

self._base_url : NBIA_ENDPOINTS = NBIA_ENDPOINTS.BASE_URL

@property
def headers(self):
return self._api_headers

# create a setter for the base_url in case user want to use NLST
@property
def base_url(self) -> NBIA_ENDPOINTS:
return self._base_url

@base_url.setter
def base_url(self, nbia_url: NBIA_ENDPOINTS) -> None:
self._base_url = nbia_url

@property
def logger(self) -> Logger:
return self._log

@logger.setter
def logger(self, logger: Logger) -> None:
self._log = logger


def query_api(
self, endpoint: NBIA_ENDPOINTS, params: dict = {}
) -> Union[list, dict, bytes]:
query_url = NBIA_ENDPOINTS.BASE_URL.value + endpoint.value

self.log.debug("Querying API endpoint: %s", query_url)
self.log.debug("Query parameters: %s", params)
# query_url = NBIA_ENDPOINTS.BASE_URL.value + endpoint.value
query_url: LiteralString = self._base_url.value + endpoint.value

self._log.debug("Querying API endpoint: %s", query_url)
self._log.debug("Query parameters: %s", params)
response: requests.Response
try:
response = requests.get(url=query_url, headers=self.headers, params=params)
response.raise_for_status() # Raise an HTTPError for bad responses
except requests.exceptions.RequestException as e:
self.log.error("Error querying API: %s", e)
self._log.error("Error querying API: %s", e)
raise e

if response.status_code != 200:
self.log.error(
self._log.error(
"Error querying API: %s %s", response.status_code, response.reason
)
raise requests.exceptions.RequestException(
Expand All @@ -85,21 +108,21 @@ def query_api(
response_data: bytes = response.content
return response_data
except JSONDecodeError as j:
self.log.debug("Response: %s", response.text)
self._log.debug("Response: %s", response.text)
if response.text == "":
self.log.error("Response text is empty.")
self._log.error("Response text is empty.")
else:
self.log.error("Error parsing response as JSON: %s", j)
self._log.error("Error parsing response as JSON: %s", j)
raise j
except Exception as e:
self.log.error("Error querying API: %s", e)
self._log.error("Error querying API: %s", e)
raise e

def getCollections(self, prefix: str = "") -> Union[list[str], None]:
response = self.query_api(NBIA_ENDPOINTS.GET_COLLECTIONS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

collections = []
Expand Down Expand Up @@ -140,7 +163,7 @@ def getModalityValues(
)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

modalities = []
Expand All @@ -155,7 +178,7 @@ def getPatients(self, Collection: str = "") -> Union[list[dict[str, str]], None]

response = self.query_api(endpoint=NBIA_ENDPOINTS.GET_PATIENTS, params=PARAMS)
if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

patientList = []
Expand Down Expand Up @@ -218,7 +241,7 @@ def getPatientsByCollectionAndModality(
params=PARAMS,
)
if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

patientList = [_["PatientId"] for _ in response]
Expand All @@ -231,7 +254,7 @@ def getCollectionPatientCount(
response = self.query_api(NBIA_ENDPOINTS.GET_COLLECTION_PATIENT_COUNT)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

patientCounts = []
Expand All @@ -256,7 +279,7 @@ def getBodyPartCounts(
)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

bodyparts = []
Expand All @@ -277,7 +300,7 @@ def getStudies(
response = self.query_api(endpoint=NBIA_ENDPOINTS.GET_STUDIES, params=PARAMS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

return response
Expand All @@ -298,7 +321,7 @@ def getSeries(
response = self.query_api(endpoint=NBIA_ENDPOINTS.GET_SERIES, params=PARAMS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

return response
Expand All @@ -321,7 +344,7 @@ def getSeriesMetadata(
)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

metadata.extend(response)
Expand All @@ -343,7 +366,7 @@ def getNewSeries(
response = self.query_api(endpoint=NBIA_ENDPOINTS.GET_UPDATED_SERIES, params=PARAMS)

if not isinstance(response, list):
self.log.error("Expected list, but received: %s", type(response))
self._log.error("Expected list, but received: %s", type(response))
return None

return response
Expand Down Expand Up @@ -412,13 +435,13 @@ def _downloadSingleSeries(
params = dict()
params["SeriesInstanceUID"] = SeriesInstanceUID

self.log.debug("Downloading series: %s", SeriesInstanceUID)
self._log.debug("Downloading series: %s", SeriesInstanceUID)
response = self.query_api(
endpoint=NBIA_ENDPOINTS.DOWNLOAD_SERIES, params=params
)

if not isinstance(response, bytes):
self.log.error(f"Expected binary data, but received: {type(response)}")
self._log.error(f"Expected binary data, but received: {type(response)}")
return False

file = zipfile.ZipFile(io.BytesIO(response))
Expand All @@ -429,7 +452,7 @@ def _downloadSingleSeries(
try:
validateMD5(seriesDir=tempDir)
except Exception as e:
self.log.error("Error validating MD5 hash: %s", e)
self._log.error("Error validating MD5 hash: %s", e)
return False

# Create an instance of DICOMSorter with the desired target pattern
Expand All @@ -442,7 +465,7 @@ def _downloadSingleSeries(
)
# sorter.sortDICOMFiles(option="move", overwrite=overwrite)
if not sorter.sortDICOMFiles(option="move", overwrite=overwrite):
self.log.error(
self._log.error(
"Error sorting DICOM files for series %s\n \
failed files located at %s",
SeriesInstanceUID,
Expand All @@ -455,7 +478,7 @@ def _downloadSingleSeries(
# parsePARAMS is a helper function that takes a locals() dict and returns
# a dict with only the non-empty values
def parsePARAMS(self, params: dict) -> dict:
self.log.debug("Parsing params: %s", params)
self._log.debug("Parsing params: %s", params)
PARAMS = dict()
for key, value in params.items():
if (value != "") and (key != "self"):
Expand Down
1 change: 1 addition & 0 deletions src/nbiatoolkit/utils/nbia_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class NBIA_ENDPOINTS(Enum):
"""

BASE_URL = "https://services.cancerimagingarchive.net/nbia-api/services/"
NLST_URL = "https://nlst.cancerimagingarchive.net/nbia-api/services/"

GET_COLLECTIONS = "v2/getCollectionValues"
GET_COLLECTION_PATIENT_COUNT = "getCollectionValuesAndCounts"
Expand Down

0 comments on commit 682e50e

Please sign in to comment.