Skip to content
This repository was archived by the owner on Jan 13, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
name = 'PyOTA',
description = 'IOTA API library for Python',
url = 'https://github.com/iotaledger/iota.lib.py',
version = '1.1.1',
version = '1.1.2',

packages = find_packages('src'),
include_package_data = True,
Expand Down
7 changes: 5 additions & 2 deletions src/iota/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,8 @@ def send_transfer(
minWeightMagnitude = min_weight_magnitude,
)

def send_trytes(self, trytes, depth, min_weight_magnitude=18):
# type: (Iterable[TransactionTrytes], int, int) -> dict
def send_trytes(self, trytes, depth, min_weight_magnitude=None):
# type: (Iterable[TransactionTrytes], int, Optional[int]) -> dict
"""
Attaches transaction trytes to the Tangle, then broadcasts and
stores them.
Expand Down Expand Up @@ -851,6 +851,9 @@ def send_trytes(self, trytes, depth, min_weight_magnitude=18):
References:
- https://github.com/iotaledger/wiki/blob/master/api-proposal.md#sendtrytes
"""
if min_weight_magnitude is None:
min_weight_magnitude = self.default_min_weight_magnitude

return extended.SendTrytesCommand(self.adapter)(
trytes = trytes,
depth = depth,
Expand Down
67 changes: 48 additions & 19 deletions src/iota/crypto/addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import hashlib
from abc import ABCMeta, abstractmethod as abstract_method
from typing import Dict, Generator, Iterable, List, MutableSequence, Optional
from contextlib import contextmanager as context_manager
from threading import Lock
from typing import Dict, Generator, Iterable, List, MutableSequence, \
Optional, Tuple

from iota import Address, TRITS_PER_TRYTE, TrytesCompatible
from iota.crypto import Curl
Expand All @@ -23,6 +26,19 @@ class BaseAddressCache(with_metaclass(ABCMeta)):
"""
Base functionality for classes that cache generated addresses.
"""
LockType = Lock
"""
The type of locking mechanism used by :py:meth:`acquire_lock`.

Defaults to ``threading.Lock``, but you can change it if you want to
use a different mechanism (e.g., multithreading or distributed).
"""

def __init__(self):
super(BaseAddressCache, self).__init__()

self._lock = self.LockType()

@abstract_method
def get(self, seed, index):
# type: (Seed, int) -> Optional[Address]
Expand All @@ -34,6 +50,18 @@ def get(self, seed, index):
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)

@context_manager
def acquire_lock(self):
"""
Acquires a lock on the cache instance, to prevent invalid cache
misses when multiple threads access the cache concurrently.

Note: Acquire lock before checking the cache, and do not release it
until after the cache hit/miss is resolved.
"""
with self._lock:
yield

@abstract_method
def set(self, seed, index, address):
# type: (Seed, int, Address) -> None
Expand All @@ -45,6 +73,17 @@ def set(self, seed, index, address):
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)

@staticmethod
def _gen_cache_key(seed, index):
# type: (Seed, int) -> binary_type
"""
Generates an obfuscated cache key so that we're not storing seeds
in cleartext.
"""
h = hashlib.new('sha256')
h.update(binary_type(seed) + b':' + binary_type(index))
return h.digest()


class MemoryAddressCache(BaseAddressCache):
"""
Expand All @@ -63,17 +102,6 @@ def set(self, seed, index, address):
# type: (Seed, int, Address) -> None
self.cache[self._gen_cache_key(seed, index)] = address

@staticmethod
def _gen_cache_key(seed, index):
# type: (Seed, int) -> binary_type
"""
Generates an obfuscated cache key so that we're not storing seeds
in cleartext.
"""
h = hashlib.new('sha256')
h.update(binary_type(seed) + b':' + binary_type(index))
return h.digest()


class AddressGenerator(Iterable[Address]):
"""
Expand Down Expand Up @@ -213,18 +241,19 @@ def create_iterator(self, start=0, step=1):

while True:
if self.cache:
address = self.cache.get(self.seed, key_iterator.current)
with self.cache.acquire_lock():
address = self.cache.get(self.seed, key_iterator.current)

if not address:
address = self._generate_address(key_iterator)
self.cache.set(self.seed, address.key_index, address)
if not address:
address = self._generate_address(key_iterator)
self.cache.set(self.seed, address.key_index, address)
else:
address = self._generate_address(key_iterator)

yield address

@staticmethod
def address_from_digest(digest_trits, key_index):
def address_from_digest_trits(digest_trits, key_index):
# type: (List[int], int) -> Address
"""
Generates an address from a private key digest.
Expand All @@ -247,13 +276,13 @@ def _generate_address(self, key_iterator):

Used in the event of a cache miss.
"""
return self.address_from_digest(*self._get_digest_params(key_iterator))
return self.address_from_digest_trits(*self._get_digest_params(key_iterator))

@staticmethod
def _get_digest_params(key_iterator):
# type: (KeyIterator) -> Tuple[List[int], int]
"""
Extracts parameters for :py:meth:`address_from_digest`.
Extracts parameters for :py:meth:`address_from_digest_trits`.

Split into a separate method so that it can be mocked during unit
tests.
Expand Down
72 changes: 72 additions & 0 deletions test/crypto/addresses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import absolute_import, division, print_function, \
unicode_literals

from threading import Thread
from time import sleep
from typing import List, Tuple
from unittest import TestCase

Expand Down Expand Up @@ -44,6 +46,26 @@ def setUp(self):
b'CFANWBQFGMFKITZBJDSYLGXYUIQVCMXFWSWFRNHRV'
)

# noinspection SpellCheckingInspection
def test_address_from_digest(self):
"""
Generating an address from a private key digest.
"""
digest =\
Hash(
b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI'
b'JKOJGZVGHCUXXGFZEMMGDSGWDCKJXO9ILLFAKGGZE'
)

self.assertEqual(
AddressGenerator.address_from_digest_trits(digest.as_trits(), 0),

Address(
b'QLOEDSBXXOLLUJYLEGKEPYDRIJJTPIMEPKMFHUVJ'
b'MPMLYYCLPQPANEVDSERQWPVNHCAXYRLAYMBHJLWWR'
),
)

def test_get_addresses_single(self):
"""
Generating a single address.
Expand Down Expand Up @@ -329,3 +351,53 @@ def test_cache_miss_seed(self):
generator2 = AddressGenerator(Seed.random())
generator2.get_addresses(42)
self.assertEqual(mock_generate_address.call_count, 2)

def test_thread_safety(self):
"""
Address cache is thread-safe, eliminating invalid cache misses when
multiple threads attempt to access the cache concurrently.
"""
AddressGenerator.cache = MemoryAddressCache()

seed = Seed.random()

generated = []

def get_address():
generator = AddressGenerator(seed)
generated.extend(generator.get_addresses(0))

# noinspection PyUnusedLocal
def mock_generate_address(address_generator, key_iterator):
# type: (AddressGenerator, KeyIterator) -> Address
# Insert a teensy delay, to make it more likely that multiple
# threads hit the cache concurrently.
sleep(0.01)

# Note that in this test, the address generator always returns a
# new instance.
return Address(self.addy, key_index=key_iterator.current)

with patch(
'iota.crypto.addresses.AddressGenerator._generate_address',
mock_generate_address,
):
threads = [Thread(target=get_address) for _ in range(100)]

for t in threads:
t.start()

for t in threads:
t.join()

# Quick sanity check.
self.assertEqual(len(generated), len(threads))

# If the cache is operating in a thread-safe manner, then it will
# always return the exact same instance, given the same seed and
# key index.
expected = generated[0]
for actual in generated[1:]:
# Compare `id` values instead of using ``self.assertIs`` because
# the failure message is a bit easier to understand.
self.assertEqual(id(actual), id(expected))