diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8500dc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,41 @@ +*.py[cod] + +# C extensions +*.so + +# Packages +*.egg +*.egg-info +dist +build +eggs +parts +bin +var +sdist +develop-eggs +.installed.cfg +lib +lib64 + +# Installer logs +pip-log.txt + +# Unit test / coverage reports +.coverage +.tox +nosetests.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Pycharm +.idea + +# Backup files +*~ diff --git a/LICENSE.txt b/LICENSE.txt deleted file mode 100644 index 4c3ab8f..0000000 --- a/LICENSE.txt +++ /dev/null @@ -1,12 +0,0 @@ - Copyright (c) 2010 Yahoo! Inc. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. See accompanying LICENSE file. diff --git a/setup.py b/setup.py index 0915cef..f9591a9 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name="sshmap", - version="0.5.8", + version="0.5.50", author="Dwight Hubbard", author_email="dhubbard@yahoo-inc.com", url="https://github.com/dwighthubbard/sshmap", diff --git a/sshmap/__init__.py b/sshmap/__init__.py old mode 100644 new mode 100755 index 07978b1..fa25fe7 --- a/sshmap/__init__.py +++ b/sshmap/__init__.py @@ -1,8 +1,8 @@ -#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Copyright (c) 2012-2013 Yahoo! Inc. All rights reserved. #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software @@ -10,4 +10,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. See accompanying LICENSE file. -from sshmap import * + +#from sshmap import run, run_command, ssh_result, ssh_results, fastSSHClient +import sshmap +import callback +import utility +import runner + +# For backwards compatibility +from callback import summarize_failures as callback_summarize_failures +from callback import aggregate_output as callback_aggregate_output +from callback import exec_command as callback_exec_command +from callback import filter_match as callback_filter_match +from callback import status_count as callback_status_count + +# The actual used sshmap functions +from sshmap import run, run_command diff --git a/sshmap/callback.py b/sshmap/callback.py new file mode 100755 index 0000000..07437e7 --- /dev/null +++ b/sshmap/callback.py @@ -0,0 +1,228 @@ +#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. +__author__ = 'dhubbard' +""" +sshmap built in callback handlers +""" +import os +import sys +import hashlib +import json +import stat +import base64 +import subprocess + +import utility + + +# Filter callback handlers +def flowthrough(result): + """ + Builtin Callback, return the raw data passed + + >>> result=flowthrough(ssh_result(["output"], ["error"],"foo", 0)) + >>> result.dump() + foo output error 0 0 None + """ + return result + + +def summarize_failures(result): + """ + Builtin Callback, put a summary of failures into parm + """ + failures = result.setting('failures') + if not failures: + result.parm['failures'] = [] + failures = [] + if result.ssh_retcode: + failures.append(result.host) + result.parm['failures'] = failures + return result + + +def exec_command(result): + """ + Builtin Callback, pass the results to a command/script + :param result: + """ + script = result.setting("callback_script") + if not script: + return result + utility.status_clear() + result_out, result_err = subprocess.Popen( + script + " " + result.host, + shell=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ).communicate( + result.out_string() + result.err_string() + ) + result.out = [result_out] + result.err = [result_err] + print(result.out_string()) + return result + + +def aggregate_output(result): + """ Builtin Callback, Aggregate identical results """ + aggregate_hosts = result.setting('aggregate_hosts') + if not aggregate_hosts: + aggregate_hosts = {} + collapsed_output = result.setting('collapsed_output') + if not collapsed_output: + collapsed_output = {} + h = hashlib.md5() + h.update(result.out_string()) + h.update(result.err_string()) + if result.ssh_retcode: + h.update(result.ssh_error_message()) + digest = h.hexdigest() + if digest in aggregate_hosts.keys(): + aggregate_hosts[digest].append(result.host) + else: + aggregate_hosts[digest] = [result.host] + if result.ssh_retcode: + error = [] + if result.err: + error = result.err + error.append(result.ssh_error_message()) + collapsed_output[digest] = (result.out, error) + else: + collapsed_output[digest] = (result.out, result.err) + result.parm['aggregate_hosts'] = aggregate_hosts + if collapsed_output: + result.parm['collapsed_output'] = collapsed_output + return result + + +def filter_match(result): + """ + Builtin Callback, remove all output if the string is not found in the + output + similar to grep + :param result: + """ + if result.out_string().find(result.setting('match')) == -1 and \ + result.err_string().find(result.setting('match')) == -1: + result.out = '' + result.err = '' + return result + + +def filter_json(result): + """ + Builtin Callback, change stdout to json + + >>> result=filter_json(ssh_result(["output"], ["error"],"foo", 0)) + >>> result.dump() + foo [["output"], ["error"], 0] error 0 0 None + """ + result.out = [json.dumps((result.out, result.err, result.retcode))] + return result + + +def filter_base64(result): + """ + Builtin Callback, base64 encode the info in out and err streams + """ + result.out = [base64.b64encode(result.out_string)] + result.err = [base64.b64encode(result.err_string)] + return result + + +#Status callback handlers +def status_count(result): + """ + Builtin Callback, show the count complete/remaining + :param result: + """ + # The master process inserts the status into the + # total_host_count and completed_host_count variables + sys.stderr.write('\x1b[0G\x1b[0K%s/%s' % ( + result.setting('completed_host_count'), + result.setting('total_host_count'))) + sys.stderr.flush() + return result + + +#Output callback handlers +def output_prefix_host(result): + """ + Builtin Callback, print the output with the hostname: prefixed to each line + :param result: + hostname: out + + >>> result=sshmap.callback.output_prefix_host(ssh_result(['out'],['err'], 'hostname', 0)) + >>> result.dump() + """ + output = [] + error = [] + utility.status_clear() + # If summarize_failures option is set don't print ssh errors inline + if result.setting('summarize_failed') and result.ssh_retcode: + return result + if result.setting('print_rc'): + rc = ' SSH_Returncode: %d\tCommand_Returncode: %d' % ( + result.ssh_retcode, result.retcode) + else: + rc = '' + if result.ssh_retcode: + print >> sys.stderr, '%s: %s' % (result.host, result.ssh_error_message()) + error = ['%s: %s' % (result.host, result.ssh_error_message())] + if len(result.out_string()): + for line in result.out: + if line: + print '%s:%s %s' % (result.host, rc, line.strip()) + output.append('%s:%s %s\n' % (result.host, rc, line.strip())) + if len(result.err_string()): + for line in result.err: + if line: + print >> sys.stderr, '%s:%s %s' % (result.host, rc, line.strip()) + error.append('%s:%s Error: %s\n' % (result.host, rc, line.strip())) + if result.setting('output'): + if not len(result.out_string()) and not len(result.err_string()) and not result.setting( + 'only_output') and result.setting('print_rc'): + print '%s:%s' % (result.host, rc) + sys.stdout.flush() + sys.stderr.flush() + result.out = output + result.err = error + return result + + +def read_conf(key=None, prompt=True): + """ Read settings from the config file + :param key: + :param prompt: + """ + try: + conf = json.load(open(os.path.expanduser('~/.sshmap.conf'), 'r')) + except IOError: + conf = sshmap.conf_defaults + if key: + try: + return conf[key].encode('ascii') + except KeyError: + pass + else: + return conf + if key and prompt: + conf[key] = raw_input(sshmap.conf_desc[key] + ': ') + fh = open(os.path.expanduser('~/.sshmap2.conf'), 'w') + os.fchmod(fh.fileno(), stat.S_IRUSR | stat.S_IWUSR) + json.dump(conf, fh) + fh.close() + return conf[key] + else: + return None diff --git a/sshmap/defaults.py b/sshmap/defaults.py new file mode 100644 index 0000000..3b03399 --- /dev/null +++ b/sshmap/defaults.py @@ -0,0 +1,67 @@ +#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. +""" +Default Values for sshmap +""" +__author__ = 'dhubbard' +import os + +# Defaults +JOB_MAX = 100 +# noinspection PyBroadException +try: + for line in open('/proc/%d/limits' % os.getpid(), 'r').readlines(): + if line.startswith('Max processes'): + JOB_MAX = int(line.strip().split()[2]) / 4 +except: + pass + +# Return code values +RUN_OK = 0 +RUN_FAIL_AUTH = 1 +RUN_FAIL_TIMEOUT = 2 +RUN_FAIL_CONNECT = 3 +RUN_FAIL_SSH = 4 +RUN_SUDO_PROMPT = 5 +RUN_FAIL_UNKNOWN = 6 +RUN_FAIL_NOPASSWORD = 7 +RUN_FAIL_BADPASSWORD = 8 + +# Text return codes +RUN_CODES = ['Ok', 'Authentication Error', 'Timeout', 'SSH Connection Failed', + 'SSH Failure', + 'Sudo did not send a password prompt', 'Connection refused', + 'Sudo password required', + 'Invalid sudo password'] + +# Configuration file field descriptions +conf_desc = { + "username": "IRC Server username", + "password": "IRC Server password", + "channel": "sshmap", +} + +# Configuration file defaults +conf_defaults = { + "address": "chat.freenode.net", + "port": "6667", + "use_ssl": False, +} + +sudo_message = [ + 'We trust you have received the usual lecture from the local System', + 'Administrator. It usually boils down to these three things:', + '#1) Respect the privacy of others.', + '#2) Think before you type.', + '#3) With great power comes great responsibility.' +] diff --git a/sshmap/runner.py b/sshmap/runner.py new file mode 100644 index 0000000..7e5d1b3 --- /dev/null +++ b/sshmap/runner.py @@ -0,0 +1,48 @@ +#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. +""" +Python rpc remote runner handlers +""" +__author__ = 'dhubbard' +import os +import base64 + + +script_stdin = """import os +os.popen(\"{command}\".decode('base64').decode('{compressor}'),'w').write(\"\"\"{input}\"\"\".decode('base64').decode('{compressor}')) +""" + +script_sudo = """import os +command = \"{command}\".decode('base64').decode('{compressor}') +fh = os.popen(\"sudo -S \" + command,'w') +fh.write(\"{password}\\n\") +fh.write(\"\"\"{input}\"\"\".decode('base64').decode('{compressor}')) +""" + +def get_runner(command, input, password='', runner_script=None, + compressor='bz2'): + if not runner_script: + runner_script = script_stdin + + if compressor not in ['bz2', 'zlilb']: + compressor = 'bz2' + + base64_command = base64.b64encode(command.encode('bz2')) + base64_input = base64.b64encode(input.encode('bz2')) + + return runner_script.format( + command=base64_command, + input=base64_input, + compressor=compressor, + password=password + ) diff --git a/sshmap/sshmap b/sshmap/sshmap index 9f2aae5..745eabf 100755 --- a/sshmap/sshmap +++ b/sshmap/sshmap @@ -3,7 +3,7 @@ #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software @@ -15,18 +15,20 @@ Python based ssh multiplexer optimized for map operations command line utility. """ +import os import sys -from optparse import OptionParser +import optparse import getpass -from sshmap import * +import sshmap +import hostlists if __name__ == "__main__": - parser = OptionParser() - #parser.add_option("--html", dest="html", default=False,action="store_true", help="Use HTML for formatting") - parser.add_option("--output_json", dest = "output_json", default = False, action = "store_true", + parser = optparse.OptionParser() + parser.add_option("--output_json", dest="output_json", + default=False, action="store_true", help = "Output in JSON format") - parser.add_option("--output_base64", dest = "output_base64", default = False, action = "store_true", - help = "Output in base64 format") + parser.add_option("--output_base64", dest="output_base64", default=False, + action="store_true", help="Output in base64 format") parser.add_option("--summarize_failed", dest = "summarize_failed", default = False, action = "store_true", help = "Print a list of hosts that failed at the end") parser.add_option("--aggregate_output", "--collapse", dest = "aggregate_output", default = False, @@ -50,10 +52,11 @@ if __name__ == "__main__": help = "Script to process the output of each host. The hostname will be passed as the first argument and the stdin/stderr from the host will be passed as stdin/stderr of the script") parser.add_option("--no_status", dest = "show_status", default = True, action = "store_false", help = "Don't show a status count as the command progresses") - parser.add_option("--sudo", dest = "sudo", default = False, action = "store_true", + parser.add_option("--sudo", dest="sudo", default=False, + action="store_true", help = "Use sudo to run the command as root") - parser.add_option("--password", dest = "password", default = None, action = "store_true", - help = "Prompt for a password") + parser.add_option("--password", dest="password", default=False, + action="store_true", help="Prompt for a password") (options, args) = parser.parse_args() @@ -63,54 +66,52 @@ if __name__ == "__main__": command = firstline[2:] args.append(command) - #if len(args) != 2: - # if len(args) and args[0] in ['test']: - # Unit testing - # import doctest - - # doctest.testmod() - #sys.exit(0) - # Default option values options.password = None options.username = getpass.getuser() options.output = True # Create our callback pipeline based on the options passed - callback = [callback_summarize_failures] + callback = [sshmap.callback.summarize_failures] if options.match: - callback.append(callback_filter_match) + callback.append(sshmap.callback.filter_match) if options.output_base64: - callback.append(callback_filter_base64) + callback.append(sshmap.callback.filter_base64) if options.output_json: - callback.append(callback_filter_json) + callback.append(sshmap.callback.filter_json) if options.callback_script: - callback.append(callback_exec_command) + callback.append(sshmap.callback.exec_command) else: if options.aggregate_output: - callback.append(callback_aggregate_output) + callback.append(sshmap.callback.aggregate_output) else: - callback.append(callback_output_prefix_host) + callback.append(sshmap.callback.output_prefix_host) if options.show_status: - callback.append(callback_status_count) + callback.append(sshmap.callback.status_count) # Get the password if the options passed indicate it might be needed if options.sudo: # Prompt for password, we really need to add a password file option try: - options.password = os.environ['OS_PASSWORD'] + options.password = os.environ['SSHMAP_SUDO_PASSWORD'] except KeyError: - options.password = getpass.getpass('Enter sudo password for user ' + getpass.getuser() + ': ') + options.password = getpass.getpass( + 'Enter sudo password for user ' + getpass.getuser() + ': ') elif options.password: # Prompt for password, we really need to add a password file option try: - options.password = os.environ['OS_PASSWORD'] + options.password = os.environ['SSHMAP_SUDO_PASSWORD'] except KeyError: - options.password = getpass.getpass('Enter password for user ' + getpass.getuser() + ': ') + options.password = getpass.getpass( + 'Enter password for user ' + getpass.getuser() + ': ') command = ' '.join(args[1:]) range = args[0] - results = run(args[0], command, username = options.username, password = options.password, sudo = options.sudo, - timeout = options.timeout, script = options.runscript, jobs = options.jobs, sort = options.sort, - shuffle = options.shuffle, output_callback = callback, parms = vars(options)) + results = sshmap.run( + args[0], command, username=options.username, + password=options.password, sudo=options.sudo, + timeout=options.timeout, script=options.runscript, jobs=options.jobs, + sort=options.sort, + shuffle=options.shuffle, output_callback=callback, parms=vars(options) + ) if options.aggregate_output: aggregate_hosts = results.setting('aggregate_hosts') collapsed_output = results.setting('collapsed_output') @@ -121,7 +122,12 @@ if __name__ == "__main__": print ','.join(hostlists.compress(aggregate_hosts[md5])) print "-" * (int(columns)-2) stdout, stderr = collapsed_output[md5] - if len(stdout): print '\n'.join(stdout) - if len(stderr): print >> sys.stderr, '\n'.join(stderr) - if options.summarize_failed and 'failures' in results.parm.keys() and len(results.parm['failures']): - print 'SSH Failed to: %s' % hostlists.compress(results.parm['failures']) + if len(stdout): + print ''.join(stdout) + if len(stderr): + print >> sys.stderr, '\n'.join(stderr) + if options.summarize_failed and 'failures' in results.parm.keys() and \ + len(results.parm['failures']): + print( + 'SSH Failed to: %s' % hostlists.compress(results.parm['failures']) + ) diff --git a/sshmap/sshmap.py b/sshmap/sshmap.py index b407cae..0a5dbd1 100755 --- a/sshmap/sshmap.py +++ b/sshmap/sshmap.py @@ -3,7 +3,7 @@ #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software @@ -14,25 +14,23 @@ """ Python based ssh multiplexer optimized for map operations """ -#disable deprecated warning messages -import warnings +# Pull in the python3 print function for python3 compatibility +from __future__ import print_function +#disable deprecated warning messages that occur during some of the imports +import warnings warnings.filterwarnings("ignore") # Python Standard Library imports -import sys import os -import stat +import sys import getpass import socket import types -import base64 import random import signal -import hashlib -import json +import inspect import multiprocessing -import subprocess import logging # Imports from external python extension modules @@ -40,48 +38,10 @@ # Imports from other sshmap modules import hostlists - -# Defaults -JOB_MAX = 100 -# noinspection PyBroadException -try: - for line in open('/proc/%d/limits' % os.getpid(), 'r').readlines(): - if line.startswith('Max processes'): - JOB_MAX = int(line.strip().split()[2]) / 4 -except: - pass - -# Return code values -RUN_OK = 0 -RUN_FAIL_AUTH = 1 -RUN_FAIL_TIMEOUT = 2 -RUN_FAIL_CONNECT = 3 -RUN_FAIL_SSH = 4 -RUN_SUDO_PROMPT = 5 -RUN_FAIL_UNKNOWN = 6 -RUN_FAIL_NOPASSWORD = 7 -RUN_FAIL_BADPASSWORD = 8 - -# Text return codes -RUN_CODES = ['Ok', 'Authentication Error', 'Timeout', 'SSH Connection Failed', - 'SSH Failure', - 'Sudo did not send a password prompt', 'Connection refused', - 'Sudo password required', - 'Invalid sudo password'] - -# Configuration file field descriptions -conf_desc = { - "username": "IRC Server username", - "password": "IRC Server password", - "channel": "sshmap", -} - -# Configuration file defaults -conf_defaults = { - "address": "chat.freenode.net", - "port": "6667", - "use_ssl": False, -} +import utility +import callback +import defaults +import runner # Fix to make ctrl-c correctly terminate child processes # spawned by the multiprocessing module @@ -89,7 +49,15 @@ def wrapper(func): + """ + Simple timeout wrapper for multiprocessing + :param func: + """ def wrap(self, timeout=None): + """ + The wrapper method + :param timeout: + """ return func(self, timeout=timeout if timeout is not None else 1e100) return wrap @@ -98,7 +66,7 @@ def wrap(self, timeout=None): IMapIterator.next = wrapper(IMapIterator.next) -class ssh_result: +class ssh_result(object): """ ssh_result class, that holds the output from the ssh_call. This is passed to all the callback functions. @@ -128,22 +96,28 @@ def err_string(self): def setting(self, key): """ Get a setting from the parm dict or return None if it doesn't exist + :param key: """ - return get_parm_val(self.parm, key) + return utility.get_parm_val(self.parm, key) def ssh_error_message(self): """ Return the ssh_error_message for the error code """ - return RUN_CODES[self.ssh_retcode] + return defaults.RUN_CODES[self.ssh_retcode] def dump(self, return_parm=True, return_retcode=True): - """ Print all our public values """ - print self.host, self.out_string().replace('\n', ''), self.err_string().replace('\n', ''), + """ Print all our public values + :param return_parm: + :param return_retcode: + """ + sys.stdout.write(self.host+' ') + sys.stdout.write(self.out_string().replace('\n', '')+' ') + sys.stderr.write(self.err_string().replace('\n', '')+' ') if return_retcode: - print self.retcode, + sys.stdout.write(self.retcode+' ') if return_parm: - print self.ssh_retcode, self.parm + sys.stdout.write(self.ssh_retcode+' '+self.parm) else: - print + sys.stdout.write('\n') def print_output(self): """ Print output from the commands """ @@ -166,10 +140,12 @@ def dump(self): """ Dump all the result objects """ for item in self.__iter__(): item.dump(return_parm=False, return_retcode=False) - print self.parm + print(self.parm) def print_output(self, summarize_failures=False): - """ Print all the objects """ + """ Print all the objects + :param summarize_failures: + """ for item in self.__iter__(): item.print_output() if summarize_failures: @@ -182,8 +158,9 @@ def print_output(self, summarize_failures=False): def setting(self, key): """ Get a setting from the parm dict or return None if it doesn't exist + :param key: """ - return get_parm_val(self.parm, key) + return utility.get_parm_val(self.parm, key) def agent_auth(transport, username): @@ -191,6 +168,8 @@ def agent_auth(transport, username): Attempt to authenticate to the given transport using any of the private keys available from an SSH agent or from a local private RSA key file (assumes no pass phrase). + :param transport: + :param username: """ agent = ssh.Agent() @@ -214,6 +193,14 @@ class fastSSHClient(ssh.SSHClient): """ ssh SSHClient class extended with timeout support """ def exec_command(self, command, bufsize=-1, timeout=None, pty=False): + """ + Execute a command + :param command: + :param bufsize: + :param timeout: + :param pty: + :return: + """ chan = self._transport.open_session() chan.settimeout(timeout) if pty: @@ -226,34 +213,40 @@ def exec_command(self, command, bufsize=-1, timeout=None, pty=False): def _term_readline(handle): - #print '_iterm_readline' - #print type(handle) char = handle.read(1) - #print '%s' % (char), type(char) buf = "" - #print '_iterm_readline: starting loop' try: while char: - #print '_item_readline: appending',type(buf),type(char) buf += char if char in ['\r', '\n']: - #print '_iterm_readline - Found line', len(buf), char, buf return buf char = handle.read(1) - except Exception, message: - print Exception, message - #print '_item_readline - Exit', buf + except Exception as message: + print('%s %s' % (Exception, message)) return buf -def run_command(host, command="uname -a", username=None, password=None, sudo=False, script=None, timeout=None, - parms=None, client=None, bufsize=-1, cwd='/tmp', logging=False): +def run_command(host, command="uname -a", username=None, password=None, + sudo=False, script=None, timeout=None, parms=None, client=None, + bufsize=-1, logging=False): """ Run a command or script on a remote node via ssh + :param host: + :param command: + :param username: + :param password: + :param sudo: + :param script: + :param timeout: + :param parms: + :param client: + :param bufsize: + :param logging: """ # Guess any parameters not passed that can be if isinstance(host, types.TupleType): - host, command, username, password, sudo, script, timeout, parms, client = host + host, command, username, password, sudo, script, timeout, parms, \ + client = host if timeout == 0: timeout = None if not username: @@ -261,13 +254,12 @@ def run_command(host, command="uname -a", username=None, password=None, sudo=Fal if bufsize == -1 and script: bufsize = os.path.getsize(script) + 1024 + script_parameters = None if script: temp = command.split() if len(temp) > 1: command = temp[0] script_parameters = temp - else: - script_parameters = None # Get a result object to put our output in result = ssh_result(host=host, parm=parms) @@ -282,7 +274,7 @@ def run_command(host, command="uname -a", username=None, password=None, sudo=Fal client = fastSSHClient() except: result.err = ['Error creating client'] - result.ssh_retcode = RUN_FAIL_UNKNOWN + result.ssh_retcode = defaults.RUN_FAIL_UNKNOWN return result client.set_missing_host_key_policy(ssh.AutoAddPolicy()) # load_system_host_keys slows things way down @@ -290,70 +282,58 @@ def run_command(host, command="uname -a", username=None, password=None, sudo=Fal close_client = True # noinspection PyBroadException try: - client.connect(host, username=username, password=password, timeout=timeout) + client.connect(host, username=username, password=password, + timeout=timeout) except ssh.AuthenticationException: - result.ssh_retcode = RUN_FAIL_AUTH + result.ssh_retcode = defaults.RUN_FAIL_AUTH return result except ssh.SSHException: - result.ssh_retcode = RUN_FAIL_CONNECT + result.ssh_retcode = defaults.RUN_FAIL_CONNECT return result except AttributeError: - result.ssh_retcode = RUN_FAIL_SSH + result.ssh_retcode = defaults.RUN_FAIL_SSH return result except socket.error: - result.ssh_retcode = RUN_FAIL_CONNECT + result.ssh_retcode = defaults.RUN_FAIL_CONNECT return result - except Exception, message: - result.ssh_retcode = RUN_FAIL_UNKNOWN + except Exception as message: + logging.debug('Got unknown exception %s', message) + result.ssh_retcode = defaults.RUN_FAIL_UNKNOWN return result try: - # We have to force a sudo -k first or we can't reliably know we'll be prompted for our password + # We have to force a sudo -k first or we can't reliably know we'll be + # prompted for our password if sudo: - stdin, stdout, stderr, chan = client.exec_command('sudo -k %s' % command, timeout=timeout, bufsize=bufsize, - pty=True) + stdin, stdout, stderr, chan = client.exec_command( + 'sudo -k -S %s' % command, + timeout=timeout, bufsize=bufsize, pty=False + ) if not chan: - result.ssh_retcode = RUN_FAIL_CONNECT + result.ssh_retcode = defaults.RUN_FAIL_CONNECT return result else: - stdin, stdout, stderr, chan = client.exec_command(command, timeout=timeout, bufsize=bufsize) + stdin, stdout, stderr, chan = client.exec_command( + command, timeout=timeout, bufsize=bufsize) if not chan: - result.ssh_retcode = RUN_FAIL_CONNECT + result.ssh_retcode = defaults.RUN_FAIL_CONNECT result.err = ["WTF, this shouldn't happen\n"] return result - except ssh.SSHException, ssh.transport.SSHException: - result.ssh_retcode = RUN_FAIL_SSH + except (ssh.SSHException, ssh.transport.SSHException): + result.ssh_retcode = defaults.RUN_FAIL_SSH return result if sudo: try: # Send the password stdin.write(password + '\r') stdin.flush() - - # Remove the password prompt and password from the output - prompt = _term_readline(stdout) - seen_password = False - seen_password_prompt = False - #print 'READ:',prompt - while 'assword:' in prompt or password in prompt or 'try again' in prompt or len(prompt.strip()) == 0: - if 'try again' in prompt: - result.ssh_retcode = RUN_FAIL_BADPASSWORD - return result - prompt_new = _term_readline(stdout) - if 'assword:' in prompt: - seen_password_prompt = True - if password in prompt: - seen_password = True - if seen_password_prompt and seen_password: - break - prompt = prompt_new except socket.timeout: result.err = ['Timeout during sudo connect, likely bad password'] - result.ssh_retcode = RUN_FAIL_TIMEOUT + result.ssh_retcode = defaults.RUN_FAIL_TIMEOUT return result if script: - # Pass the script over stdin and close the channel so the receving end gets an EOF - # process it as a django template with the arguments passed + # Pass the script over stdin and close the channel so the receiving end + # gets an EOF process it as a django template with the arguments passed # noinspection PyBroadException try: import django.template @@ -361,291 +341,122 @@ def run_command(host, command="uname -a", username=None, password=None, sudo=Fal import django.conf django.conf.settings.configure() - template = open(script, 'r').read() + if os.path.exists(script): + template = open(script, 'r').read() + else: + template = script if script_parameters: - c = django.template.Context({ 'argv': script_parameters }) + c = django.template.Context({'argv': script_parameters}) else: - c = django.template.Context({ }) + c = django.template.Context({}) stdin.write(django.template.Template(template).render(c)) - except Exception, e: - stdin.write(open(script, 'r').read()) + except: + if os.path.exists(script): + stdin.write(open(script, 'r').read()) + else: + stdin.write(script) stdin.flush() stdin.channel.shutdown_write() try: # Read the output from stdout,stderr and close the connection result.out = stdout.readlines() result.err = stderr.readlines() + if sudo: + # Remove any passwords or prompts from the start of the stderr + # output + err = [] + check_prompt = True + skip = False + for el in result.err: + if check_prompt: + if password in el or 'assword:' in el or \ + '[sudo] password' in el or el.strip() == '' or \ + el.strip() in defaults.sudo_message: + skip = True + else: + check_prompt = False + if not skip: + err.append(el) + skip = False + result.err = err + result.retcode = chan.recv_exit_status() if close_client: client.close() except socket.timeout: - result.ssh_retcode = RUN_FAIL_TIMEOUT + result.ssh_retcode = defaults.RUN_FAIL_TIMEOUT return result - result.ssh_retcode = RUN_OK - return result - - -# Handy utility functions -def get_parm_val(parm=None, key=None): - """ - Return the value of a key - - >>> get_parm_val(parm={'test':'val'},key='test') - 'val' - >>> get_parm_val(parm={'test':'val'},key='foo') - >>> - """ - if parm and key in parm.keys(): - return parm[key] - else: - return None - - -def status_info(callbacks, text): - """ - Update the display line at the cursor - """ - #print callbacks,text - #return - if isinstance(callbacks, list) and callback_status_count in callbacks: - status_clear() - sys.stderr.write(text) - sys.stderr.flush() - - -def status_clear(): - """ - Clear the status line (current line) - """ - sys.stderr.write('\x1b[0G\x1b[0K') - #sys.stderr.flush() - - -# Built in callbacks -# Filter callback handlers -def callback_flowthrough(result): - """ - Builtin Callback, return the raw data passed - - >>> result=callback_flowthrough(ssh_result(["output"], ["error"],"foo", 0)) - >>> result.dump() - foo output error 0 0 None - """ + result.ssh_retcode = defaults.RUN_OK return result -def callback_summarize_failures(result): - """ - Builtin Callback, put a summary of failures into parm - """ - failures = result.setting('failures') - if not failures: - result.parm['failures'] = [] - failures = [] - if result.ssh_retcode: - failures.append(result.host) - result.parm['failures'] = failures - return result - - -def callback_exec_command(result): - """ - Builtin Callback, pass the results to a command/script - :param result: - """ - script = result.setting("callback_script") - if not script: - return result - status_clear() - result_out, result_err = subprocess.Popen(script + " " + result.host, shell = True, stdin = subprocess.PIPE, - stdout = subprocess.PIPE, stderr = subprocess.PIPE).communicate( - result.out_string() + result.err_string()) - result.out = [result_out] - result.err = [result_err] - print result.out_string() - return result - - -def callback_aggregate_output(result): - """ Builtin Callback, Aggregate identical results """ - aggregate_hosts = result.setting('aggregate_hosts') - if not aggregate_hosts: - aggregate_hosts = {} - collapsed_output = result.setting('collapsed_output') - if not collapsed_output: - collapsed_output = {} - h = hashlib.md5() - h.update(result.out_string()) - h.update(result.err_string()) - if result.ssh_retcode: - h.update(result.ssh_error_message()) - digest = h.hexdigest() - if digest in aggregate_hosts.keys(): - aggregate_hosts[digest].append(result.host) - else: - aggregate_hosts[digest] = [result.host] - if result.ssh_retcode: - error = [] - if result.err: - error = result.err - error.append(result.ssh_error_message()) - collapsed_output[digest] = (result.out, error) - else: - collapsed_output[digest] = (result.out, result.err) - result.parm['aggregate_hosts'] = aggregate_hosts - if collapsed_output: - result.parm['collapsed_output'] = collapsed_output - return result - - -def callback_filter_match(result): - """ - Builtin Callback, remove all output if the string is not found in the output - similar to grep - """ - if result.out_string().find(result.setting('match')) == -1 and result.err_string().find( - result.setting('match')) == -1: - result.out = '' - result.err = '' - return result - - -def callback_filter_json(result): - """ - Builtin Callback, change stdout to json - - >>> result=callback_filter_json(ssh_result(["output"], ["error"],"foo", 0)) - >>> result.dump() - foo [["output"], ["error"], 0] error 0 0 None - """ - result.out = [json.dumps((result.out, result.err, result.retcode))] - return result - - -def callback_filter_base64(result): - """ - Builtin Callback, base64 encode the info in out and err streams - """ - result.out = [base64.b64encode(result.out_string)] - result.err = [base64.b64encode(result.err_string)] - return result - - -#Status callback handlers -def callback_status_count(result): - """ - Builtin Callback, show the count complete/remaining - :param result: - """ - # The master process inserts the status into the - # total_host_count and completed_host_count variables - sys.stderr.write('\x1b[0G\x1b[0K%s/%s' % ( - result.setting('completed_host_count'), result.setting('total_host_count'))) - sys.stderr.flush() - return result - - -#Output callback handlers -def callback_output_prefix_host(result): - """ - Builtin Callback, print the output with the hostname: prefixed to each line - :param result: - - >>> result=callback_output_prefix_host(ssh_result(['out'],['err'], 'hostname', 0)) - hostname: out - >>> result.dump() - hostname hostname: out hostname: Error: err 0 0 None - """ - output = [] - error = [] - status_clear() - # If summarize_failures option is set don't print ssh errors inline - if result.setting('summarize_failed') and result.ssh_retcode: - return result - if result.setting('print_rc'): - rc = ' SSH_Returncode: %d\tCommand_Returncode: %d' % (result.ssh_retcode, result.retcode) - else: - rc = '' - if result.ssh_retcode: - print >> sys.stderr, '%s: %s' % (result.host, result.ssh_error_message()) - error = ['%s: %s' % (result.host, result.ssh_error_message())] - if len(result.out_string()): - for line in result.out: - if line: - print '%s:%s %s' % (result.host, rc, line.strip()) - output.append('%s:%s %s\n' % (result.host, rc, line.strip())) - if len(result.err_string()): - for line in result.err: - if line: - print >> sys.stderr, '%s:%s %s' % (result.host, rc, line.strip()) - error.append('%s:%s Error: %s\n' % (result.host, rc, line.strip())) - if result.setting('output'): - if not len(result.out_string()) and not len(result.err_string()) and not result.setting( - 'only_output') and result.setting('print_rc'): - print '%s:%s' % (result.host, rc) - sys.stdout.flush() - sys.stderr.flush() - result.out = output - result.err = error - return result - - -def read_conf(key=None, prompt=True): - """ Read settings from the config file """ - try: - conf = json.load(open(os.path.expanduser('~/.sshmap.conf'), 'r')) - except IOError: - conf = conf_defaults - if key: - try: - return conf[key].encode('ascii') - except KeyError: - pass - else: - return conf - if key and prompt: - conf[key] = raw_input(conf_desc[key] + ': ') - fh = open(os.path.expanduser('~/.sshmap2.conf'), 'w') - os.fchmod(fh.fileno(), stat.S_IRUSR | stat.S_IWUSR) - json.dump(conf, fh) - fh.close() - return conf[key] - else: - return None - - def init_worker(): """ Set up the signal handler for new worker threads """ signal.signal(signal.SIGINT, signal.SIG_IGN) -def run(host_range, command, username=None, password=None, sudo=False, script=None, timeout=None, sort=False, - bufsize=-1, cwd='/tmp', jobs=None, output_callback=callback_summarize_failures, parms=None, shuffle=False, - chunksize=None): +def run_with_runner(*args, **kwargs): + """ + Run a command with a python runner script + :param args: + :param kwargs: + """ + if 'runner' in kwargs.keys() and isinstance( + kwargs['runner'], type.FunctionType): + kwargs['script'] = runner.get_runner( + command=args[1], + input="", + password=kwargs['password'], + runner_script=kwargs['runner'], + compressor='bz2' + ) + del kwargs['runner'] + return run(*args, **kwargs) + + +def run(host_range, command, username=None, password=None, sudo=False, + script=None, timeout=None, sort=False, + jobs=None, output_callback=[callback.summarize_failures], + parms=None, shuffle=False, chunksize=None): """ Run a command on a hostlists host_range of hosts + :param host_range: + :param command: + :param username: + :param password: + :param sudo: + :param script: + :param timeout: + :param sort: + :param jobs: + :param output_callback: + :param parms: + :param shuffle: + :param chunksize: + >>> res=run(host_range='localhost',command="echo ok") - >>> print res[0].dump() - localhost ok 0 0 {'failures': [], 'total_host_count': 1, 'completed_host_count': 1} - None + >>> print(res[0].dump()) + localhost ok 0 0 {'failures': [], 'total_host_count': 1, + 'completed_host_count': 1} """ - status_info(output_callback, 'Looking up hosts') + utility.status_info(output_callback, 'Looking up hosts') hosts = hostlists.expand(hostlists.range_split(host_range)) if shuffle: random.shuffle(hosts) - status_clear() + utility.status_clear() results = ssh_results() if parms: results.parm = parms else: - results.parm = { } + results.parm = {} if sudo and not password: for host in hosts: - result=ssh_result() - result.err='Sudo password required' - result.retcode = RUN_FAIL_NOPASSWORD + result = ssh_result() + result.host = host + result.err = 'Sudo password required' + result.retcode = defaults.RUN_FAIL_NOPASSWORD results.append(result) results.parm['total_host_count'] = len(hosts) results.parm['completed_host_count'] = 0 @@ -654,8 +465,8 @@ def run(host_range, command, username=None, password=None, sudo=False, script=No if jobs < 1: jobs = 1 - if jobs > JOB_MAX: - jobs = JOB_MAX + if jobs > defaults.JOB_MAX: + jobs = defaults.JOB_MAX # Set up our ssh client #status_info(output_callback,'Setting up the SSH client') @@ -667,15 +478,14 @@ def run(host_range, command, username=None, password=None, sudo=False, script=No results.parm['total_host_count'] = len(hosts) results.parm['completed_host_count'] = 0 - status_clear() - status_info(output_callback, 'Spawning processes') + utility.status_clear() + utility.status_info(output_callback, 'Spawning processes') if jobs > len(hosts): jobs = len(hosts) - pool = multiprocessing.Pool(processes = jobs, initializer = init_worker) + pool = multiprocessing.Pool(processes=jobs, initializer=init_worker) if not chunksize: - chunksize = 1 if jobs >= len(hosts): chunksize = 1 else: @@ -692,40 +502,49 @@ def run(host_range, command, username=None, password=None, sudo=False, script=No else: map_command = pool.imap_unordered - if isinstance(output_callback, types.ListType) and callback_status_count in output_callback: - callback_status_count(ssh_result(parm=results.parm)) + if isinstance(output_callback, types.ListType) and \ + callback.status_count in output_callback: + callback.status_count(ssh_result(parm=results.parm)) # Create a process pool and pass the parameters to it - status_clear() - status_info(output_callback, 'Sending %d commands to each process' % chunksize) - if callback_status_count in output_callback: - callback_status_count(ssh_result(parm=results.parm)) + utility.status_clear() + utility.status_info( + output_callback, 'Sending %d commands to each process' % chunksize) + if callback.status_count in output_callback: + callback.status_count(ssh_result(parm=results.parm)) try: - for result in map_command(run_command, - [(host, command, username, password, sudo, script, timeout, results.parm, client) for - host in hosts], chunksize): - #results.parm['active_processes']=len(multiprocessing.active_children()) + for result in map_command( + run_command, + [ + ( + host, command, username, password, sudo, script, timeout, + results.parm, client + ) for host in hosts + ], + chunksize + ): results.parm['completed_host_count'] += 1 result.parm = results.parm if isinstance(output_callback, types.ListType): - for callback in output_callback: - result = callback(result) + for cb in output_callback: + result = cb(result) else: result = output_callback(result) results.parm = result.parm results.append(result) pool.close() except KeyboardInterrupt: - print 'ctrl-c pressed' + print('ctrl-c pressed') pool.terminate() - #except Exception,e: + #except Exception as e: # print 'unknown error encountered',Exception,e # pass pool.terminate() - if isinstance(output_callback, types.ListType) and callback_status_count in output_callback: - status_clear() + if isinstance(output_callback, types.ListType) and \ + callback.status_count in output_callback: + utility.status_clear() return results diff --git a/sshmap/utility.py b/sshmap/utility.py new file mode 100755 index 0000000..5073544 --- /dev/null +++ b/sshmap/utility.py @@ -0,0 +1,54 @@ +#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. +""" +sshmap utility functions +""" +import sys +import callback + +__author__ = 'dhubbard' + + +def get_parm_val(parm=None, key=None): + """ + Return the value of a key + + >>> get_parm_val(parm={'test':'val'},key='test') + 'val' + >>> get_parm_val(parm={'test':'val'},key='foo') + >>> + """ + if parm and key in parm.keys(): + return parm[key] + else: + return None + + +def status_info(callbacks, text): + """ + Update the display line at the cursor + """ + #print callbacks,text + #return + if isinstance(callbacks, list) and \ + callback.status_count in callbacks: + status_clear() + sys.stderr.write(text) + sys.stderr.flush() + + +def status_clear(): + """ + Clear the status line (current line) + """ + sys.stderr.write('\x1b[0G\x1b[0K') + #sys.stderr.flush() \ No newline at end of file diff --git a/tests/test.py b/tests/test.py new file mode 100755 index 0000000..d126442 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +#Copyright (c) 2012 Yahoo! Inc. All rights reserved. +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. +""" +Unit tests of sshmap +""" +__author__ = 'dhubbard' +import sshmap +import os +import unittest + + +class TestSSH(unittest.TestCase): + """ + sshmap unit tests + """ + def set_up(self): + pass + + def test_shell_command_as_user(self): + """Run a ssh command to localhost and verify it works """ + result = os.popen('sshmap/sshmap localhost echo hello').read().strip() + self.assertEqual('localhost: hello', result) + + def test_shell_command_sudo(self): + """Run a ssh command to localhost using sudo and verify it works""" + result = os.popen('sshmap/sshmap localhost --sudo id').read().strip() + self.assert_( + 'localhost: uid=0(root) gid=0(root) groups=0(root)' in result) + + def test_shell_script_as_user(self): + # Run a ssh command to localhost and verify it works + open('testscript.test', 'w').write('#!/bin/bash\necho hello\n') + result = os.popen( + 'sshmap/sshmap localhost --runscript testscript.test' + ).read().strip() + self.assertEqual('localhost: hello', result) + os.remove('testscript.test') + + def test_shell_script_sudo(self): + """Run a ssh command to localhost and verify it works """ + open('testscript.test', 'w').write('#!/bin/bash\nid\n') + result = os.popen( + 'sshmap/sshmap localhost --runscript testscript.test --sudo ' + '--timeout 15' + ).read().strip() + self.assert_( + 'localhost: uid=0(root) gid=0(root) groups=0(root)' in result) + os.remove('testscript.test') + + def run_with_runner(self): + """ + Execute an rpc call without arguments via ssh + """ + result = sshmap.run_with_runner( + 'localhost', + 'uname -n', + runner=sshmap.runner.script_stdin + ) + self.assertEqual(result, os.popen('uname -n').read()) + + +if __name__ == '__main__': + unittest.main()