diff --git a/setup.py b/setup.py index ec2f503f..e19a0d23 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ name = 'PyOTA', description = 'IOTA API library for Python', url = 'https://github.com/iotaledger/iota.lib.py', - version = '1.1.1', + version = '1.1.2', packages = find_packages('src'), include_package_data = True, diff --git a/src/iota/api.py b/src/iota/api.py index f7ccf2cb..f3e89589 100644 --- a/src/iota/api.py +++ b/src/iota/api.py @@ -822,8 +822,8 @@ def send_transfer( minWeightMagnitude = min_weight_magnitude, ) - def send_trytes(self, trytes, depth, min_weight_magnitude=18): - # type: (Iterable[TransactionTrytes], int, int) -> dict + def send_trytes(self, trytes, depth, min_weight_magnitude=None): + # type: (Iterable[TransactionTrytes], int, Optional[int]) -> dict """ Attaches transaction trytes to the Tangle, then broadcasts and stores them. @@ -851,6 +851,9 @@ def send_trytes(self, trytes, depth, min_weight_magnitude=18): References: - https://github.com/iotaledger/wiki/blob/master/api-proposal.md#sendtrytes """ + if min_weight_magnitude is None: + min_weight_magnitude = self.default_min_weight_magnitude + return extended.SendTrytesCommand(self.adapter)( trytes = trytes, depth = depth, diff --git a/src/iota/crypto/addresses.py b/src/iota/crypto/addresses.py index 53468103..44db482c 100644 --- a/src/iota/crypto/addresses.py +++ b/src/iota/crypto/addresses.py @@ -4,7 +4,10 @@ import hashlib from abc import ABCMeta, abstractmethod as abstract_method -from typing import Dict, Generator, Iterable, List, MutableSequence, Optional +from contextlib import contextmanager as context_manager +from threading import Lock +from typing import Dict, Generator, Iterable, List, MutableSequence, \ + Optional, Tuple from iota import Address, TRITS_PER_TRYTE, TrytesCompatible from iota.crypto import Curl @@ -23,6 +26,19 @@ class BaseAddressCache(with_metaclass(ABCMeta)): """ Base functionality for classes that cache generated addresses. """ + LockType = Lock + """ + The type of locking mechanism used by :py:meth:`acquire_lock`. + + Defaults to ``threading.Lock``, but you can change it if you want to + use a different mechanism (e.g., multithreading or distributed). + """ + + def __init__(self): + super(BaseAddressCache, self).__init__() + + self._lock = self.LockType() + @abstract_method def get(self, seed, index): # type: (Seed, int) -> Optional[Address] @@ -34,6 +50,18 @@ def get(self, seed, index): 'Not implemented in {cls}.'.format(cls=type(self).__name__), ) + @context_manager + def acquire_lock(self): + """ + Acquires a lock on the cache instance, to prevent invalid cache + misses when multiple threads access the cache concurrently. + + Note: Acquire lock before checking the cache, and do not release it + until after the cache hit/miss is resolved. + """ + with self._lock: + yield + @abstract_method def set(self, seed, index, address): # type: (Seed, int, Address) -> None @@ -45,6 +73,17 @@ def set(self, seed, index, address): 'Not implemented in {cls}.'.format(cls=type(self).__name__), ) + @staticmethod + def _gen_cache_key(seed, index): + # type: (Seed, int) -> binary_type + """ + Generates an obfuscated cache key so that we're not storing seeds + in cleartext. + """ + h = hashlib.new('sha256') + h.update(binary_type(seed) + b':' + binary_type(index)) + return h.digest() + class MemoryAddressCache(BaseAddressCache): """ @@ -63,17 +102,6 @@ def set(self, seed, index, address): # type: (Seed, int, Address) -> None self.cache[self._gen_cache_key(seed, index)] = address - @staticmethod - def _gen_cache_key(seed, index): - # type: (Seed, int) -> binary_type - """ - Generates an obfuscated cache key so that we're not storing seeds - in cleartext. - """ - h = hashlib.new('sha256') - h.update(binary_type(seed) + b':' + binary_type(index)) - return h.digest() - class AddressGenerator(Iterable[Address]): """ @@ -213,18 +241,19 @@ def create_iterator(self, start=0, step=1): while True: if self.cache: - address = self.cache.get(self.seed, key_iterator.current) + with self.cache.acquire_lock(): + address = self.cache.get(self.seed, key_iterator.current) - if not address: - address = self._generate_address(key_iterator) - self.cache.set(self.seed, address.key_index, address) + if not address: + address = self._generate_address(key_iterator) + self.cache.set(self.seed, address.key_index, address) else: address = self._generate_address(key_iterator) yield address @staticmethod - def address_from_digest(digest_trits, key_index): + def address_from_digest_trits(digest_trits, key_index): # type: (List[int], int) -> Address """ Generates an address from a private key digest. @@ -247,13 +276,13 @@ def _generate_address(self, key_iterator): Used in the event of a cache miss. """ - return self.address_from_digest(*self._get_digest_params(key_iterator)) + return self.address_from_digest_trits(*self._get_digest_params(key_iterator)) @staticmethod def _get_digest_params(key_iterator): # type: (KeyIterator) -> Tuple[List[int], int] """ - Extracts parameters for :py:meth:`address_from_digest`. + Extracts parameters for :py:meth:`address_from_digest_trits`. Split into a separate method so that it can be mocked during unit tests. diff --git a/test/crypto/addresses_test.py b/test/crypto/addresses_test.py index 62086fda..24d36791 100644 --- a/test/crypto/addresses_test.py +++ b/test/crypto/addresses_test.py @@ -2,6 +2,8 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals +from threading import Thread +from time import sleep from typing import List, Tuple from unittest import TestCase @@ -44,6 +46,26 @@ def setUp(self): b'CFANWBQFGMFKITZBJDSYLGXYUIQVCMXFWSWFRNHRV' ) + # noinspection SpellCheckingInspection + def test_address_from_digest(self): + """ + Generating an address from a private key digest. + """ + digest =\ + Hash( + b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI' + b'JKOJGZVGHCUXXGFZEMMGDSGWDCKJXO9ILLFAKGGZE' + ) + + self.assertEqual( + AddressGenerator.address_from_digest_trits(digest.as_trits(), 0), + + Address( + b'QLOEDSBXXOLLUJYLEGKEPYDRIJJTPIMEPKMFHUVJ' + b'MPMLYYCLPQPANEVDSERQWPVNHCAXYRLAYMBHJLWWR' + ), + ) + def test_get_addresses_single(self): """ Generating a single address. @@ -329,3 +351,53 @@ def test_cache_miss_seed(self): generator2 = AddressGenerator(Seed.random()) generator2.get_addresses(42) self.assertEqual(mock_generate_address.call_count, 2) + + def test_thread_safety(self): + """ + Address cache is thread-safe, eliminating invalid cache misses when + multiple threads attempt to access the cache concurrently. + """ + AddressGenerator.cache = MemoryAddressCache() + + seed = Seed.random() + + generated = [] + + def get_address(): + generator = AddressGenerator(seed) + generated.extend(generator.get_addresses(0)) + + # noinspection PyUnusedLocal + def mock_generate_address(address_generator, key_iterator): + # type: (AddressGenerator, KeyIterator) -> Address + # Insert a teensy delay, to make it more likely that multiple + # threads hit the cache concurrently. + sleep(0.01) + + # Note that in this test, the address generator always returns a + # new instance. + return Address(self.addy, key_index=key_iterator.current) + + with patch( + 'iota.crypto.addresses.AddressGenerator._generate_address', + mock_generate_address, + ): + threads = [Thread(target=get_address) for _ in range(100)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Quick sanity check. + self.assertEqual(len(generated), len(threads)) + + # If the cache is operating in a thread-safe manner, then it will + # always return the exact same instance, given the same seed and + # key index. + expected = generated[0] + for actual in generated[1:]: + # Compare `id` values instead of using ``self.assertIs`` because + # the failure message is a bit easier to understand. + self.assertEqual(id(actual), id(expected))