From 57dce44c9814ee5cc5dd1d9a8280745cc8a9f35a Mon Sep 17 00:00:00 2001 From: Hong Minhee Date: Fri, 13 Feb 2015 23:53:09 +0900 Subject: [PATCH] Add save check for AuthorizedKeyList Close #5 --- docs/changes.rst | 2 ++ geofront/remote.py | 38 ++++++++++++++++++-------- tests/remote_test.py | 65 ++++++++++++++++++++++++++------------------ 3 files changed, 67 insertions(+), 38 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index 40c074a..d760989 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -7,6 +7,8 @@ Version 0.3.0 To be released. - Geofront becomes to require Paramiko 1.15.0 or higher. +- Added save check for :class:`~geofront.remote.AuthorizedKeyList`. + [:issue:`5`] Version 0.2.2 diff --git a/geofront/remote.py b/geofront/remote.py index 8ab3010..cb012b5 100644 --- a/geofront/remote.py +++ b/geofront/remote.py @@ -32,6 +32,7 @@ import datetime import io import itertools +import logging import numbers import threading import time @@ -147,9 +148,25 @@ def _iterate_lines(self): if line: yield line - def _save(self, authorized_keys: str): - with io.BytesIO(authorized_keys.encode()) as fo: - self.sftp_client.putfo(fo, self.FILE_PATH) + def _save(self, lines: list, existing_lines=None): + check = frozenset(line.strip() for line in lines) + if existing_lines is None: + mode = 'w' + else: + assert lines[:len(existing_lines)] == existing_lines + lines = lines[len(existing_lines):] + mode = 'a' + with self.sftp_client.open(self.FILE_PATH, mode) as fo: + if mode == 'a': + fo.write('\n') + for line in lines: + fo.write(line + '\n') + actual = frozenset(self._iterate_lines()) + if actual != check: + logger = logging.getLogger(__name__ + '.AuthorizedKeyList._save') + logger.debug('file check error: expected = %r, actual = %r', + check, actual) + raise IOError('failed to write to ' + self.FILE_PATH) def __iter__(self): for line in self._iterate_lines(): @@ -193,7 +210,7 @@ def __setitem__(self, index, value): 'authorized_keys indices must be integers, not ' '{0.__module__}.{0.__qualname__}'.format(type(index)) ) - self._save('\n'.join(lines)) + self._save(lines) def insert(self, index, value): if not isinstance(index, numbers.Integral): @@ -203,14 +220,13 @@ def insert(self, index, value): ) lines = list(self._iterate_lines()) lines.insert(index, format_openssh_pubkey(value)) - self._save('\n'.join(lines)) + self._save(lines) def extend(self, values): - lines = itertools.chain( - self._iterate_lines(), - map(format_openssh_pubkey, values) - ) - self._save('\n'.join(lines)) + existing_lines = list(self._iterate_lines()) + appended_lines = map(format_openssh_pubkey, values) + lines = itertools.chain(existing_lines, appended_lines) + self._save(list(lines), existing_lines=existing_lines) def __delitem__(self, index): if not isinstance(index, (numbers.Integral, slice)): @@ -220,7 +236,7 @@ def __delitem__(self, index): ) lines = list(self._iterate_lines()) del lines[index] - self._save('\n'.join(lines)) + self._save(lines) @typed diff --git a/tests/remote_test.py b/tests/remote_test.py index 9536db6..c226186 100644 --- a/tests/remote_test.py +++ b/tests/remote_test.py @@ -36,6 +36,16 @@ def test_remote(b, equal): assert (hash(a) == hash(b)) is equal +def get_next_line(fo): + line = '' + while not line: + line = fo.readline() + if not line: + return line + line = line.strip() + return line + + def test_authorized_keys_list_iter(fx_authorized_sftp): sftp_client, path, keys = fx_authorized_sftp key_list = AuthorizedKeyList(sftp_client) @@ -98,22 +108,22 @@ def test_authorized_keys_list_setitem(fx_authorized_sftp): key_list[3:] = [] with path.join('.ssh', 'authorized_keys').open() as f: for i in range(3): - assert parse_openssh_pubkey(f.readline().strip()) == keys[i] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[i] + assert not get_next_line(f) # Positive index key_list[2] = keys[3] with path.join('.ssh', 'authorized_keys').open() as f: - assert parse_openssh_pubkey(f.readline().strip()) == keys[0] - assert parse_openssh_pubkey(f.readline().strip()) == keys[1] - assert parse_openssh_pubkey(f.readline().strip()) == keys[3] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[0] + assert parse_openssh_pubkey(get_next_line(f)) == keys[1] + assert parse_openssh_pubkey(get_next_line(f)) == keys[3] + assert not get_next_line(f) # Negative index key_list[-1] = keys[4] with path.join('.ssh', 'authorized_keys').open() as f: - assert parse_openssh_pubkey(f.readline().strip()) == keys[0] - assert parse_openssh_pubkey(f.readline().strip()) == keys[1] - assert parse_openssh_pubkey(f.readline().strip()) == keys[4] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[0] + assert parse_openssh_pubkey(get_next_line(f)) == keys[1] + assert parse_openssh_pubkey(get_next_line(f)) == keys[4] + assert not get_next_line(f) def test_authorized_keys_list_insert(fx_authorized_sftp): @@ -122,12 +132,12 @@ def test_authorized_keys_list_insert(fx_authorized_sftp): new_key = RSAKey.generate(1024) key_list.insert(2, new_key) with path.join('.ssh', 'authorized_keys').open() as f: - assert parse_openssh_pubkey(f.readline().strip()) == keys[0] - assert parse_openssh_pubkey(f.readline().strip()) == keys[1] - assert parse_openssh_pubkey(f.readline().strip()) == new_key + assert parse_openssh_pubkey(get_next_line(f)) == keys[0] + assert parse_openssh_pubkey(get_next_line(f)) == keys[1] + assert parse_openssh_pubkey(get_next_line(f)) == new_key for i in range(2, 6): - assert parse_openssh_pubkey(f.readline().strip()) == keys[i] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[i] + assert not get_next_line(f) def test_authorized_keys_list_extend(fx_authorized_sftp): @@ -137,10 +147,10 @@ def test_authorized_keys_list_extend(fx_authorized_sftp): key_list.extend(new_keys) with path.join('.ssh', 'authorized_keys').open() as f: for i in range(6): - assert parse_openssh_pubkey(f.readline().strip()) == keys[i] + assert parse_openssh_pubkey(get_next_line(f)) == keys[i] for i in range(3): - assert parse_openssh_pubkey(f.readline().strip()) == new_keys[i] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == new_keys[i] + assert not get_next_line(f) def test_authorized_keys_list_delitem(fx_authorized_sftp): @@ -150,19 +160,19 @@ def test_authorized_keys_list_delitem(fx_authorized_sftp): del key_list[3:] with path.join('.ssh', 'authorized_keys').open() as f: for i in range(3): - assert parse_openssh_pubkey(f.readline().strip()) == keys[i] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[i] + assert not get_next_line(f) # Positive index del key_list[2] with path.join('.ssh', 'authorized_keys').open() as f: - assert parse_openssh_pubkey(f.readline().strip()) == keys[0] - assert parse_openssh_pubkey(f.readline().strip()) == keys[1] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[0] + assert parse_openssh_pubkey(get_next_line(f)) == keys[1] + assert not get_next_line(f) # Negative index del key_list[-1] with path.join('.ssh', 'authorized_keys').open() as f: - assert parse_openssh_pubkey(f.readline().strip()) == keys[0] - assert not f.readline().strip() + assert parse_openssh_pubkey(get_next_line(f)) == keys[0] + assert not get_next_line(f) def test_authorize(fx_sftpd): @@ -180,8 +190,9 @@ def test_authorize(fx_sftpd): timeout=datetime.timedelta(seconds=5) ) with authorized_keys_path.open() as f: - saved_keys = map(parse_openssh_pubkey, f) - assert frozenset(saved_keys) == (public_keys | {master_key}) + saved_keys = frozenset(parse_openssh_pubkey(l) + for l in f if l.strip()) + assert saved_keys == (public_keys | {master_key}) while datetime.datetime.now(datetime.timezone.utc) <= expires_at: time.sleep(1) time.sleep(1)