Skip to content
This repository has been archived by the owner on Feb 13, 2024. It is now read-only.

Commit

Permalink
Adding option to define HTTP timeout for POST requests to RPC nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ochaloup committed Jan 19, 2022
1 parent f1b14ac commit 2af3d2b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
10 changes: 6 additions & 4 deletions mango/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,12 @@ def report_on_transaction(self) -> None:
# A `RPCCaller` extends the HTTPProvider with better error handling.
#
class RPCCaller(HTTPProvider):
def __init__(self, name: str, cluster_url: str, stale_data_pauses_before_retry: typing.Sequence[float], slot_holder: SlotHolder, instruction_reporter: InstructionReporter):
def __init__(self, name: str, cluster_url: str, http_request_timeout: typing.Optional[float], stale_data_pauses_before_retry: typing.Sequence[float], slot_holder: SlotHolder, instruction_reporter: InstructionReporter):
super().__init__(cluster_url)
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.name: str = name
self.cluster_url: str = cluster_url
self.http_request_timeout: typing.Optional[float] = http_request_timeout
self.stale_data_pauses_before_retry: typing.Sequence[float] = stale_data_pauses_before_retry
self.slot_holder: SlotHolder = slot_holder
self.instruction_reporter: InstructionReporter = instruction_reporter
Expand Down Expand Up @@ -401,7 +402,7 @@ def __make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse:
# return self._after_request(raw_response=raw_response, method=method)

request_kwargs = self._before_request(method=method, params=params, is_async=False)
raw_response = requests.post(**request_kwargs)
raw_response = requests.post(**request_kwargs, timeout=self.http_request_timeout)

# Some custom exceptions specifically for rate-limiting. This allows calling code to handle this
# specific case if they so choose.
Expand Down Expand Up @@ -530,6 +531,7 @@ def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse:
self._logger.debug(f"Shifted provider - now using: {self.__providers[0]}")
return result
except (requests.exceptions.HTTPError,
requests.exceptions.ConnectTimeout,
RateLimitException,
NodeIsBehindException,
StaleSlotException,
Expand Down Expand Up @@ -591,11 +593,11 @@ def __init__(self, client: Client, name: str, cluster_name: str, commitment: Com
self.executor: Executor = ThreadPoolExecutor()

@staticmethod
def from_configuration(name: str, cluster_name: str, cluster_urls: typing.Sequence[str], commitment: Commitment, skip_preflight: bool, encoding: str, blockhash_cache_duration: int, stale_data_pauses_before_retry: typing.Sequence[float], instruction_reporter: InstructionReporter) -> "BetterClient":
def from_configuration(name: str, cluster_name: str, cluster_urls: typing.Sequence[str], commitment: Commitment, skip_preflight: bool, encoding: str, blockhash_cache_duration: int, http_request_timeout: typing.Optional[float], stale_data_pauses_before_retry: typing.Sequence[float], instruction_reporter: InstructionReporter) -> "BetterClient":
slot_holder: SlotHolder = SlotHolder()
rpc_callers: typing.List[RPCCaller] = []
for cluster_url in cluster_urls:
rpc_caller: RPCCaller = RPCCaller(name, cluster_url, stale_data_pauses_before_retry,
rpc_caller: RPCCaller = RPCCaller(name, cluster_url, http_request_timeout, stale_data_pauses_before_retry,
slot_holder, instruction_reporter)
rpc_callers += [rpc_caller]

Expand Down
4 changes: 2 additions & 2 deletions mango/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#
class Context:
def __init__(self, name: str, cluster_name: str, cluster_urls: typing.Sequence[str], skip_preflight: bool,
commitment: str, encoding: str, blockhash_cache_duration: int,
commitment: str, encoding: str, blockhash_cache_duration: int, http_request_timeout: typing.Optional[float],
stale_data_pauses_before_retry: typing.Sequence[float], mango_program_address: PublicKey,
serum_program_address: PublicKey, group_name: str, group_address: PublicKey,
gma_chunk_size: Decimal, gma_chunk_pause: Decimal, instrument_lookup: InstrumentLookup,
Expand All @@ -48,7 +48,7 @@ def __init__(self, name: str, cluster_name: str, cluster_urls: typing.Sequence[s
instruction_reporter: InstructionReporter = CompoundInstructionReporter.from_addresses(
mango_program_address, serum_program_address)
self.client: BetterClient = BetterClient.from_configuration(name, cluster_name, cluster_urls, Commitment(
commitment), skip_preflight, encoding, blockhash_cache_duration, stale_data_pauses_before_retry, instruction_reporter)
commitment), skip_preflight, encoding, blockhash_cache_duration, http_request_timeout, stale_data_pauses_before_retry, instruction_reporter)
self.mango_program_address: PublicKey = mango_program_address
self.serum_program_address: PublicKey = serum_program_address
self.group_name: str = group_name
Expand Down
12 changes: 9 additions & 3 deletions mango/contextbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def add_command_line_parameters(parser: argparse.ArgumentParser) -> None:
help="Encoding to request when receiving data from Solana (options are 'base58' (slow), 'base64', 'base64+zstd', or 'jsonParsed')")
parser.add_argument("--blockhash-cache-duration", type=int,
help="How long (in seconds) to cache 'recent' blockhashes")
parser.add_argument("--http-request-timeout", type=float, default=None,
help="What is the timeout for HTTP requests to RPC nodes (in seconds) ")
parser.add_argument("--stale-data-pause-before-retry", type=Decimal,
help="How long (in seconds, e.g. 0.1) to pause after retrieving stale data before retrying")
parser.add_argument("--stale-data-maximum-retries", type=int,
Expand Down Expand Up @@ -107,6 +109,7 @@ def from_command_line_parameters(args: argparse.Namespace) -> Context:
commitment: typing.Optional[str] = args.commitment
encoding: typing.Optional[str] = args.encoding
blockhash_cache_duration: typing.Optional[int] = args.blockhash_cache_duration
http_request_timeout: typing.Optional[float] = args.http_request_timeout
stale_data_pause_before_retry: typing.Optional[Decimal] = args.stale_data_pause_before_retry
stale_data_maximum_retries: typing.Optional[int] = args.stale_data_maximum_retries
gma_chunk_size: typing.Optional[Decimal] = args.gma_chunk_size
Expand All @@ -125,7 +128,7 @@ def from_command_line_parameters(args: argparse.Namespace) -> Context:
actual_stale_data_pauses_before_retry = [float(pause)] * retries

context: Context = ContextBuilder.build(name, cluster_name, cluster_urls, skip_preflight, commitment,
encoding, blockhash_cache_duration,
encoding, blockhash_cache_duration, http_request_timeout,
actual_stale_data_pauses_before_retry,
group_name, group_address, mango_program_address,
serum_program_address, gma_chunk_size, gma_chunk_pause,
Expand All @@ -142,7 +145,7 @@ def default() -> Context:
def from_group_name(context: Context, group_name: str) -> Context:
return ContextBuilder.build(context.name, context.client.cluster_name, context.client.cluster_urls,
context.client.skip_preflight, context.client.commitment,
context.client.encoding, context.client.blockhash_cache_duration,
context.client.encoding, context.client.blockhash_cache_duration, None,
context.client.stale_data_pauses_before_retry,
group_name, None, None, None,
context.gma_chunk_size, context.gma_chunk_pause,
Expand All @@ -160,6 +163,7 @@ def forced_to_devnet(context: Context) -> Context:
context.client.skip_preflight,
context.client.encoding,
context.client.blockhash_cache_duration,
None,
context.client.stale_data_pauses_before_retry,
context.client.instruction_reporter)

Expand All @@ -177,6 +181,7 @@ def forced_to_mainnet_beta(context: Context) -> Context:
context.client.skip_preflight,
context.client.encoding,
context.client.blockhash_cache_duration,
None,
context.client.stale_data_pauses_before_retry,
context.client.instruction_reporter)

Expand All @@ -187,6 +192,7 @@ def build(name: typing.Optional[str] = None, cluster_name: typing.Optional[str]
cluster_urls: typing.Optional[typing.Sequence[str]] = None, skip_preflight: bool = False,
commitment: typing.Optional[str] = None, encoding: typing.Optional[str] = None,
blockhash_cache_duration: typing.Optional[int] = None,
http_request_timeout: typing.Optional[float] = None,
stale_data_pauses_before_retry: typing.Optional[typing.Sequence[float]] = None,
group_name: typing.Optional[str] = None, group_address: typing.Optional[PublicKey] = None,
program_address: typing.Optional[PublicKey] = None, serum_program_address: typing.Optional[PublicKey] = None,
Expand Down Expand Up @@ -307,4 +313,4 @@ def __public_key_or_none(address: typing.Optional[str]) -> typing.Optional[Publi
devnet_serum_market_lookup])
market_lookup: MarketLookup = all_market_lookup

return Context(actual_name, actual_cluster, actual_cluster_urls, actual_skip_preflight, actual_commitment, actual_encoding, actual_blockhash_cache_duration, actual_stale_data_pauses_before_retry, actual_program_address, actual_serum_program_address, actual_group_name, actual_group_address, actual_gma_chunk_size, actual_gma_chunk_pause, instrument_lookup, market_lookup)
return Context(actual_name, actual_cluster, actual_cluster_urls, actual_skip_preflight, actual_commitment, actual_encoding, actual_blockhash_cache_duration, http_request_timeout, actual_stale_data_pauses_before_retry, actual_program_address, actual_serum_program_address, actual_group_name, actual_group_address, actual_gma_chunk_size, actual_gma_chunk_pause, instrument_lookup, market_lookup)
3 changes: 2 additions & 1 deletion tests/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_minimum_balance_for_rent_exemption(size, *args: typing.Any, **kwargs: ty

class MockClient(mango.BetterClient):
def __init__(self) -> None:
rpc = mango.RPCCaller("fake", "http://localhost", [], mango.SlotHolder(), mango.InstructionReporter())
rpc = mango.RPCCaller("fake", "http://localhost", None, [], mango.SlotHolder(), mango.InstructionReporter())
compound = mango.CompoundRPCCaller("fake", [rpc])
super().__init__(MockCompatibleClient(), "test", "local", Commitment("processed"),
False, "base64", 0, compound)
Expand All @@ -51,6 +51,7 @@ def fake_context() -> mango.Context:
commitment="processed",
encoding="base64",
blockhash_cache_duration=0,
http_request_timeout=None,
stale_data_pauses_before_retry=[],
mango_program_address=fake_seeded_public_key("Mango program address"),
serum_program_address=fake_seeded_public_key("Serum program address"),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class FakeRPCCaller(mango.RPCCaller):
def __init__(self) -> None:
super().__init__("Fake", "https://localhost", [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter())
super().__init__("Fake", "https://localhost", None, [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter())
self.called = False

def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse:
Expand All @@ -25,7 +25,7 @@ def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse:

class RaisingRPCCaller(mango.RPCCaller):
def __init__(self) -> None:
super().__init__("Fake", "https://localhost", [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter())
super().__init__("Fake", "https://localhost", None, [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter())
self.called = False

def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse:
Expand Down

0 comments on commit 2af3d2b

Please sign in to comment.