Skip to content

Commit

Permalink
Various fixes to work with python 3
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo committed Dec 22, 2016
1 parent b22d71f commit 7586353
Show file tree
Hide file tree
Showing 18 changed files with 173 additions and 107 deletions.
5 changes: 3 additions & 2 deletions labrad/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import getpass
from builtins import input

from labrad import constants

Expand Down Expand Up @@ -77,8 +78,8 @@ def get_username_and_password(host, port, prompt=True):
break
return passwords[user]
elif prompt:
user = raw_input('Enter username, or blank for the global user '
'({}:{}): '.format(host, port))
user = input('Enter username, or blank for the global user '
'({}:{}): '.format(host, port))
password = _prompt_for_password(host, port, user)
return Password(user, password)
else:
Expand Down
14 changes: 5 additions & 9 deletions labrad/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
from labrad.support import (mangle, indent, PrettyMultiDict, FlatPacket,
PacketRecord, PacketResponse, hexdump)

class NotFoundError(Error):
code = 10
def __init__(self, key):
self.msg = 'Could not find "%s".' % key

def unwrap(s, after='|'):
def trim(line):
Expand Down Expand Up @@ -140,10 +136,7 @@ def __getitem__(self, key):
# force refresh and try again
if self._parent:
self._parent.refresh(now=True)
try:
return super(DynamicAttrDict, self).__getitem__(key)
except KeyError:
raise NotFoundError(key)
return super(DynamicAttrDict, self).__getitem__(key)


class HasDynamicAttrs(object):
Expand Down Expand Up @@ -243,7 +236,10 @@ def __dir__(self):
return sorted(set(self._attrs.keys() + self.__dict__.keys() + dir(type(self))))

def __getattr__(self, key):
return self._attrs[key]
try:
return self._attrs[key]
except KeyError:
raise AttributeError(key)

def __getitem__(self, key):
return self._attrs[key]
Expand Down
1 change: 1 addition & 0 deletions labrad/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import types

import twisted.internet.defer as defer
from past.builtins import basestring # for python 2/3 compatibility

from labrad import types as T, util

Expand Down
21 changes: 15 additions & 6 deletions labrad/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@
from __future__ import absolute_import
from __future__ import print_function

import BaseHTTPServer
import json
import logging
import os
import threading
import time
import urllib
import urlparse
import webbrowser
from builtins import input

import requests
from concurrent import futures

try:
import BaseHTTPServer as http_server
except ImportError:
import http.server as http_server # python 3

try:
import urlparse
except ImportError:
import urllib.parse as urlparse


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,16 +123,16 @@ def do_headless_login():
login_uri = _create_login_uri(client_id, redirect_uri)
print('To obtain an OAuth authorization code, please navigate to the '
'following URL in a browser:\n\n{}\n'.format(login_uri))
code = raw_input('When you have completed the login flow, enter the '
'code here: ')
code = input('When you have completed the login flow, enter the '
'code here: ')
return redirect_uri, code

if headless:
redirect_uri, code = do_headless_login()
else:
code_future = futures.Future()

class OAuthHandler(BaseHTTPServer.BaseHTTPRequestHandler):
class OAuthHandler(http_server.BaseHTTPRequestHandler):
def do_GET(self):
parsed_path = urlparse.urlparse(self.path)
params = urlparse.parse_qs(parsed_path.query)
Expand All @@ -138,7 +147,7 @@ def log_request(self, *args, **kw):
pass

# Start local http server to receive redirect on random port
httpd = BaseHTTPServer.HTTPServer(('localhost', 0), OAuthHandler)
httpd = http_server.HTTPServer(('localhost', 0), OAuthHandler)
_, local_port = httpd.server_address

redirect_uri = 'http://localhost:{}'.format(local_port)
Expand Down
19 changes: 13 additions & 6 deletions labrad/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from __future__ import print_function

import hashlib
from builtins import input

from twisted.internet import reactor, protocol, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import failure, log
import labrad.types as T

import labrad.types as T
from labrad import auth, constants as C, crypto, errors, oauth, support, util
from labrad.stream import packetStream, flattenPacket

Expand All @@ -50,7 +51,7 @@ def __init__(self):
self.request_handler = None
# create a generator to assemble the packets
self.packetStream = packetStream(self.packetReceived, self.endianness)
self.packetStream.next() # start the packet stream
next(self.packetStream) # start the packet stream

self.onDisconnect = util.DeferredSignal()

Expand Down Expand Up @@ -87,7 +88,7 @@ def connectionLost(self, reason):
called, or because of some network error.
"""
self.disconnected = True
for d in self.requests.values():
for d in list(self.requests.values()):
d.errback(Exception('Connection lost.'))
if reason == protocol.connectionDone:
self.onDisconnect.callback(None)
Expand Down Expand Up @@ -429,7 +430,10 @@ def require_secure_connection(auth_type):
# send password response
m = hashlib.md5()
m.update(challenge)
m.update(credential.password)
if isinstance(credential.password, bytes):
m.update(credential.password)
else:
m.update(credential.password.encode('UTF-8'))
try:
resp = yield self._sendManagerRequest(0, m.digest())
except Exception:
Expand Down Expand Up @@ -494,7 +498,10 @@ def _doLogin(self, *ident):
# Store name, which is always the first identification param.
self.name = ident[0]
# Send identification.
self.ID = yield self._sendManagerRequest(0, (1L,) + ident)
data = (1,) + ident
tag = 'w' + 's'*len(ident)
flat = T.flatten(data, tag)
self.ID = yield self._sendManagerRequest(0, flat)


class MessageContext(object):
Expand Down Expand Up @@ -604,7 +611,7 @@ def ping(p):
print('SHA1 Fingerprint={}'.format(crypto.fingerprint(cert)))
print()
while True:
ans = raw_input(
ans = input(
'Accept server certificate for host "{}"? (accept just '
'this [O]nce; [S]ave and always accept this cert; '
'[R]eject) '.format(host))
Expand Down
5 changes: 5 additions & 0 deletions labrad/ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

import operator

try:
_ = long
except NameError:
long = int # python 3 has no long type

def gcd(a, b):
"""Compute the greatest common divisor of two ints."""
a, b = b, a % b
Expand Down
10 changes: 5 additions & 5 deletions labrad/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

HEADER_TYPE = T.parseTypeTag('(ww)iww')
PACKET_TYPE = T.parseTypeTag('(ww)iws')
RECORD_TYPE = T.parseTypeTag('wss')
RECORD_TYPE = T.parseTypeTag('wsy')

def packetStream(packetHandler, endianness='>'):
"""A generator that assembles packets.
Accepts a function packetHandler that will be called with four arguments
whenever a packet is completed: source, context, request, records.
"""
buf = ''
buf = b''
while True:
# get packet header (20 bytes)
while len(buf) < 20:
Expand Down Expand Up @@ -52,13 +52,13 @@ def flattenPacket(target, context, request, records, endianness='>'):
data = records
else:
kw = {'endianness': endianness}
data = ''.join(flattenRecord(*rec, **kw) for rec in records)
data = b''.join(flattenRecord(*rec, **kw) for rec in records)
flat = PACKET_TYPE.flatten((context, request, target, data), endianness)
return flat.bytes

def flattenRecords(records, endianness='>'):
kw = {'endianness': endianness}
return ''.join(flattenRecord(*rec, **kw) for rec in records)
return b''.join(flattenRecord(*rec, **kw) for rec in records)

def flattenRecord(ID, data, types=[], endianness='>'):
"""Flatten a piece of data into a record with datatype and property."""
Expand All @@ -67,7 +67,7 @@ def flattenRecord(ID, data, types=[], endianness='>'):
except T.FlatteningError as e:
e.msg = e.msg + "\nSetting ID %s." % (ID,)
raise
flat_record = RECORD_TYPE.flatten((ID, str(flat.tag), str(flat.bytes)),
flat_record = RECORD_TYPE.flatten((ID, str(flat.tag), bytes(flat.bytes)),
endianness)
return flat_record.bytes

2 changes: 1 addition & 1 deletion labrad/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def testCompoundPacket(self):

# test using keys to refer to parts of a packet
pkt2 = pts.packet()
resp = pkt2.echo(1L, key='one')\
resp = pkt2.echo(1, key='one')\
.echo_delay(T.Value(0.1, 's'))\
.delayed_echo('blah', key='two')\
.send()
Expand Down
10 changes: 5 additions & 5 deletions labrad/test/test_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Queue
import queue
import random

import pytest
Expand All @@ -10,13 +10,13 @@

def test_server_expire_context_method_is_called():
"""Ensure that server's expireContext method is called when client disconnects."""
queue = Queue.Queue()
q = queue.Queue()

class TestServer(LabradServer):
name = "TestServer"

def expireContext(self, c):
queue.put(c.ID)
q.put(c.ID)

@setting(100)
def echo(self, c, data):
Expand All @@ -25,9 +25,9 @@ def echo(self, c, data):
with util.syncRunServer(TestServer()):
with labrad.connect() as cxn:
# create a random context owned by this connection
request_context = (cxn.ID, random.randint(0, 0xFFFFFFFFL))
request_context = (cxn.ID, random.randint(0, 2**32-1))
cxn.testserver.echo('hello, world!', context=request_context)
expired_context = queue.get(block=True, timeout=1)
expired_context = q.get(block=True, timeout=1)
assert expired_context == request_context


Expand Down
16 changes: 8 additions & 8 deletions labrad/test/test_signal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Queue
import queue

import pytest

Expand Down Expand Up @@ -32,26 +32,26 @@ def _test_signals(signal_setting, fire_setting):
msg_id = 12345

# add a listener to enqueue all messages
queue = Queue.Queue()
q = queue.Queue()
def handler(message_ctx, msg):
queue.put((message_ctx, msg))
q.put((message_ctx, msg))
cxn._backend.cxn.addListener(handler, ID=msg_id)

# we don't get messages before signing up
server[fire_setting]('not listening')
with pytest.raises(Queue.Empty):
queue.get(block=True, timeout=1)
with pytest.raises(queue.Empty):
q.get(block=True, timeout=1)

server[signal_setting](msg_id)
server[fire_setting]('listening')
msg_ctx, msg = queue.get(block=True, timeout=1)
msg_ctx, msg = q.get(block=True, timeout=1)
assert msg_ctx.source == server.ID
assert msg == 'listening'

server[signal_setting]()
server[fire_setting]('not listening')
with pytest.raises(Queue.Empty):
queue.get(block=True, timeout=1)
with pytest.raises(queue.Empty):
q.get(block=True, timeout=1)


def test_signal_can_be_defined_on_server_class():
Expand Down
13 changes: 6 additions & 7 deletions labrad/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def testDefaultFlatAndBack(self):
None,
True, False,
1, -1, 2, -2, 0x7FFFFFFF, -0x80000000,
1L, 2L, 3L, 4L, 0L, 0xFFFFFFFFL,
'', 'a', '\x00\x01\x02\x03',
datetime.now(),

Expand Down Expand Up @@ -145,7 +144,7 @@ def testDefaultFlatAndBack(self):
[['a', 'bb', 'ccc'], ['dddd', 'eeeee', 'ffffff']],

# more complex stuff
[(1L, 'a'), (2L, 'b')],
[(1, 'a'), (2, 'b')],
]
for data_in in tests:
data_out = T.unflatten(*T.flatten(data_in))
Expand Down Expand Up @@ -261,7 +260,7 @@ def testTypeHints(self):
# handle unknown pieces inside clusters and lists
(['a', 'b'], ['*?'], '*s'),
((1, 2, 'a'), ['ww?'], 'wws'),
((1, 1L), ['??'], 'iw'),
((1, 2), ['??'], 'iw'),
]
for data, hints, tag in passingTests:
self.assertEqual(T.flatten(data, hints)[1], T.parseTypeTag(tag))
Expand Down Expand Up @@ -312,8 +311,8 @@ def testNumpySupport(self):
a = np.array([1, 2, 3, 4, 5], dtype='int32')
b = T.unflatten(*T.flatten(a))
self.assertTrue(np.all(a == b))
self.assertTrue(T.flatten(np.int32(5))[0] == '\x00\x00\x00\x05')
self.assertTrue(T.flatten(np.int64(-5))[0] == '\xff\xff\xff\xfb')
self.assertTrue(T.flatten(np.int32(5))[0] == b'\x00\x00\x00\x05')
self.assertTrue(T.flatten(np.int64(-5))[0] == b'\xff\xff\xff\xfb')
self.assertTrue(len(T.flatten(np.float64(3.15))[0]) == 8)
with self.assertRaises(T.FlatteningError):
T.flatten(np.int64(-5), T.LRWord())
Expand Down Expand Up @@ -353,8 +352,8 @@ def testUnicodeBytes(self):
foo = T.flatten('foo bar')
self.assertEquals(foo, T.flatten(u'foo bar'))
self.assertEquals(str(foo.tag), 's')
self.assertEquals(T.unflatten(foo.bytes, 'y'), 'foo bar')
self.assertEquals(T.unflatten(*T.flatten('foo bar', ['y'])), 'foo bar')
self.assertEquals(T.unflatten(foo.bytes, 'y'), b'foo bar')
self.assertEquals(T.unflatten(*T.flatten(b'foo bar', ['y'])), b'foo bar')

def testFlattenIntArrayToValueArray(self):
x = np.array([1, 2, 3, 4], dtype='int64')
Expand Down
Loading

0 comments on commit 7586353

Please sign in to comment.