Skip to content

Commit

Permalink
clean up the curl download logic
Browse files Browse the repository at this point in the history
Summary:
Create `curl_plugin_shared.download_file` to get rid of some duplication
between curl fetch and reup. Add progress printing to this function, so
that curl now plays nicely with the fancy terminal.

Test Plan: New tests included for the shared logic.

Reviewers: sean

Differential Revision: https://phabricator.buildinspace.com/D95
  • Loading branch information
oconnor663 committed Oct 1, 2014
1 parent 09bb6ab commit f294542
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 30 deletions.
35 changes: 35 additions & 0 deletions peru/resources/plugins/curl/curl_plugin_shared.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import os
import re
from urllib.parse import urlsplit
Expand All @@ -14,3 +15,37 @@ def get_request_filename(request):
return piece[len('filename='):]
# If no filename was specified, pick a reasonable default.
return os.path.basename(urlsplit(request.url).path) or 'index.html'


def format_bytes(num_bytes):
for threshold, unit in ((10**9, 'GB'), (10**6, 'MB'), (10**3, 'KB')):
if num_bytes >= threshold:
# Truncate floats instead of rounding.
float_str = str(num_bytes / threshold)
decimal_index = float_str.index('.')
truncated_float = float_str[:decimal_index+2]
return truncated_float + unit
return '{}B'.format(num_bytes)


def download_file(request, output_file):
digest = hashlib.sha1()
file_size_str = request.info().get('Content-Length')
file_size = int(file_size_str) if file_size_str is not None else None
bytes_read = 0
while True:
buf = request.read(4096)
if not buf:
break
digest.update(buf)
if output_file:
output_file.write(buf)
bytes_read += len(buf)
percentage = ''
kb_downloaded = format_bytes(bytes_read)
total_kb = ''
if file_size:
percentage = ' {}%'.format(round(100 * bytes_read / file_size))
total_kb = '/' + format_bytes(file_size)
print('downloaded{} {}{}'.format(percentage, kb_downloaded, total_kb))
return digest.hexdigest()
19 changes: 6 additions & 13 deletions peru/resources/plugins/curl/fetch.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
#! /usr/bin/env python3

import hashlib
import os
import sys
import urllib.request

from curl_plugin_shared import get_request_filename
import curl_plugin_shared


url = os.environ['PERU_MODULE_URL']
sha1 = os.environ['PERU_MODULE_SHA1']
filename = os.environ['PERU_MODULE_FILENAME']

digest = hashlib.sha1()
with urllib.request.urlopen(url) as request:
if not filename:
filename = get_request_filename(request)
filename = curl_plugin_shared.get_request_filename(request)
full_filepath = os.path.join(os.environ['PERU_FETCH_DEST'], filename)
with open(full_filepath, 'wb') as outfile:
while True:
buf = request.read(4096)
if not buf:
break
outfile.write(buf)
digest.update(buf)
with open(full_filepath, 'wb') as output_file:
digest = curl_plugin_shared.download_file(request, output_file)

if sha1 and digest.hexdigest() != sha1:
if sha1 and digest != sha1:
print('Bad checksum!\n url: {}\nexpected: {}\n actual: {}'
.format(url, sha1, digest.hexdigest()), file=sys.stderr)
.format(url, sha1, digest), file=sys.stderr)
sys.exit(1)
12 changes: 4 additions & 8 deletions peru/resources/plugins/curl/reup.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
#! /usr/bin/env python3

import hashlib
import os
import urllib.request

import curl_plugin_shared

reup_output = os.environ['PERU_REUP_OUTPUT']

url = os.environ['PERU_MODULE_URL']
sha1 = os.environ['PERU_MODULE_SHA1']

digest = hashlib.sha1()
with urllib.request.urlopen(url) as request:
while True:
buf = request.read(4096)
if not buf:
break
digest.update(buf)
digest = curl_plugin_shared.download_file(request, None)

with open(reup_output, 'w') as output_file:
print('sha1:', digest.hexdigest(), file=output_file)
print('sha1:', digest, file=output_file)
63 changes: 55 additions & 8 deletions tests/test_curl_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import contextlib
import hashlib
import imp
import io
from os.path import abspath, join, dirname
import unittest

Expand All @@ -10,16 +13,31 @@
curl_plugin_shared = imp.load_source('_curl_plugin_shared', CURL_SHARED_PATH)


class CurlPluginTest(unittest.TestCase):
def test_get_request_filename(self):
class MockRequest:
_info = {}
class MockRequest:
def __init__(self, url, info, response):
self.url = url
self._info = info
self._response_buffer = io.BytesIO(response)

def info(self):
return self._info

def info(self):
return self._info
def read(self, *args):
return self._response_buffer.read(*args)

request = MockRequest()
request.url = 'http://www.example.com/'

class CurlPluginTest(unittest.TestCase):
def test_format_bytes(self):
self.assertEqual('0B', curl_plugin_shared.format_bytes(0))
self.assertEqual('999B', curl_plugin_shared.format_bytes(999))
self.assertEqual('1.0KB', curl_plugin_shared.format_bytes(1000))
self.assertEqual('999.9KB', curl_plugin_shared.format_bytes(999999))
self.assertEqual('1.0MB', curl_plugin_shared.format_bytes(10**6))
self.assertEqual('1.0GB', curl_plugin_shared.format_bytes(10**9))
self.assertEqual('1000.0GB', curl_plugin_shared.format_bytes(10**12))

def test_get_request_filename(self):
request = MockRequest('http://www.example.com/', {}, b'junk')
self.assertEqual('index.html',
curl_plugin_shared.get_request_filename(request))
request.url = 'http://www.example.com/foo'
Expand All @@ -29,3 +47,32 @@ def info(self):
'attachment; filename=bar'}
self.assertEqual('bar',
curl_plugin_shared.get_request_filename(request))

def test_download_file_with_length(self):
content = b'xy' * 4096
request = MockRequest(
'some url',
{'Content-Length': len(content)},
content)
stdout = io.StringIO()
output_file = io.BytesIO()
with contextlib.redirect_stdout(stdout):
sha1 = curl_plugin_shared.download_file(request, output_file)
self.assertEqual(
'downloaded 50% 4.0KB/8.1KB\ndownloaded 100% 8.1KB/8.1KB\n',
stdout.getvalue())
self.assertEqual(content, output_file.getvalue())
self.assertEqual(hashlib.sha1(content).hexdigest(), sha1)

def test_download_file_without_length(self):
content = b'foo'
request = MockRequest('some url', {}, content)
stdout = io.StringIO()
output_file = io.BytesIO()
with contextlib.redirect_stdout(stdout):
sha1 = curl_plugin_shared.download_file(request, output_file)
self.assertEqual(
'downloaded 3B\n',
stdout.getvalue())
self.assertEqual(content, output_file.getvalue())
self.assertEqual(hashlib.sha1(content).hexdigest(), sha1)
2 changes: 1 addition & 1 deletion tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_curl_plugin_fetch(self):
# Now run it with the wrong hash, and confirm that there's an error.
fields['sha1'] = 'wrong hash'
with self.assertRaises(plugin.PluginRuntimeError):
self.do_plugin_test('curl', fields, curl_content)
self.do_plugin_test('curl', fields, {'newname': 'content'})

def test_curl_plugin_reup(self):
curl_content = {'myfile': 'content'}
Expand Down

0 comments on commit f294542

Please sign in to comment.