Skip to content

Commit

Permalink
Add save check for AuthorizedKeyList
Browse files Browse the repository at this point in the history
Close #5
  • Loading branch information
dahlia committed Feb 13, 2015
1 parent ac8fde7 commit 57dce44
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 38 deletions.
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Version 0.3.0
To be released.

- Geofront becomes to require Paramiko 1.15.0 or higher.
- Added save check for :class:`~geofront.remote.AuthorizedKeyList`.
[:issue:`5`]


Version 0.2.2
Expand Down
38 changes: 27 additions & 11 deletions geofront/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import datetime
import io
import itertools
import logging
import numbers
import threading
import time
Expand Down Expand Up @@ -147,9 +148,25 @@ def _iterate_lines(self):
if line:
yield line

def _save(self, authorized_keys: str):
with io.BytesIO(authorized_keys.encode()) as fo:
self.sftp_client.putfo(fo, self.FILE_PATH)
def _save(self, lines: list, existing_lines=None):
check = frozenset(line.strip() for line in lines)
if existing_lines is None:
mode = 'w'
else:
assert lines[:len(existing_lines)] == existing_lines
lines = lines[len(existing_lines):]
mode = 'a'
with self.sftp_client.open(self.FILE_PATH, mode) as fo:
if mode == 'a':
fo.write('\n')
for line in lines:
fo.write(line + '\n')
actual = frozenset(self._iterate_lines())
if actual != check:
logger = logging.getLogger(__name__ + '.AuthorizedKeyList._save')
logger.debug('file check error: expected = %r, actual = %r',
check, actual)
raise IOError('failed to write to ' + self.FILE_PATH)

def __iter__(self):
for line in self._iterate_lines():
Expand Down Expand Up @@ -193,7 +210,7 @@ def __setitem__(self, index, value):
'authorized_keys indices must be integers, not '
'{0.__module__}.{0.__qualname__}'.format(type(index))
)
self._save('\n'.join(lines))
self._save(lines)

def insert(self, index, value):
if not isinstance(index, numbers.Integral):
Expand All @@ -203,14 +220,13 @@ def insert(self, index, value):
)
lines = list(self._iterate_lines())
lines.insert(index, format_openssh_pubkey(value))
self._save('\n'.join(lines))
self._save(lines)

def extend(self, values):
lines = itertools.chain(
self._iterate_lines(),
map(format_openssh_pubkey, values)
)
self._save('\n'.join(lines))
existing_lines = list(self._iterate_lines())
appended_lines = map(format_openssh_pubkey, values)
lines = itertools.chain(existing_lines, appended_lines)
self._save(list(lines), existing_lines=existing_lines)

def __delitem__(self, index):
if not isinstance(index, (numbers.Integral, slice)):
Expand All @@ -220,7 +236,7 @@ def __delitem__(self, index):
)
lines = list(self._iterate_lines())
del lines[index]
self._save('\n'.join(lines))
self._save(lines)


@typed
Expand Down
65 changes: 38 additions & 27 deletions tests/remote_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def test_remote(b, equal):
assert (hash(a) == hash(b)) is equal


def get_next_line(fo):
line = ''
while not line:
line = fo.readline()
if not line:
return line
line = line.strip()
return line


def test_authorized_keys_list_iter(fx_authorized_sftp):
sftp_client, path, keys = fx_authorized_sftp
key_list = AuthorizedKeyList(sftp_client)
Expand Down Expand Up @@ -98,22 +108,22 @@ def test_authorized_keys_list_setitem(fx_authorized_sftp):
key_list[3:] = []
with path.join('.ssh', 'authorized_keys').open() as f:
for i in range(3):
assert parse_openssh_pubkey(f.readline().strip()) == keys[i]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[i]
assert not get_next_line(f)
# Positive index
key_list[2] = keys[3]
with path.join('.ssh', 'authorized_keys').open() as f:
assert parse_openssh_pubkey(f.readline().strip()) == keys[0]
assert parse_openssh_pubkey(f.readline().strip()) == keys[1]
assert parse_openssh_pubkey(f.readline().strip()) == keys[3]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[0]
assert parse_openssh_pubkey(get_next_line(f)) == keys[1]
assert parse_openssh_pubkey(get_next_line(f)) == keys[3]
assert not get_next_line(f)
# Negative index
key_list[-1] = keys[4]
with path.join('.ssh', 'authorized_keys').open() as f:
assert parse_openssh_pubkey(f.readline().strip()) == keys[0]
assert parse_openssh_pubkey(f.readline().strip()) == keys[1]
assert parse_openssh_pubkey(f.readline().strip()) == keys[4]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[0]
assert parse_openssh_pubkey(get_next_line(f)) == keys[1]
assert parse_openssh_pubkey(get_next_line(f)) == keys[4]
assert not get_next_line(f)


def test_authorized_keys_list_insert(fx_authorized_sftp):
Expand All @@ -122,12 +132,12 @@ def test_authorized_keys_list_insert(fx_authorized_sftp):
new_key = RSAKey.generate(1024)
key_list.insert(2, new_key)
with path.join('.ssh', 'authorized_keys').open() as f:
assert parse_openssh_pubkey(f.readline().strip()) == keys[0]
assert parse_openssh_pubkey(f.readline().strip()) == keys[1]
assert parse_openssh_pubkey(f.readline().strip()) == new_key
assert parse_openssh_pubkey(get_next_line(f)) == keys[0]
assert parse_openssh_pubkey(get_next_line(f)) == keys[1]
assert parse_openssh_pubkey(get_next_line(f)) == new_key
for i in range(2, 6):
assert parse_openssh_pubkey(f.readline().strip()) == keys[i]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[i]
assert not get_next_line(f)


def test_authorized_keys_list_extend(fx_authorized_sftp):
Expand All @@ -137,10 +147,10 @@ def test_authorized_keys_list_extend(fx_authorized_sftp):
key_list.extend(new_keys)
with path.join('.ssh', 'authorized_keys').open() as f:
for i in range(6):
assert parse_openssh_pubkey(f.readline().strip()) == keys[i]
assert parse_openssh_pubkey(get_next_line(f)) == keys[i]
for i in range(3):
assert parse_openssh_pubkey(f.readline().strip()) == new_keys[i]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == new_keys[i]
assert not get_next_line(f)


def test_authorized_keys_list_delitem(fx_authorized_sftp):
Expand All @@ -150,19 +160,19 @@ def test_authorized_keys_list_delitem(fx_authorized_sftp):
del key_list[3:]
with path.join('.ssh', 'authorized_keys').open() as f:
for i in range(3):
assert parse_openssh_pubkey(f.readline().strip()) == keys[i]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[i]
assert not get_next_line(f)
# Positive index
del key_list[2]
with path.join('.ssh', 'authorized_keys').open() as f:
assert parse_openssh_pubkey(f.readline().strip()) == keys[0]
assert parse_openssh_pubkey(f.readline().strip()) == keys[1]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[0]
assert parse_openssh_pubkey(get_next_line(f)) == keys[1]
assert not get_next_line(f)
# Negative index
del key_list[-1]
with path.join('.ssh', 'authorized_keys').open() as f:
assert parse_openssh_pubkey(f.readline().strip()) == keys[0]
assert not f.readline().strip()
assert parse_openssh_pubkey(get_next_line(f)) == keys[0]
assert not get_next_line(f)


def test_authorize(fx_sftpd):
Expand All @@ -180,8 +190,9 @@ def test_authorize(fx_sftpd):
timeout=datetime.timedelta(seconds=5)
)
with authorized_keys_path.open() as f:
saved_keys = map(parse_openssh_pubkey, f)
assert frozenset(saved_keys) == (public_keys | {master_key})
saved_keys = frozenset(parse_openssh_pubkey(l)
for l in f if l.strip())
assert saved_keys == (public_keys | {master_key})
while datetime.datetime.now(datetime.timezone.utc) <= expires_at:
time.sleep(1)
time.sleep(1)
Expand Down

0 comments on commit 57dce44

Please sign in to comment.