Skip to content

Commit

Permalink
Merge pull request #32 from fjl/eip8
Browse files Browse the repository at this point in the history
EIP-8
  • Loading branch information
konradkonrad committed Feb 22, 2016
2 parents 88c7943 + ea0bd3e commit 3b230c4
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 139 deletions.
10 changes: 5 additions & 5 deletions devp2p/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
hmac_sha256 = pyelliptic.hmac_sha256


class ECIESDecryptionError(Exception):
class ECIESDecryptionError(RuntimeError):
pass


Expand Down Expand Up @@ -104,7 +104,7 @@ def is_valid_key(self, raw_pubkey, raw_privkey=None):
return not failed

@classmethod
def ecies_encrypt(cls, data, raw_pubkey):
def ecies_encrypt(cls, data, raw_pubkey, shared_mac_data=''):
"""
ECIES Encrypt, where P = recipient public key is:
1) generate r = random value
Expand Down Expand Up @@ -149,15 +149,15 @@ def ecies_encrypt(cls, data, raw_pubkey):
msg = chr(0x04) + ephem_pubkey + iv + ciphertext

# the MAC of a message (called the tag) as per SEC 1, 3.5.
tag = hmac_sha256(key_mac, msg[1 + 64:])
tag = hmac_sha256(key_mac, msg[1 + 64:] + shared_mac_data)
assert len(tag) == 32
msg += tag

assert len(msg) == 1 + 64 + 16 + 32 + len(data) == 113 + len(data)
assert len(msg) - cls.ecies_encrypt_overhead_length == len(data)
return msg

def ecies_decrypt(self, data):
def ecies_decrypt(self, data, shared_mac_data=''):
"""
Decrypt data with ECIES method using the local private key
Expand Down Expand Up @@ -190,7 +190,7 @@ def ecies_decrypt(self, data):
assert len(tag) == 32

# 2) verify tag
if not pyelliptic.equals(hmac_sha256(key_mac, data[1 + 64:- 32]), tag):
if not pyelliptic.equals(hmac_sha256(key_mac, data[1 + 64:- 32] + shared_mac_data), tag):
raise ECIESDecryptionError("Fail to verify data")

# 3) decrypt
Expand Down
23 changes: 14 additions & 9 deletions devp2p/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ class DiscoveryProtocol(kademlia.WireInterface):
cmd_id_map = dict(ping=1, pong=2, find_node=3, neighbours=4)
rev_cmd_id_map = dict((v, k) for k, v in cmd_id_map.items())

# number of required top-level list elements for each cmd_id.
# elements beyond this length are trimmed.
cmd_elem_count_map = dict(ping=4, pong=3, find_node=2, neighbours=2)

encoders = dict(cmd_id=chr,
expiration=rlp.sedes.big_endian_int.serialize)

Expand Down Expand Up @@ -300,19 +304,24 @@ def unpack(self, message):
# if not crypto.verify(remote_pubkey, signature, signed_data):
# raise InvalidSignature()
cmd_id = self.decoders['cmd_id'](message[97])
assert cmd_id in self.cmd_id_map.values()
payload = rlp.decode(message[98:])
cmd = self.rev_cmd_id_map[cmd_id]
payload = rlp.decode(message[98:], strict=False)
assert isinstance(payload, list)
expiration = self.decoders['expiration'](payload.pop())
if time.time() > expiration:
raise PacketExpired()
# ignore excessive list elements as required by EIP-8.
payload = payload[:self.cmd_elem_count_map.get(cmd, len(payload))]
return remote_pubkey, cmd_id, payload, mdc

def receive(self, address, message):
log.debug('<<< message', address=address)
assert isinstance(address, Address)
try:
remote_pubkey, cmd_id, payload, mdc = self.unpack(message)
# Note: as of discovery version 4, expiration is the last element for all
# packets. This might not be the case for a later version, but just popping
# the last element is good enough for now.
expiration = self.decoders['expiration'](payload.pop())
if time.time() > expiration:
raise PacketExpired()
except DefectiveMessage:
return
cmd = getattr(self, 'recv_' + self.rev_cmd_id_map[cmd_id])
Expand Down Expand Up @@ -376,10 +385,6 @@ def recv_ping(self, nodeid, payload, mdc):
return
node = self.get_node(nodeid)
log.debug('<<< ping', node=node)
version = rlp.sedes.big_endian_int.deserialize(payload[0])
if version != self.version:
log.error('wrong version', remote_version=version, expected_version=self.version)
return
remote_address = Address.from_endpoint(*payload[1]) # from address
my_address = Address.from_endpoint(*payload[2]) # my address
self.get_node(nodeid).address.update(remote_address)
Expand Down
17 changes: 8 additions & 9 deletions devp2p/muxsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@


class MultiplexedSession(multiplexer.Multiplexer):
def __init__(self, privkey, hello_packet, token_by_pubkey=dict(), remote_pubkey=None):
def __init__(self, privkey, hello_packet, remote_pubkey=None):
self.is_initiator = bool(remote_pubkey)
self.hello_packet = hello_packet
self.message_queue = gevent.queue.Queue() # wire msg egress queue
self.packet_queue = gevent.queue.Queue() # packet ingress queue
ecc = ECCx(raw_privkey=privkey)
self.rlpx_session = RLPxSession(
ecc, is_initiator=bool(remote_pubkey), token_by_pubkey=token_by_pubkey)
ecc, is_initiator=bool(remote_pubkey))
self._remote_pubkey = remote_pubkey
self.token_by_pubkey = token_by_pubkey
multiplexer.Multiplexer.__init__(self, frame_cipher=self.rlpx_session)
if self.is_initiator:
self._send_init_msg()
Expand Down Expand Up @@ -43,19 +42,19 @@ def _add_message_during_handshake(self, msg):
session = self.rlpx_session
if self.is_initiator:
# expecting auth ack message
session.decode_auth_ack_message(msg[:session.auth_ack_message_ct_length])
rest = session.decode_auth_ack_message(msg)
session.setup_cipher()
if len(msg) > session.auth_ack_message_ct_length: # add remains (hello) to queue
self._add_message_post_handshake(msg[session.auth_ack_message_ct_length:])
if len(rest) > 0: # add remains (hello) to queue
self._add_message_post_handshake(rest)
else:
# expecting auth_init
session.decode_authentication(msg[:session.auth_message_ct_length])
rest = session.decode_authentication(msg)
auth_ack_msg = session.create_auth_ack_message()
auth_ack_msg_ct = session.encrypt_auth_ack_message(auth_ack_msg)
self.message_queue.put(auth_ack_msg_ct)
session.setup_cipher()
if len(msg) > session.auth_message_ct_length: # add remains (hello) to queue
self._add_message_post_handshake(msg[session.auth_message_ct_length:])
if len(rest) > 0:
self._add_message_post_handshake(rest)
self.add_message = self._add_message_post_handshake

# send hello
Expand Down
10 changes: 4 additions & 6 deletions devp2p/p2p_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,16 @@ class pong(BaseProtocol.command):

class hello(BaseProtocol.command):
cmd_id = 0

structure = [
('version', sedes.big_endian_int),
('client_version_string', sedes.binary),
('capabilities', sedes.CountableList(sedes.List([sedes.binary, sedes.big_endian_int]))),
('listen_port', sedes.big_endian_int),
('remote_pubkey', sedes.binary)
]
# don't throw for additional list elements as
# mandated by EIP-8.
decode_strict = False

def create(self, proto):
return dict(version=proto.version,
Expand All @@ -115,10 +117,6 @@ def receive(self, proto, data):
if data['remote_pubkey'] == proto.config['node']['id']:
log.debug('connected myself')
return proto.send_disconnect(reason=reasons.connected_to_self)
if data['version'] != proto.version:
log.debug('incompatible network protocols', peer=proto.peer,
expected=proto.version, received=data['version'])
return proto.send_disconnect(reason=reasons.incompatibel_p2p_version)

proto.peer.receive_hello(proto, **data)
# super(hello, self).receive(proto, data)
Expand All @@ -127,7 +125,7 @@ def receive(self, proto, data):
@classmethod
def get_hello_packet(cls, peer):
"special: we need this packet before the protocol can be initalized"
res = dict(version=cls.version,
res = dict(version=55,
client_version_string=peer.config['client_version_string'],
capabilities=peer.capabilities,
listen_port=peer.config['p2p']['listen_port'],
Expand Down
3 changes: 1 addition & 2 deletions devp2p/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def __init__(self, peermanager, connection, remote_pubkey=None):
# create multiplexed encrypted session
privkey = self.config['node']['privkey_hex'].decode('hex')
hello_packet = P2PProtocol.get_hello_packet(self)
self.mux = MultiplexedSession(privkey, hello_packet,
token_by_pubkey=dict(), remote_pubkey=remote_pubkey)
self.mux = MultiplexedSession(privkey, hello_packet, remote_pubkey=remote_pubkey)
self.remote_pubkey = remote_pubkey

# register p2p protocol
Expand Down
4 changes: 3 additions & 1 deletion devp2p/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class command(object):
- define structure for rlp de/endcoding by sedes
- list(arg_name, rlp.sedes.type), ...) # for structs
- sedes.CountableList(sedes.type) # for lists with uniform item type
- if you want non-strict decoding, define decode_strict = False
optionally implement
- create
- receive
Expand All @@ -55,6 +56,7 @@ class command(object):
"""
cmd_id = 0
structure = [] # [(arg_name, rlp.sedes.type), ...]
decode_strict = True

def create(self, proto, *args, **kargs):
"optionally implement create"
Expand Down Expand Up @@ -93,7 +95,7 @@ def decode_payload(cls, rlp_data):
if isinstance(cls.structure, sedes.CountableList):
decoder = cls.structure
else:
decoder = sedes.List([x[1] for x in cls.structure])
decoder = sedes.List([x[1] for x in cls.structure], strict=cls.decode_strict)
try:
data = rlp.decode(str(rlp_data), sedes=decoder)
except (AssertionError, rlp.RLPException, TypeError) as e:
Expand Down
Loading

0 comments on commit 3b230c4

Please sign in to comment.