Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 56 additions & 118 deletions allways/contract_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@

from allways.classes import Swap, SwapStatus
from allways.constants import CONTRACT_ADDRESS, MIN_BALANCE_FOR_TX_RAO
from allways.utils.scale import (
ACCOUNT_ID_BYTES,
U32_BYTES,
U64_BYTES,
U128_BYTES,
compact_encode_len,
decode_account_id,
decode_string,
decode_u32,
decode_u64,
decode_u128,
encode_bytes,
encode_str,
encode_u128,
strip_hex_prefix,
)

# =========================================================================
# Contract selectors (from metadata — deterministic per contract build)
Expand Down Expand Up @@ -208,23 +224,6 @@
_EXTRINSIC_NOT_FOUND = tuple(t for t in [ExtrinsicNotFound, AsyncExtrinsicNotFound] if t is not None)


def compact_encode_len(length: int) -> bytes:
"""SCALE compact-encode a length prefix. Shared by contract client and axon handlers."""
if length < 64:
return bytes([length << 2])
elif length < 16384:
return bytes([((length << 2) | 1) & 0xFF, length >> 6])
else:
return bytes(
[
((length << 2) | 2) & 0xFF,
(length >> 6) & 0xFF,
(length >> 14) & 0xFF,
(length >> 22) & 0xFF,
]
)


# ContractExecResult byte layout offsets (after gas prefix)
_GAS_PREFIX_BYTES = 16 # Skip gas consumed/required
_RESULT_OK_OFFSET = 10 # Byte indicating Ok(0x00) vs Err in Result
Expand Down Expand Up @@ -366,7 +365,7 @@ def raw_contract_read(
if not result.get('result'):
return None

raw = bytes.fromhex(result['result'].replace('0x', ''))
raw = bytes.fromhex(strip_hex_prefix(result['result']))
if len(raw) < 32:
return None

Expand All @@ -376,7 +375,7 @@ def raw_contract_read(
if len(r) < _DATA_COMPACT_OFFSET or r[_RESULT_OK_OFFSET] != 0x00:
return None

flags = struct.unpack_from('<I', r, _FLAGS_OFFSET)[0]
flags, _ = decode_u32(r, _FLAGS_OFFSET)
is_revert = bool(flags & 1)

data_compact = r[_DATA_COMPACT_OFFSET]
Expand Down Expand Up @@ -505,27 +504,25 @@ def encode_value(self, value, type_tag: str) -> bytes:
return struct.pack('B', int(value))
elif type_tag == 'hash':
if isinstance(value, str):
return bytes.fromhex(value.replace('0x', ''))
return bytes(value)[:32].ljust(32, b'\x00')
return bytes.fromhex(strip_hex_prefix(value))
return bytes(value)[:ACCOUNT_ID_BYTES].ljust(ACCOUNT_ID_BYTES, b'\x00')
elif type_tag == 'bytes':
data = value if isinstance(value, (bytes, bytearray)) else value.encode('utf-8')
return compact_encode_len(len(data)) + data
return encode_bytes(data)
elif type_tag == 'u32':
return struct.pack('<I', int(value))
elif type_tag == 'u64':
return struct.pack('<Q', int(value))
elif type_tag == 'u128':
v = int(value)
return struct.pack('<QQ', v & 0xFFFFFFFFFFFFFFFF, v >> 64)
return encode_u128(int(value))
elif type_tag == 'bool':
return b'\x01' if value else b'\x00'
elif type_tag == 'AccountId':
if isinstance(value, str):
return bytes.fromhex(self.subtensor.substrate.ss58_decode(value))
return bytes(value)
elif type_tag == 'str':
data = value.encode('utf-8') if isinstance(value, str) else value
return compact_encode_len(len(data)) + data
return encode_str(value) if isinstance(value, str) else encode_bytes(value)
elif type_tag == 'vec_u64':
items = list(value)
encoded = compact_encode_len(len(items))
Expand All @@ -535,108 +532,58 @@ def encode_value(self, value, type_tag: str) -> bytes:
raise ValueError(f'Unsupported type: {type_tag}')

def extract_u32(self, data: bytes) -> Optional[int]:
if not data or len(data) < 4:
if not data or len(data) < U32_BYTES:
return None
return struct.unpack_from('<I', data, 0)[0]
return decode_u32(data, 0)[0]

def extract_u64(self, data: bytes) -> Optional[int]:
if not data or len(data) < 8:
if not data or len(data) < U64_BYTES:
return None
return struct.unpack_from('<Q', data, 0)[0]
return decode_u64(data, 0)[0]

def extract_u128(self, data: bytes) -> Optional[int]:
if not data or len(data) < 16:
if not data or len(data) < U128_BYTES:
return None
low = struct.unpack_from('<Q', data, 0)[0]
high = struct.unpack_from('<Q', data, 8)[0]
return low + (high << 64)
return decode_u128(data, 0)[0]

def extract_bool(self, data: bytes) -> Optional[bool]:
if not data:
return None
return data[0] != 0

def extract_account_id(self, data: bytes) -> Optional[str]:
if not data or len(data) < 32:
if not data or len(data) < ACCOUNT_ID_BYTES:
return None
return self.subtensor.substrate.ss58_encode(data[:32].hex())

def decode_string(self, data: bytes, offset: int) -> Tuple[str, int]:
"""Decode a SCALE compact-prefixed string. Returns (string, new_offset)."""
if offset >= len(data):
return '', offset
first = data[offset]
mode = first & 0x03
if mode == 0:
str_len = first >> 2
offset += 1
elif mode == 1:
if offset + 1 >= len(data):
return '', offset
str_len = (data[offset] | (data[offset + 1] << 8)) >> 2
offset += 2
else:
if offset + 3 >= len(data):
return '', offset
str_len = (
data[offset] | (data[offset + 1] << 8) | (data[offset + 2] << 16) | (data[offset + 3] << 24)
) >> 2
offset += 4
if offset + str_len > len(data):
return '', offset
s = data[offset : offset + str_len].decode('utf-8', errors='replace')
return s, offset + str_len
return decode_account_id(data, 0)[0]

def decode_swap_data(self, data: bytes, offset: int = 0) -> Optional[Swap]:
"""Decode a SwapData struct from raw SCALE bytes."""
try:
o = offset

swap_id = struct.unpack_from('<Q', data, o)[0]
o += 8
user = self.subtensor.substrate.ss58_encode(data[o : o + 32].hex())
o += 32
miner = self.subtensor.substrate.ss58_encode(data[o : o + 32].hex())
o += 32
from_chain, o = self.decode_string(data, o)
to_chain, o = self.decode_string(data, o)
from_amount_lo = struct.unpack_from('<Q', data, o)[0]
o += 8
from_amount_hi = struct.unpack_from('<Q', data, o)[0]
o += 8
from_amount = from_amount_lo + (from_amount_hi << 64)
to_amount_lo = struct.unpack_from('<Q', data, o)[0]
o += 8
to_amount_hi = struct.unpack_from('<Q', data, o)[0]
o += 8
to_amount = to_amount_lo + (to_amount_hi << 64)
tao_amount_lo = struct.unpack_from('<Q', data, o)[0]
o += 8
tao_amount_hi = struct.unpack_from('<Q', data, o)[0]
o += 8
tao_amount = tao_amount_lo + (tao_amount_hi << 64)
user_from_address, o = self.decode_string(data, o)
user_to_address, o = self.decode_string(data, o)
miner_from_address, o = self.decode_string(data, o)
miner_to_address, o = self.decode_string(data, o)
rate, o = self.decode_string(data, o)
from_tx_hash, o = self.decode_string(data, o)
from_tx_block = struct.unpack_from('<I', data, o)[0]
o += 4
to_tx_hash, o = self.decode_string(data, o)
to_tx_block = struct.unpack_from('<I', data, o)[0]
o += 4
swap_id, o = decode_u64(data, o)
user, o = decode_account_id(data, o)
miner, o = decode_account_id(data, o)
from_chain, o = decode_string(data, o)
to_chain, o = decode_string(data, o)
from_amount, o = decode_u128(data, o)
to_amount, o = decode_u128(data, o)
tao_amount, o = decode_u128(data, o)
user_from_address, o = decode_string(data, o)
user_to_address, o = decode_string(data, o)
miner_from_address, o = decode_string(data, o)
miner_to_address, o = decode_string(data, o)
rate, o = decode_string(data, o)
from_tx_hash, o = decode_string(data, o)
from_tx_block, o = decode_u32(data, o)
to_tx_hash, o = decode_string(data, o)
to_tx_block, o = decode_u32(data, o)
status_byte = data[o]
o += 1
status = SwapStatus(status_byte) if status_byte <= 3 else SwapStatus.ACTIVE
initiated_block = struct.unpack_from('<I', data, o)[0]
o += 4
timeout_block = struct.unpack_from('<I', data, o)[0]
o += 4
fulfilled_block = struct.unpack_from('<I', data, o)[0]
o += 4
completed_block = struct.unpack_from('<I', data, o)[0]
o += 4
initiated_block, o = decode_u32(data, o)
timeout_block, o = decode_u32(data, o)
fulfilled_block, o = decode_u32(data, o)
completed_block, o = decode_u32(data, o)

return Swap(
id=swap_id,
Expand Down Expand Up @@ -863,7 +810,7 @@ def get_cooldown(self, from_address: str) -> Tuple[int, int]:
if data is None or len(data) < 5:
return (0, 0)
strike_count = data[0]
last_expired = struct.unpack_from('<I', data, 1)[0]
last_expired, _ = decode_u32(data, 1)
return (strike_count, last_expired)

def get_accumulated_fees(self) -> int:
Expand Down Expand Up @@ -912,18 +859,9 @@ def get_reservation_data(self, miner_hotkey: str) -> Optional[Tuple[int, int, in
if data[0] != 0x01:
return None
o = 1
# 3 x u128
tao_lo = struct.unpack_from('<Q', data, o)[0]
tao_hi = struct.unpack_from('<Q', data, o + 8)[0]
tao_amount = tao_lo + (tao_hi << 64)
o += 16
src_lo = struct.unpack_from('<Q', data, o)[0]
src_hi = struct.unpack_from('<Q', data, o + 8)[0]
from_amount = src_lo + (src_hi << 64)
o += 16
dst_lo = struct.unpack_from('<Q', data, o)[0]
dst_hi = struct.unpack_from('<Q', data, o + 8)[0]
to_amount = dst_lo + (dst_hi << 64)
tao_amount, o = decode_u128(data, o)
from_amount, o = decode_u128(data, o)
to_amount, _ = decode_u128(data, o)
return (tao_amount, from_amount, to_amount)

# =========================================================================
Expand Down
111 changes: 111 additions & 0 deletions allways/utils/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""SCALE codec primitives shared by the contract client, event watcher, and axon handlers.

Only the subset of SCALE we need for the ink! swap manager contract — not a
general SCALE implementation. Streaming decoders return ``(value, new_offset)``
so compound structs can chain reads without manual offset bookkeeping.
"""

import struct
from typing import Tuple

from substrateinterface.utils.ss58 import ss58_encode

# SS58 prefix for Bittensor (matches substrate.ss58_format on all configured networks).
SS58_PREFIX = 42

# Byte widths of fixed-size SCALE primitives.
U32_BYTES = 4
U64_BYTES = 8
U128_BYTES = 16
ACCOUNT_ID_BYTES = 32


def strip_hex_prefix(s: str) -> str:
"""Remove a leading ``0x`` from a hex string if present."""
return s[2:] if s.startswith('0x') else s


def compact_encode_len(length: int) -> bytes:
"""SCALE compact-encode a length prefix."""
if length < 64:
return bytes([length << 2])
if length < 16384:
return bytes([((length << 2) | 1) & 0xFF, length >> 6])
return bytes(
[
((length << 2) | 2) & 0xFF,
(length >> 6) & 0xFF,
(length >> 14) & 0xFF,
(length >> 22) & 0xFF,
]
)


def encode_bytes(data: bytes) -> bytes:
"""SCALE-encode raw bytes as compact length prefix + bytes."""
return compact_encode_len(len(data)) + data


def encode_str(s: str) -> bytes:
"""SCALE-encode a UTF-8 string as compact length prefix + bytes."""
return encode_bytes(s.encode('utf-8'))


def encode_u128(value: int) -> bytes:
"""SCALE-encode a u128 as 16 little-endian bytes."""
return value.to_bytes(U128_BYTES, 'little')


# ─── Streaming decoders ────────────────────────────────────────────────────


def decode_u32(data: bytes, offset: int) -> Tuple[int, int]:
return struct.unpack_from('<I', data, offset)[0], offset + U32_BYTES


def decode_u64(data: bytes, offset: int) -> Tuple[int, int]:
return struct.unpack_from('<Q', data, offset)[0], offset + U64_BYTES


def decode_u128(data: bytes, offset: int) -> Tuple[int, int]:
lo = struct.unpack_from('<Q', data, offset)[0]
hi = struct.unpack_from('<Q', data, offset + U64_BYTES)[0]
return lo + (hi << 64), offset + U128_BYTES


def decode_bool(data: bytes, offset: int) -> Tuple[bool, int]:
return data[offset] != 0, offset + 1


def decode_account_id(data: bytes, offset: int) -> Tuple[str, int]:
raw = data[offset : offset + ACCOUNT_ID_BYTES]
return ss58_encode(raw, SS58_PREFIX), offset + ACCOUNT_ID_BYTES


def decode_string(data: bytes, offset: int) -> Tuple[str, int]:
"""SCALE-decode a compact-length-prefixed UTF-8 string.

Returns ``('', offset)`` on truncated or out-of-bounds input so callers
streaming composite structs degrade cleanly instead of raising.
"""
if offset >= len(data):
return '', offset
first = data[offset]
mode = first & 0x03
if mode == 0:
str_len = first >> 2
offset += 1
elif mode == 1:
if offset + 1 >= len(data):
return '', offset
str_len = (data[offset] | (data[offset + 1] << 8)) >> 2
offset += 2
else:
if offset + 3 >= len(data):
return '', offset
str_len = (data[offset] | (data[offset + 1] << 8) | (data[offset + 2] << 16) | (data[offset + 3] << 24)) >> 2
offset += 4
if offset + str_len > len(data):
return '', offset
s = data[offset : offset + str_len].decode('utf-8', errors='replace')
return s, offset + str_len
Loading
Loading