Skip to content

Commit

Permalink
Enable pushing files over tunnels
Browse files Browse the repository at this point in the history
  • Loading branch information
lordmauve committed Aug 3, 2016
1 parent 46ce3d1 commit e3a6a52
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 59 deletions.
134 changes: 102 additions & 32 deletions chopsticks/bubble.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def exec_(_code_, _globs_=None, _locs_=None):
from collections import namedtuple
import signal
from hashlib import sha1
import traceback

outqueue = Queue(maxsize=10)
tasks = Queue()
outqueue = Queue()
done = object()

running = True
Expand Down Expand Up @@ -109,16 +109,40 @@ def get_source(self, fullname):
sys.path_hooks.append(Loader)


def transmit_errors(func):
def wrapper(req_id, *args, **kwargs):
try:
return func(req_id, *args, **kwargs)
except:
outqueue.put({
'req_id': req_id,
'tb': traceback.format_exc()
})
return wrapper


def handle_call(req_id, params):
"""Pass a request to the main thread."""
threading.Thread(target=handle_call_thread, args=(req_id, params)).start()


def handle_call_thread(req_id, params):
try:
callable, args, kwargs = pickle.loads(base64.b64decode(params))
except:
ret = callable(*args, **kwargs)
except:
import traceback
msg = {
'req_id': req_id,
'tb': traceback.format_exc()
}
else:
tasks.put((req_id, callable, args, kwargs))

msg = {
'req_id': req_id,
'ret': ret,
}
outqueue.put(msg)

@transmit_errors
def handle_fetch(req_id, path):
"""Fetch a file by path."""
tasks.put((req_id, do_fetch, (req_id, path,)))
Expand Down Expand Up @@ -146,28 +170,86 @@ def do_fetch(req_id, path):
}


@transmit_errors
def do_call(req_id, callable, args=(), kwargs={}):
try:
ret = callable(*args, **kwargs)
except:
import traceback
msg = {
'req_id': req_id,
'tb': traceback.format_exc()
}
else:
msg = {
'req_id': req_id,
'ret': ret,
}
outqueue.put(msg)
ret = callable(*args, **kwargs)
outqueue.put({
'req_id': req_id,
'ret': ret,
})


def handle_imp(mod, exists, is_pkg, file, source):
Loader.cache[mod] = Imp(exists, is_pkg, file, source)
Loader.ev.set()


active_puts = {}


@transmit_errors
def handle_begin_put(req_id, path, mode):
prev_umask = os.umask(0o077)
try:
if path is None:
import tempfile
f = tempfile.NamedTemporaryFile(delete=False)
path = wpath = f.name
else:
if os.path.isdir(path):
raise IOError('%s is a directory' % path)
wpath = path + '~chopsticks-tmp'
f = open(wpath, 'wb')
finally:
os.umask(prev_umask)
os.fchmod(f.fileno(), mode)
active_puts[req_id] = (f, wpath, path, sha1())


@transmit_errors
def handle_put_data(req_id, data):
f, wpath, path, cksum = active_puts[req_id]
try:
data = base64.b64decode(data)
cksum.update(data)
f.write(data)
except:
try:
os.unlink(wpath)
f.close()
except OSError:
pass
raise


class ChecksumMismatch(Exception):
pass


@transmit_errors
def handle_end_put(req_id, sha1sum):
f, wpath, path, cksum = active_puts.pop(req_id)
received = f.tell()
f.close()
digest = cksum.hexdigest()
if digest != sha1sum:
try:
os.unlink(wpath)
except OSError:
pass
raise ChecksumMismatch('Checksum failed for transfer %s' % path)
if wpath != path:
os.rename(wpath, path)
outqueue.put({
'req_id': req_id,
'ret': {
'remote_path': os.path.abspath(path),
'sha1sum': digest,
'size': received
}
})


def read_msg():
buf = inpipe.read(4)
if not buf:
Expand All @@ -192,11 +274,6 @@ def reader():
handler(**obj)
finally:
outqueue.put(done)
tasks.put(done)
# SIGINT will raise KeyboardInterrupt in the main (ie. task) thread
# TODO: Perhaps give this some timeout, in case operations can complete
# successfully?
os.kill(os.getpid(), signal.SIGINT)


def writer():
Expand All @@ -214,10 +291,3 @@ def writer():

for func in (reader, writer):
threading.Thread(target=func).start()


while True:
task = tasks.get()
if task is done:
break
do_call(*task)
4 changes: 4 additions & 0 deletions chopsticks/ioloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,17 @@ def write_iter(self, iterable):
"""
self.queue.append(iterable)
self.loop.want_write(self.fd, self.on_write)

def on_write(self):
if not self.queue:
return
try:
written = os.write(self.fd, self.queue[0])
except OSError:
# TODO: handle errors properly
import traceback
traceback.print_exc()
return
b = self.queue[0] = self.queue[0][written:]
if not b:
Expand Down
131 changes: 104 additions & 27 deletions chopsticks/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ def connect(self):
"""Connect the tunnel."""
raise NotImplementedError('Subclasses must implement connect()')

def write_msg(self, msg):
"""Write a JSON message to the tunnel."""
raise NotImplementedError('Subclasses must implement write_msg()')
def write_msg(self, op, **kwargs):
"""Write one message to the subprocess.
This uses a chunked JSON protocol.
"""
kwargs['op'] = op
self.writer.write(kwargs)

def _next_id(self):
self.req_id += 1
Expand All @@ -85,19 +90,23 @@ def handle_imp(self, mod):
(True, os.path.join(stem, '__init__.py')),
(False, stem + '.py'),
]
for root in sys.path:
for is_pkg, rel in paths:
path = os.path.join(root, rel)
if os.path.exists(path):
self.write_msg(
'imp',
mod=mod,
exists=True,
is_pkg=is_pkg,
file=rel,
source=open(path, 'r').read()
)
return

try:
for root in sys.path:
for is_pkg, rel in paths:
path = os.path.join(root, rel)
if os.path.exists(path):
self.write_msg(
'imp',
mod=mod,
exists=True,
is_pkg=is_pkg,
file=rel,
source=open(path, 'r').read()
)
return
except:
pass
self.write_msg(
'imp',
mod=mod,
Expand Down Expand Up @@ -125,17 +134,22 @@ def on_message(self, msg):
elif 'ret' in msg:
id = msg['req_id']
if id not in self.callbacks:
# TODO: warn
self._warn('response received for unknown req_id %d' % id)
return

self.callbacks.pop(id)(msg['ret'])
self.reader.stop()
elif 'data' in msg:
id = msg['req_id']
if id not in self.callbacks:
# TODO: warn
self._warn('response received for unknown req_id %d' % id)
return
self.callbacks[id].recv(msg['data'])
else:
self._warn('malformed message received: %r' % msg)

def _warn(self, msg):
print('%s:' % self.host, msg, file=sys.stderr)

def call(self, callable, *args, **kwargs):
"""Call the given callable on the remote host.
Expand Down Expand Up @@ -167,6 +181,9 @@ def fetch(self, remote_path, local_path=None):
If local_path is given, it is the local path to write to. Otherwise,
a temporary filename will be used.
This operation supports arbitarily large files (file data is streamed,
not buffered in memory).
The return value is a dict containing:
* ``local_path`` - the local path written to
Expand All @@ -192,6 +209,75 @@ def _fetch_async(self, on_result, remote_path, local_path=None):
path=remote_path,
)

def put(self, local_path, remote_path=None, mode=0o644):
"""Copy a file to the remote host.
If remote_path is given, it is the remote path to write to. Otherwise,
a temporary filename will be used.
This operation supports arbitarily large files (file data is streamed,
not buffered in memory).
The return value is a dict containing:
* ``remote_path`` - the absolute remote path
* ``size`` - the number of bytes received
* ``sha1sum`` - a sha1 checksum of the file data
"""
self._put_async(loop.stop, local_path, remote_path, mode)
ret = loop.run()
if isinstance(ret, ErrorResult):
raise RemoteException(ret.msg)
return ret

def _put_async(
self,
on_result,
local_path,
remote_path=None,
mode=0o644):
id = self._next_id()
self.callbacks[id] = on_result
self.reader.start()
self.write_msg(
'begin_put',
req_id=id,
path=remote_path,
mode=mode
)
self.writer.write_iter(
iter_chunks(id, local_path)
)


def iter_chunks(req_id, path):
"""Iterate over chunks of the given file.
Yields message suitable for writing to a stream.
"""
chksum = sha1()
with open(path, 'rb') as f:
while True:
chunk = f.read(10240)
if not chunk:
yield {
'op': 'end_put',
'req_id': req_id,
'sha1sum': chksum.hexdigest()
}
break
chksum.update(chunk)
data = base64.b64encode(chunk)
if not PY2:
data = data.decode('ascii')
yield {
'op': 'put_data',
'req_id': req_id,
'data': data
}


class Fetch(object):
def __init__(self, on_result, local_path=None):
Expand Down Expand Up @@ -248,15 +334,6 @@ def on_error(self, err):
for id in list(self.callbacks):
self.callbacks.pop(id)(err)

def write_msg(self, op, **kwargs):
"""Write one message to the subprocess.
This uses a chunked JSON protocol.
"""
kwargs['op'] = op
self.writer.write(kwargs)


class SubprocessTunnel(PipeTunnel):
"""A tunnel that connects to a subprocess."""
Expand Down

0 comments on commit e3a6a52

Please sign in to comment.