Skip to content

Commit

Permalink
Merge pull request #6 from catlee/master
Browse files Browse the repository at this point in the history
Code cleanup, and adding tests
  • Loading branch information
Rail Aliiev committed Jan 19, 2015
2 parents b09a6e1 + 72954cf commit e9c1ef4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 27 deletions.
39 changes: 14 additions & 25 deletions mar/mar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@
import hashlib
import tempfile
from subprocess import Popen, PIPE
from functools import partial

import logging
log = logging.getLogger(__name__)


def read_file(fp, blocksize=8192):
"""Yields blocks of data from file object fp"""
for block in iter(partial(fp.read, blocksize), b''):
yield block


def rsa_sign(digest, keyfile):
proc = Popen(['openssl', 'pkeyutl', '-sign', '-inkey', keyfile],
stdin=PIPE, stdout=PIPE)
Expand Down Expand Up @@ -78,10 +85,7 @@ def generate_signature(fp, updatefunc):
fp.read(sigsize)

# Read the rest of the file
while True:
block = fp.read(512 * 1024)
if not block:
break
for block in read_file(fp, 512 * 1024):
updatefunc(block)


Expand Down Expand Up @@ -184,7 +188,7 @@ def __repr__(self):

def to_bytes(self):
return struct.pack(self._member_fmt, self._offset, self.size, self.flags) + \
self.name + "\x00"
self.name.encode("ascii") + b"\x00"


class MarFile:
Expand Down Expand Up @@ -232,7 +236,7 @@ def __init__(self, name, mode="r", signature_versions=[]):
self.index_offset += 4 + 8

# Write the magic and placeholder for the index
self.fileobj.write("MAR1" + packint(self.index_offset))
self.fileobj.write(b"MAR1" + packint(self.index_offset))

# Write placeholder for file size
self.fileobj.write(struct.pack(">Q", 0))
Expand Down Expand Up @@ -341,10 +345,7 @@ def add(self, path, name=None, fileobj=None, flags=None):

f = open(path, 'rb')
self.fileobj.seek(self.index_offset)
while True:
block = f.read(512 * 1024)
if not block:
break
for block in read_file(f, 512 * 1024):
self.fileobj.write(block)
else:
assert flags
Expand All @@ -353,10 +354,7 @@ def add(self, path, name=None, fileobj=None, flags=None):
info.flags = flags
info._offset = self.index_offset
self.fileobj.seek(self.index_offset)
while True:
block = fileobj.read(512 * 1024)
if not block:
break
for block in read_file(fileobj, 512 * 1024):
info.size += len(block)
self.fileobj.write(block)

Expand Down Expand Up @@ -396,13 +394,6 @@ def close(self):
self.fileobj.close()
self.fileobj = None

def __del__(self):
"""Close the file when we're garbage collected"""
try:
self.close()
except IOError:
pass

def __enter__(self):
return self

Expand Down Expand Up @@ -444,6 +435,7 @@ def extract(self, member, path="."):
os.makedirs(dirname)

self.fileobj.seek(member._offset)
# TODO: Should this be done all in memory?
open(dstpath, "wb").write(self.fileobj.read(member.size))
os.chmod(dstpath, member.flags)

Expand Down Expand Up @@ -506,10 +498,7 @@ def add(self, path, name=None, fileobj=None, mode=None):
f = fileobj
comp = bz2.BZ2Compressor(9)
self.fileobj.seek(self.index_offset)
while True:
block = f.read(512 * 1024)
if not block:
break
for block in read_file(f, 512 * 1024):
block = comp.compress(block)
info.size += len(block)
self.fileobj.write(block)
Expand Down
53 changes: 51 additions & 2 deletions tests/test_mar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
import tempfile
import hashlib

from mar.mar import MarFile
from mar.mar import MarFile, BZ2MarFile, read_file

TEST_MAR = os.path.join(os.path.dirname(__file__), 'test.mar')


def test_read_file():
data = []
for block in read_file(open(__file__, 'rb')):
data.append(block)
assert b''.join(data) == open(__file__, 'rb').read()


def sha1sum(b):
"""Returns the sha1sum of a byte string"""
h = hashlib.new('sha1')
Expand All @@ -20,8 +28,12 @@ def test_list():
assert repr(m.members[0]) == "<update.manifest 664 141 bytes starting at 392>", m.members[0]
assert repr(m.members[1]) == "<defaults/pref/channel-prefs.js 664 76 bytes starting at 533>", m.members[1]

m = BZ2MarFile(TEST_MAR)
assert repr(m.members[0]) == "<update.manifest 664 141 bytes starting at 392>", m.members[0]
assert repr(m.members[1]) == "<defaults/pref/channel-prefs.js 664 76 bytes starting at 533>", m.members[1]


class TestMar(TestCase):
class TestReadingMar(TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.marfile = MarFile(TEST_MAR)
Expand All @@ -43,6 +55,43 @@ def test_extract(self):
self.assertEquals("6a7890e740f1e18a425b51fefbde2f6b86f91a12", h)


class TestReadingBZ2Mar(TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.marfile = BZ2MarFile(TEST_MAR)

def tearDown(self):
shutil.rmtree(self.tmpdir)

def test_extract_bz2(self):
m = self.marfile.members[0]
self.marfile.extract(m, self.tmpdir)
fn = os.path.join(self.tmpdir, m.name)

# The size in the manifest is of the compressed data, so we need to
# check that we've extracted the correct number of uncompressed bytes
# here
self.assertEquals(os.path.getsize(fn), 308)

# Check that the contents match
data = open(fn, 'rb').read()
h = sha1sum(data)
self.assertEquals("5177f5938923e94820d8565a1a0f25d19b4821d1", h)


class TestWritingMar(TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.tmpdir)

def test_add(self):
marfile = os.path.join(self.tmpdir, 'test.mar')
with MarFile(marfile, 'w') as m:
m.add(__file__)


class TestExceptions(TestCase):
def test_badmar(self):
self.assertRaises(ValueError, MarFile, __file__)

0 comments on commit e9c1ef4

Please sign in to comment.