Skip to content

Commit

Permalink
allow reverting to other authentication methods if one fails
Browse files Browse the repository at this point in the history
  • Loading branch information
marian-code committed Feb 10, 2022
1 parent e31177f commit cbc3506
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 173 deletions.
14 changes: 6 additions & 8 deletions ssh_utilities/connection.py
Expand Up @@ -187,7 +187,7 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
quiet=quiet)

try:
credentials = cls.available_hosts[ssh_server]
credentials: dict = cls.available_hosts[ssh_server]
except KeyError as e:
raise KeyError(f"couldn't find login credentials for {ssh_server}:"
f" {e}")
Expand All @@ -205,7 +205,10 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
if allow_agent:
log.info(f"no private key supplied for {hostname}, will try "
f"to authenticate through ssh-agent")
pkey_file = None
try:
pkey_file = credentials["identityfile"][0]
except (KeyError, IndexError) as e:
pkey_file = None
else:
log.info(f"private key found for host: {hostname}")
try:
Expand Down Expand Up @@ -411,12 +414,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
server_name=server_name,
quiet=quiet
)
elif allow_agent:
ssh_key_file = None
ssh_password = None
elif ssh_key_file:
ssh_password = None
elif not ssh_password:
elif not any((allow_agent, ssh_key_file, ssh_password)):
ssh_password = getpass.getpass(prompt="Enter password: ")

return SSHConnection(
Expand Down
130 changes: 9 additions & 121 deletions ssh_utilities/remote/_connection_wrapper.py
Expand Up @@ -183,18 +183,17 @@ def connect_wrapper(wrapped_instance: "_CLASS", *args, **kwargs):
except self.exclude_exceptions as e:
# if exception is one of the excluded, re-raise it
raise e from None
except NoValidConnectionsError as e:
except (NoValidConnectionsError, SSHException) as e:
error = e
log.exception(f"Caught paramiko error in {n}: {e}")
except SSHException as e:
error = e
log.exception(f"Caught paramiko error in {n}: {e}")
except AttributeError as e:
error = e
log.exception(f"Caught attribute error in {n}: {e}")
except OSError as e:
error = e
log.exception(f"Caught OS error in {n}: {e}")
"""
except AttributeError as e:
error = e
log.exception(f"Caught attribute error in {n}: {e}")
except OSError as e:
error = e
log.exception(f"Caught OS error in {n}: {e}")
"""
except SFTPError as e:
# garbage packets,
# see: https://github.com/paramiko/paramiko/issues/395
Expand All @@ -213,114 +212,3 @@ def connect_wrapper(wrapped_instance: "_CLASS", *args, **kwargs):
time.sleep(60)

return connect_wrapper


"""
def check_connections(original_function: Optional[Callable] = None, *,
exclude_exceptions: "_EXCTYPE" = ()):
def _decorate(function):
@wraps(function)
def connect_wrapper(self: "_CLASS", *args, **kwargs):
orig_self = self
# we have to deal with branching inner classes where attribute 'c'
# is instance of SSHConnection not paramiko SSHClient
# as of now we have 2 inner levels so we must iterate until we get
# to the base
while True:
print(type(self), type(self.c))
if not isinstance(self.c, SSHClient):
self = self.c
else:
instance = self
break
def negotiate() -> bool:
try:
instance.close(quiet=True)
except Exception as e:
log.exception(f"Couldn't close connection: {e}")
try:
instance._get_ssh()
except ConnectionError:
success = False
else:
success = True
log.debug(f"success 1: {success}")
if not success:
return False
if instance._sftp_open:
log.debug(f"success 2: {success}")
try:
instance.sftp
except SFTPOpenError:
success = False
log.debug(f"success 3: {success}")
else:
log.debug(f"success 4: {success}")
success = True
else:
success = False
log.exception(f"Relevant variables:\n"
f"success: {success}\n"
f"password: {instance.password}\n"
f"address: {instance.address}\n"
f"username: {instance.username}\n"
f"ssh class: {type(instance.c)}\n"
f"sftp class: {type(instance.sftp)}")
if instance._sftp_open:
log.exception(f"remote home: {instance.remote_home}")
return success
n = function.__name__
error = None
try:
return function(orig_self, *args, **kwargs)
except exclude_exceptions as e:
# if exception is one of the excluded, re-raise it
raise e from None
except NoValidConnectionsError as e:
error = e
log.exception(f"Caught paramiko error in {n}: {e}")
except SSHException as e:
error = e
log.exception(f"Caught paramiko error in {n}: {e}")
except AttributeError as e:
error = e
log.exception(f"Caught attribute error in {n}: {e}")
except OSError as e:
error = e
log.exception(f"Caught OS error in {n}: {e}")
except SFTPError as e:
# garbage packets,
# see: https://github.com/paramiko/paramiko/issues/395
log.exception(f"Caught paramiko error in {n}: {e}")
finally:
while error:
log.warning("Connection is down, trying to reconnect")
if negotiate():
log.info("Connection restablished, continuing ..")
connect_wrapper(orig_self, *args, **kwargs)
break
else:
log.warning("Unsuccessful, wait 60 seconds "
"before next try")
time.sleep(60)
return connect_wrapper
if original_function:
return _decorate(original_function)
return _decorate
"""
99 changes: 56 additions & 43 deletions ssh_utilities/remote/remote.py
Expand Up @@ -79,6 +79,7 @@ class SSHConnection(ConnectionABC):

_remote_home: str = ""
__lock: Union[ContextManager[None], RLock]
__AUTH_ATTEMPTS: int = 3

def __init__(self, address: str, username: str,
password: Optional[str] = None,
Expand Down Expand Up @@ -137,22 +138,14 @@ def __init__(self, address: str, username: str,
self.pkey_file = pkey_file
self.allow_agent = allow_agent

# paramiko connection
if allow_agent:
self._pkey = None
self.password = None
elif pkey_file:
for key in _KEYS:
try:
self._pkey = key.from_private_key_file(
self._path2str(pkey_file)
)
except paramiko.SSHException:
log.info(f"could not parse key with {key.__name__}")
elif password:
self._pkey = None
else:
raise RuntimeError("Must input password or path to pkey")
if not allow_agent and not pkey_file and not password:
raise RuntimeError(
"Must allow ssh-agent input password or path to private key"
)


if pkey_file:
self._load_pkey()

self._c = paramiko.client.SSHClient()
self._c.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
Expand Down Expand Up @@ -235,35 +228,55 @@ def close(self, *, quiet: bool = True):
self.c.close()

# * additional methods needed by remote ssh class, not in ABC definition
def _get_ssh(self, authentication_attempts: int = 0):
def _get_ssh(self):

with self.__lock:
def _connect(method: str, **kwargs):

log.info(f"trying to authenticate with {method}")
for _ in range(self.__AUTH_ATTEMPTS):
with self.__lock:
try:
self.c.connect(self.address, username=self.username, **kwargs)
except (paramiko.ssh_exception.AuthenticationException,
paramiko.ssh_exception.NoValidConnectionsError) as e:
log.warning(
f"Error in authentication {e}. Trying again ..."
)
else:
log.info(f"successfully authenticated with: {method}")
return True

log.warning(
f"authentication with {method} failed, reverting to backup methods"
)
return False

# we will try each method maximum three times
# authenticatiom method preference is based on security and convenience
if self.allow_agent and _connect("ssh-agent", allow_agent=True):
return

if self.pkey_file and _connect("private-key", pkey=self._pkey):
return

if self.password and _connect("password", password=self.password, look_for_keys=False):
return

# if none of the authentication methods was sucessfull, raise error
raise ConnectionError(
f"Connection to {self.address} could not be established"
)


def _load_pkey(self):

for key in _KEYS:
try:
if self.allow_agent:
# connect using ssh-agent
self.c.connect(self.address, username=self.username, allow_agent=True)
if self._pkey:
# connect with public key
self.c.connect(self.address, username=self.username,
pkey=self._pkey)
else:
# if password was passed try to connect with it
self.c.connect(self.address, username=self.username,
password=self.password, look_for_keys=False)

except (paramiko.ssh_exception.AuthenticationException,
paramiko.ssh_exception.NoValidConnectionsError) as e:
log.warning(f"Error in authentication {e}. Trying again ...")

# max three attempts to connect at once
authentication_attempts += 1
if authentication_attempts >= 3:
raise ConnectionError(f"Connection to {self.address} "
f"could not be established")
else:
self._get_ssh(
authentication_attempts=authentication_attempts
)
self._pkey = key.from_private_key_file(
self._path2str(self.pkey_file)
)
except paramiko.SSHException:
log.info(f"could not parse key with {key.__name__}")

@property # type: ignore
@check_connections()
Expand Down
2 changes: 1 addition & 1 deletion ssh_utilities/version.py
@@ -1 +1 @@
__version__ = "0.13.0"
__version__ = "0.14.0"

0 comments on commit cbc3506

Please sign in to comment.