Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloud Fetch download handler #127

Merged
merged 6 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import logging

import requests
import lz4.frame
import threading
import time

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)


class ResultSetDownloadHandler(threading.Thread):
def __init__(
self,
downloadable_result_settings,
t_spark_arrow_result_link: TSparkArrowResultLink,
):
super().__init__()
self.settings = downloadable_result_settings
self.result_link = t_spark_arrow_result_link
self.is_download_scheduled = False
self.is_download_finished = threading.Event()
self.is_file_downloaded_successfully = False
self.is_link_expired = False
self.is_download_timedout = False
self.result_file = None

def is_file_download_successful(self) -> bool:
"""
Check and report if cloud fetch file downloaded successfully.

This function will block until a file download finishes or until a timeout.
"""
timeout = self.settings.download_timeout
timeout = timeout if timeout and timeout > 0 else None
try:
if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
)
return False
except Exception as e:
logger.error(e)
return False
return self.is_file_downloaded_successfully

def run(self):
"""
Download the file described in the cloud fetch link.

This function checks if the link has or is expiring, gets the file via a requests session, decompresses the
file, and signals to waiting threads that the download is finished and whether it was successful.
"""
self._reset()

# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)

if not response.ok:
self.is_file_downloaded_successfully = False
return

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
ResultSetDownloadHandler.decompress_data(compressed_data)
if self.settings.is_lz4_compressed
else compressed_data
)
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
)
except Exception as e:
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()

def _reset(self):
"""
Reset download-related flags for every retry of run()
"""
self.is_file_downloaded_successfully = False
self.is_link_expired = False
self.is_download_timedout = False
self.is_download_finished = threading.Event()

@staticmethod
def check_link_expired(
link: TSparkArrowResultLink, expiry_buffer_secs: int
) -> bool:
"""
Check if a link has expired or will expire.

Expiry buffer can be set to avoid downloading files that has not expired yet when the function is called,
but may expire before the file has fully downloaded.
"""
current_time = int(time.time())
if (
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
return True
return False

@staticmethod
def decompress_data(compressed_data: bytes) -> bytes:
"""
Decompress lz4 frame compressed data.

Decompresses data that has been lz4 compressed, either via the whole frame or by series of chunks.
"""
uncompressed_data, bytes_read = lz4.frame.decompress(
compressed_data, return_bytes_read=True
)
# The last cloud fetch file of the entire result is commonly punctuated by frequent end-of-frame markers.
# Full frame decompression above will short-circuit, so chunking is necessary
if bytes_read < len(compressed_data):
d_context = lz4.frame.create_decompression_context()
start = 0
uncompressed_data = bytearray()
while start < len(compressed_data):
data, num_bytes, is_end = lz4.frame.decompress_chunk(
d_context, compressed_data[start:]
)
uncompressed_data += data
start += num_bytes
return uncompressed_data
155 changes: 155 additions & 0 deletions tests/unit/test_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import unittest
from unittest.mock import Mock, patch, MagicMock

import databricks.sql.cloudfetch.downloader as downloader


class DownloaderTests(unittest.TestCase):
"""
Unit tests for checking downloader logic.
"""

@patch('time.time', return_value=1000)
def test_run_link_expired(self, mock_time):
settings = Mock()
result_link = Mock()
# Already expired
result_link.expiryTime = 999
d = downloader.ResultSetDownloadHandler(settings, result_link)
assert not d.is_link_expired
d.run()
assert d.is_link_expired
mock_time.assert_called_once()

@patch('time.time', return_value=1000)
def test_run_link_past_expiry_buffer(self, mock_time):
settings = Mock(link_expiry_buffer_secs=5)
result_link = Mock()
# Within the expiry buffer time
result_link.expiryTime = 1004
d = downloader.ResultSetDownloadHandler(settings, result_link)
assert not d.is_link_expired
d.run()
assert d.is_link_expired
mock_time.assert_called_once()

@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False))))
@patch('time.time', return_value=1000)
def test_run_get_response_not_ok(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0)
settings.download_timeout = 0
settings.use_proxy = False
result_link = Mock(expiryTime=1001)

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert not d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session',
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9))))
@patch('time.time', return_value=1000)
def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False)
result_link = Mock(bytesNum=100, expiryTime=1001)

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert not d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))))
@patch('time.time', return_value=1000)
def test_run_compressed_data_length_incorrect(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings.is_lz4_compressed = True
result_link = Mock(bytesNum=100, expiryTime=1001)
mock_session.return_value.get.return_value.content = \
b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00'

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert not d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session',
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10))))
@patch('time.time', return_value=1000)
def test_run_uncompressed_successful(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings.is_lz4_compressed = False
result_link = Mock(bytesNum=100, expiryTime=1001)

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert d.result_file == b"1234567890" * 10
assert d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))))
@patch('time.time', return_value=1000)
def test_run_compressed_successful(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings.is_lz4_compressed = True
result_link = Mock(bytesNum=100, expiryTime=1001)
mock_session.return_value.get.return_value.content = \
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert d.result_file == b"1234567890" * 10
assert d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session.get', side_effect=ConnectionError('foo'))
@patch('time.time', return_value=1000)
def test_download_connection_error(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True)
result_link = Mock(bytesNum=100, expiryTime=1001)
mock_session.return_value.get.return_value.content = \
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert not d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session.get', side_effect=TimeoutError('foo'))
@patch('time.time', return_value=1000)
def test_download_timeout(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True)
result_link = Mock(bytesNum=100, expiryTime=1001)
mock_session.return_value.get.return_value.content = \
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()

assert not d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch("threading.Event.wait", return_value=True)
def test_is_file_download_successful_has_finished(self, mock_wait):
for timeout in [None, 0, 1]:
with self.subTest(timeout=timeout):
settings = Mock(download_timeout=timeout)
result_link = Mock()
handler = downloader.ResultSetDownloadHandler(settings, result_link)

status = handler.is_file_download_successful()
assert status == handler.is_file_downloaded_successfully

def test_is_file_download_successful_times_outs(self):
settings = Mock(download_timeout=1)
result_link = Mock()
handler = downloader.ResultSetDownloadHandler(settings, result_link)

status = handler.is_file_download_successful()
assert not status
assert handler.is_download_timedout