Skip to content

Commit

Permalink
add better error message for too-long URI (#881)
Browse files Browse the repository at this point in the history
* add better error message for too-long URI

* improve error handling

* improve data download function, fix bugs

* stricter API, more private methods

* incorporate Pieter's feedback
  • Loading branch information
mfeurer authored and PGijsbers committed Nov 14, 2019
1 parent 69d443f commit d79a98c
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 97 deletions.
135 changes: 105 additions & 30 deletions openml/_api_calls.py
@@ -1,15 +1,15 @@
# License: BSD 3-Clause

import time
import hashlib
import logging
import requests
import warnings
import xmltodict
from typing import Dict
from typing import Dict, Optional

from . import config
from .exceptions import (OpenMLServerError, OpenMLServerException,
OpenMLServerNoResult)
OpenMLServerNoResult, OpenMLHashException)


def _perform_api_call(call, request_method, data=None, file_elements=None):
Expand Down Expand Up @@ -47,20 +47,105 @@ def _perform_api_call(call, request_method, data=None, file_elements=None):
url = url.replace('=', '%3d')
logging.info('Starting [%s] request for the URL %s', request_method, url)
start = time.time()

if file_elements is not None:
if request_method != 'post':
raise ValueError('request method must be post when file elements '
'are present')
response = _read_url_files(url, data=data, file_elements=file_elements)
raise ValueError('request method must be post when file elements are present')
response = __read_url_files(url, data=data, file_elements=file_elements)
else:
response = _read_url(url, request_method, data)
response = __read_url(url, request_method, data)

__check_response(response, url, file_elements)

logging.info(
'%.7fs taken for [%s] request for the URL %s',
time.time() - start,
request_method,
url,
)
return response
return response.text


def _download_text_file(source: str,
output_path: Optional[str] = None,
md5_checksum: str = None,
exists_ok: bool = True,
encoding: str = 'utf8',
) -> Optional[str]:
""" Download the text file at `source` and store it in `output_path`.
By default, do nothing if a file already exists in `output_path`.
The downloaded file can be checked against an expected md5 checksum.
Parameters
----------
source : str
url of the file to be downloaded
output_path : str, (optional)
full path, including filename, of where the file should be stored. If ``None``,
this function returns the downloaded file as string.
md5_checksum : str, optional (default=None)
If not None, should be a string of hexidecimal digits of the expected digest value.
exists_ok : bool, optional (default=True)
If False, raise an FileExistsError if there already exists a file at `output_path`.
encoding : str, optional (default='utf8')
The encoding with which the file should be stored.
"""
if output_path is not None:
try:
with open(output_path, encoding=encoding):
if exists_ok:
return None
else:
raise FileExistsError
except FileNotFoundError:
pass

logging.info('Starting [%s] request for the URL %s', 'get', source)
start = time.time()
response = __read_url(source, request_method='get')
__check_response(response, source, None)
downloaded_file = response.text

if md5_checksum is not None:
md5 = hashlib.md5()
md5.update(downloaded_file.encode('utf-8'))
md5_checksum_download = md5.hexdigest()
if md5_checksum != md5_checksum_download:
raise OpenMLHashException(
'Checksum {} of downloaded file is unequal to the expected checksum {}.'
.format(md5_checksum_download, md5_checksum))

if output_path is None:
logging.info(
'%.7fs taken for [%s] request for the URL %s',
time.time() - start,
'get',
source,
)
return downloaded_file

else:
with open(output_path, "w", encoding=encoding) as fh:
fh.write(downloaded_file)

logging.info(
'%.7fs taken for [%s] request for the URL %s',
time.time() - start,
'get',
source,
)

del downloaded_file
return None


def __check_response(response, url, file_elements):
if response.status_code != 200:
raise __parse_server_exception(response, url, file_elements=file_elements)
elif 'Content-Encoding' not in response.headers or \
response.headers['Content-Encoding'] != 'gzip':
logging.warning('Received uncompressed content from OpenML for {}.'.format(url))


def _file_id_to_url(file_id, filename=None):
Expand All @@ -75,7 +160,7 @@ def _file_id_to_url(file_id, filename=None):
return url


def _read_url_files(url, data=None, file_elements=None):
def __read_url_files(url, data=None, file_elements=None):
"""do a post request to url with data
and sending file_elements as files"""

Expand All @@ -85,37 +170,24 @@ def _read_url_files(url, data=None, file_elements=None):
file_elements = {}
# Using requests.post sets header 'Accept-encoding' automatically to
# 'gzip,deflate'
response = send_request(
response = __send_request(
request_method='post',
url=url,
data=data,
files=file_elements,
)
if response.status_code != 200:
raise _parse_server_exception(response, url, file_elements=file_elements)
if 'Content-Encoding' not in response.headers or \
response.headers['Content-Encoding'] != 'gzip':
warnings.warn('Received uncompressed content from OpenML for {}.'
.format(url))
return response.text
return response


def _read_url(url, request_method, data=None):
def __read_url(url, request_method, data=None):
data = {} if data is None else data
if config.apikey is not None:
data['api_key'] = config.apikey

response = send_request(request_method=request_method, url=url, data=data)
if response.status_code != 200:
raise _parse_server_exception(response, url, file_elements=None)
if 'Content-Encoding' not in response.headers or \
response.headers['Content-Encoding'] != 'gzip':
warnings.warn('Received uncompressed content from OpenML for {}.'
.format(url))
return response.text
return __send_request(request_method=request_method, url=url, data=data)


def send_request(
def __send_request(
request_method,
url,
data,
Expand Down Expand Up @@ -149,16 +221,19 @@ def send_request(
return response


def _parse_server_exception(
def __parse_server_exception(
response: requests.Response,
url: str,
file_elements: Dict,
) -> OpenMLServerError:
# OpenML has a sophisticated error system
# where information about failures is provided. try to parse this

if response.status_code == 414:
raise OpenMLServerError('URI too long! ({})'.format(url))
try:
server_exception = xmltodict.parse(response.text)
except Exception:
# OpenML has a sophisticated error system
# where information about failures is provided. try to parse this
raise OpenMLServerError(
'Unexpected server error when calling {}. Please contact the developers!\n'
'Status code: {}\n{}'.format(url, response.status_code, response.text))
Expand Down
8 changes: 3 additions & 5 deletions openml/datasets/functions.py
Expand Up @@ -886,7 +886,7 @@ def _get_dataset_arff(description: Union[Dict, OpenMLDataset],
output_file_path = os.path.join(cache_directory, "dataset.arff")

try:
openml.utils._download_text_file(
openml._api_calls._download_text_file(
source=url,
output_path=output_file_path,
md5_checksum=md5_checksum_fixture
Expand Down Expand Up @@ -1038,13 +1038,11 @@ def _get_online_dataset_arff(dataset_id):
str
A string representation of an ARFF file.
"""
dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id,
'get')
dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id, 'get')
# build a dict from the xml.
# use the url from the dataset description and return the ARFF string
return openml._api_calls._read_url(
return openml._api_calls._download_text_file(
xmltodict.parse(dataset_xml)['oml:data_set_description']['oml:url'],
request_method='get'
)


Expand Down
3 changes: 1 addition & 2 deletions openml/runs/run.py
Expand Up @@ -327,8 +327,7 @@ def get_metric_fn(self, sklearn_fn, kwargs=None):
predictions_file_url = openml._api_calls._file_id_to_url(
self.output_files['predictions'], 'predictions.arff',
)
response = openml._api_calls._read_url(predictions_file_url,
request_method='get')
response = openml._api_calls._download_text_file(predictions_file_url)
predictions_arff = arff.loads(response)
# TODO: make this a stream reader
else:
Expand Down
10 changes: 4 additions & 6 deletions openml/tasks/task.py
Expand Up @@ -116,12 +116,10 @@ def _download_split(self, cache_file: str):
pass
except (OSError, IOError):
split_url = self.estimation_procedure["data_splits_url"]
split_arff = openml._api_calls._read_url(split_url,
request_method='get')

with io.open(cache_file, "w", encoding='utf8') as fh:
fh.write(split_arff)
del split_arff
openml._api_calls._download_text_file(
source=str(split_url),
output_path=cache_file,
)

def download_split(self) -> OpenMLSplit:
"""Download the OpenML split for a given task.
Expand Down
51 changes: 0 additions & 51 deletions openml/utils.py
@@ -1,7 +1,6 @@
# License: BSD 3-Clause

import os
import hashlib
import xmltodict
import shutil
from typing import TYPE_CHECKING, List, Tuple, Union, Type
Expand Down Expand Up @@ -366,53 +365,3 @@ def _create_lockfiles_dir():
except OSError:
pass
return dir


def _download_text_file(source: str,
output_path: str,
md5_checksum: str = None,
exists_ok: bool = True,
encoding: str = 'utf8',
) -> None:
""" Download the text file at `source` and store it in `output_path`.
By default, do nothing if a file already exists in `output_path`.
The downloaded file can be checked against an expected md5 checksum.
Parameters
----------
source : str
url of the file to be downloaded
output_path : str
full path, including filename, of where the file should be stored.
md5_checksum : str, optional (default=None)
If not None, should be a string of hexidecimal digits of the expected digest value.
exists_ok : bool, optional (default=True)
If False, raise an FileExistsError if there already exists a file at `output_path`.
encoding : str, optional (default='utf8')
The encoding with which the file should be stored.
"""
try:
with open(output_path, encoding=encoding):
if exists_ok:
return
else:
raise FileExistsError
except FileNotFoundError:
pass

downloaded_file = openml._api_calls._read_url(source, request_method='get')

if md5_checksum is not None:
md5 = hashlib.md5()
md5.update(downloaded_file.encode('utf-8'))
md5_checksum_download = md5.hexdigest()
if md5_checksum != md5_checksum_download:
raise openml.exceptions.OpenMLHashException(
'Checksum {} of downloaded file is unequal to the expected checksum {}.'
.format(md5_checksum_download, md5_checksum))

with open(output_path, "w", encoding=encoding) as fh:
fh.write(downloaded_file)

del downloaded_file
12 changes: 12 additions & 0 deletions tests/test_openml/test_api_calls.py
@@ -0,0 +1,12 @@
import openml
import openml.testing


class TestConfig(openml.testing.TestBase):

def test_too_long_uri(self):
with self.assertRaisesRegex(
openml.exceptions.OpenMLServerError,
'URI too long!',
):
openml.datasets.list_datasets(data_id=list(range(10000)))
3 changes: 1 addition & 2 deletions tests/test_runs/test_run_functions.py
Expand Up @@ -119,8 +119,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
# downloads the predictions of the old task
file_id = run.output_files['predictions']
predictions_url = openml._api_calls._file_id_to_url(file_id)
response = openml._api_calls._read_url(predictions_url,
request_method='get')
response = openml._api_calls._download_text_file(predictions_url)
predictions = arff.loads(response)
run_prime = openml.runs.run_model_on_task(
model=model_prime,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_utils.py
Expand Up @@ -16,7 +16,7 @@ class OpenMLTaskTest(TestBase):
def mocked_perform_api_call(call, request_method):
# TODO: JvR: Why is this not a staticmethod?
url = openml.config.server + '/' + call
return openml._api_calls._read_url(url, request_method=request_method)
return openml._api_calls._download_text_file(url)

def test_list_all(self):
openml.utils._list_all(listing_call=openml.tasks.functions._list_tasks)
Expand Down

0 comments on commit d79a98c

Please sign in to comment.