Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect using a SSH transport #634

Merged
merged 7 commits into from
Sep 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ Features:
* Set `program_name` connection attribute (Thanks: [Dick Marinus]).
* Use `return` to terminate a generator (Thanks: [Zhongyang Guan]).
* Add `SAVEPOINT` to SQLCompleter (Thanks: [Huachao Mao]).
* Connect using a SSH transport (Thanks: [Dick Marinus]).
* Add `FROM_UNIXTIME` and `UNIX_TIMESTAMP` to SQLCompleter (Thanks: [QiaoHou Peng])
* Seach `${PWD}/.myclirc`, then `${HOME}/.myclirc`, last `/etc/myclirc` (Thanks: [QiaoHao Peng])
* Search `${PWD}/.myclirc`, then `${HOME}/.myclirc`, lastly `/etc/myclirc` (Thanks: [QiaoHao Peng])

Bug Fixes:
----------
Expand Down
4 changes: 3 additions & 1 deletion mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
# Create a new pgexecute method to popoulate the completions.
e = sqlexecute
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
e.socket, e.charset, e.local_infile, e.ssl)
e.socket, e.charset, e.local_infile, e.ssl,
e.ssh_user, e.ssh_host, e.ssh_port,
e.ssh_password, e.ssh_key_filename)

# If callbacks is a single function then push it into a list.
if callable(callbacks):
Expand Down
48 changes: 41 additions & 7 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
import re
import fileinput

try:
import paramiko
except:
paramiko = False

# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])

Expand Down Expand Up @@ -337,7 +342,9 @@ def merge_ssl_with_cnf(self, ssl, cnf):
return merged

def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl=''):
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
ssh_password='', ssh_key_filename=''):

cnf = {'database': None,
'user': None,
Expand Down Expand Up @@ -390,14 +397,20 @@ def connect(self, database='', user='', passwd='', host='', port='',

def _connect():
try:
self.sqlexecute = SQLExecute(database, user, passwd, host, port,
socket, charset, local_infile, ssl)
self.sqlexecute = SQLExecute(
database, user, passwd, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port,
ssh_password, ssh_key_filename
)
except OperationalError as e:
if ('Access denied for user' in e.args[1]):
new_passwd = click.prompt('Password', hide_input=True,
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(database, user, new_passwd, host, port,
socket, charset, local_infile, ssl)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
ssh_port, ssh_password, ssh_key_filename
)
else:
raise e

Expand Down Expand Up @@ -949,6 +962,11 @@ def get_last_query(self):
help='Password to connect to the database.')
@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str,
help='Password to connect to the database.')
@click.option('--ssh-user', help='User name to connect to ssh server.')
@click.option('--ssh-host', help='Host name to connect to ssh server.')
@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
@click.option('--ssh-password', help='Password to connect to ssh server.')
@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
@click.option('--ssl-ca', help='CA file in PEM format.',
type=click.Path(exists=True))
@click.option('--ssl-capath', help='CA directory.')
Expand Down Expand Up @@ -1001,7 +1019,8 @@ def cli(database, user, host, port, socket, password, dbname,
defaults_file, login_path, auto_vertical_output, local_infile,
ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
list_dsn):
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename):
"""A MySQL terminal client with auto-completion and syntax highlighting.

\b
Expand Down Expand Up @@ -1081,6 +1100,16 @@ def cli(database, user, host, port, socket, password, dbname,
if not port:
port = uri.port

if not paramiko and ssh_host:
click.secho(
"Cannot use SSH transport because paramiko isn't installed, "
"please install paramiko or don't use --ssh-host=",
err=True, fg="red"
)
exit(1)

ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)

mycli.connect(
database=database,
user=user,
Expand All @@ -1089,7 +1118,12 @@ def cli(database, user, host, port, socket, password, dbname,
port=port,
socket=socket,
local_infile=local_infile,
ssl=ssl
ssl=ssl,
ssh_user=ssh_user,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename
)

mycli.logger.debug('Launch Params: \n'
Expand Down
69 changes: 58 additions & 11 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from pymysql.converters import (convert_mysql_timestamp, convert_datetime,
convert_timedelta, convert_date, conversions,
decoders)
try:
import paramiko
except:
paramiko = False

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,7 +41,8 @@ class SQLExecute(object):
order by table_name,ordinal_position'''

def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl=False):
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename):
self.dbname = database
self.user = user
self.password = password
Expand All @@ -49,10 +54,17 @@ def __init__(self, database, user, password, host, port, socket, charset,
self.ssl = ssl
self._server_type = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.ssh_password = ssh_password
self.ssh_key_filename = ssh_key_filename
self.connect()

def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None, ssl=None):
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
ssh_password=None, ssh_key_filename=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
Expand All @@ -62,16 +74,29 @@ def connect(self, database=None, user=None, password=None, host=None,
charset = (charset or self.charset)
local_infile = (local_infile or self.local_infile)
ssl = (ssl or self.ssl)
_logger.debug('Connection DB Params: \n'
ssh_user = (ssh_user or self.ssh_user)
ssh_host = (ssh_host or self.ssh_host)
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
_logger.debug(
'Connection DB Params: \n'
'\tdatabase: %r'
'\tuser: %r'
'\thost: %r'
'\tport: %r'
'\tsocket: %r'
'\tcharset: %r'
'\tlocal_infile: %r'
'\tssl: %r',
db, user, host, port, socket, charset, local_infile, ssl)
'\tssl: %r'
'\tssh_user: %r'
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
'\tssh_key_filename: %r',
db, user, host, port, socket, charset, local_infile, ssl,
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename
)
conv = conversions.copy()
conv.update({
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_mysql_timestamp(obj) or obj),
Expand All @@ -80,12 +105,34 @@ def connect(self, database=None, user=None, password=None, host=None,
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
})

conn = pymysql.connect(database=db, user=user, password=password,
host=host, port=port, unix_socket=socket,
use_unicode=True, charset=charset, autocommit=True,
client_flag=pymysql.constants.CLIENT.INTERACTIVE,
local_infile=local_infile,
conv=conv, ssl=ssl, program_name="mycli")
defer_connect = False

if ssh_host:
defer_connect = True

conn = pymysql.connect(
database=db, user=user, password=password, host=host, port=port,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE,
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
defer_connect=defer_connect
)

if ssh_host and paramiko:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, ssh_password,
key_filename=ssh_key_filename
)
chan = client.get_transport().open_channel(
'direct-tcpip',
(host, port),
('0.0.0.0', 0),
)
conn.connect(chan)

if hasattr(self, 'conn'):
self.conn.close()
self.conn = conn
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ def run_tests(self):
'Topic :: Software Development',
'Topic :: Software Development :: Libraries :: Python Modules',
],
extras_require={
'ssh': ['paramiko'],
},
)
8 changes: 5 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from utils import (HOST, USER, PASSWORD, PORT,
CHARSET, create_db, db_connection)
from utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
import mycli.sqlexecute


Expand All @@ -24,4 +24,6 @@ def executor(connection):
return mycli.sqlexecute.SQLExecute(
database='_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False)
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
)
3 changes: 3 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
HOST = os.getenv('PYTEST_HOST', 'localhost')
PORT = os.getenv('PYTEST_PORT', 3306)
CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
SSH_USER = os.getenv('PYTEST_SSH_USER', None)
SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)


def db_connection(dbname=None):
Expand Down