Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codecs (part 14): abi-based encoding & decoding #63

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions multiversx_sdk/abi/abi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, cast

from multiversx_sdk.abi.abi_definition import (AbiDefinition,
EndpointDefinition,
EnumDefinition,
ParameterDefinition,
StructDefinition)
from multiversx_sdk.abi.address_value import AddressValue
from multiversx_sdk.abi.biguint_value import BigUIntValue
from multiversx_sdk.abi.bool_value import BoolValue
from multiversx_sdk.abi.bytes_value import BytesValue
from multiversx_sdk.abi.enum_value import EnumValue
from multiversx_sdk.abi.fields import Field
from multiversx_sdk.abi.interface import IPayloadHolder
from multiversx_sdk.abi.list_value import ListValue
from multiversx_sdk.abi.multi_value import MultiValue
from multiversx_sdk.abi.option_value import OptionValue
from multiversx_sdk.abi.optional_value import OptionalValue
from multiversx_sdk.abi.serializer import Serializer
from multiversx_sdk.abi.small_int_values import *
from multiversx_sdk.abi.string_value import StringValue
from multiversx_sdk.abi.struct_value import StructValue
from multiversx_sdk.abi.token_identifier_value import TokenIdentifierValue
from multiversx_sdk.abi.tuple_value import TupleValue
from multiversx_sdk.abi.type_formula import TypeFormula
from multiversx_sdk.abi.type_formula_parser import TypeFormulaParser
from multiversx_sdk.abi.variadic_values import VariadicValues
from multiversx_sdk.core.constants import ARGS_SEPARATOR


class Abi:
def __init__(self, definition: AbiDefinition) -> None:
self._type_formula_parser = TypeFormulaParser()
self._serializer = Serializer(parts_separator=ARGS_SEPARATOR)

self.definition = definition
self.custom_types_prototypes_by_name: Dict[str, Any] = {}
self.endpoints_prototypes_by_name: Dict[str, EndpointPrototype] = {}

for name in definition.types.enums:
self.custom_types_prototypes_by_name[name] = self._create_custom_type_prototype(name)

for struct_type in definition.types.structs:
self.custom_types_prototypes_by_name[struct_type] = self._create_custom_type_prototype(struct_type)

self.constructor_prototype = EndpointPrototype(
input_parameters=self._create_endpoint_input_prototypes(definition.constructor),
output_parameters=self._create_endpoint_output_prototypes(definition.constructor)
)

self.upgrade_constructor_prototype = EndpointPrototype(
input_parameters=self._create_endpoint_input_prototypes(definition.upgrade_constructor),
output_parameters=self._create_endpoint_output_prototypes(definition.upgrade_constructor)
)

for endpoint in definition.endpoints:
input_prototype = self._create_endpoint_input_prototypes(endpoint)
output_prototype = self._create_endpoint_output_prototypes(endpoint)

endpoint_prototype = EndpointPrototype(
input_parameters=input_prototype,
output_parameters=output_prototype
)

self.endpoints_prototypes_by_name[endpoint.name] = endpoint_prototype

def _create_custom_type_prototype(self, name: str) -> Any:
if name in self.definition.types.enums:
definition = self.definition.types.enums[name]
return self._create_enum_prototype(definition)
if name in self.definition.types.structs:
definition = self.definition.types.structs[name]
return self._create_struct_prototype(definition)

raise ValueError(f"cannot create prototype for custom type {name} not found")

def _create_enum_prototype(self, enum_definition: EnumDefinition) -> Any:
return EnumValue(fields_provider=lambda discriminant: self._provide_fields_for_enum_prototype(discriminant, enum_definition))

def _provide_fields_for_enum_prototype(self, discriminant: int, enum_definition: EnumDefinition) -> List[Field]:
for variant in enum_definition.variants:
if variant.discriminant != discriminant:
continue

fields_prototypes: List[Field] = []

for field_definition in variant.fields:
type_formula = self._type_formula_parser.parse_expression(field_definition.type)
field_value_prototype = self._create_prototype(type_formula)
field_prototype = Field(name=field_definition.name, value=field_value_prototype)
fields_prototypes.append(field_prototype)

return fields_prototypes

raise ValueError(f"cannot provide fields from enum {enum_definition.name}: variant with discriminant {discriminant} not found")

def _create_struct_prototype(self, struct_definition: StructDefinition) -> Any:
fields_prototypes: List[Field] = []

for field_definition in struct_definition.fields:
type_formula = self._type_formula_parser.parse_expression(field_definition.type)
field_value_prototype = self._create_prototype(type_formula)
field_prototype = Field(name=field_definition.name, value=field_value_prototype)
fields_prototypes.append(field_prototype)

return StructValue(fields_prototypes)

def _create_endpoint_input_prototypes(self, endpoint: EndpointDefinition) -> List[Any]:
prototypes: List[Any] = []

for parameter in endpoint.inputs:
parameter_prototype = self._create_parameter_prototype(parameter)
prototypes.append(parameter_prototype)

return prototypes

def _create_endpoint_output_prototypes(self, endpoint: EndpointDefinition) -> List[Any]:
prototypes: List[Any] = []

for parameter in endpoint.outputs:
parameter_prototype = self._create_parameter_prototype(parameter)
prototypes.append(parameter_prototype)

return prototypes

def _create_parameter_prototype(self, parameter: ParameterDefinition) -> Any:
type_formula = self._type_formula_parser.parse_expression(parameter.type)
return self._create_prototype(type_formula)

def encode_constructor_input_parameters(self, values: List[Any]) -> List[bytes]:
return self._do_encode_endpoint_input_parameters("constructor", self.constructor_prototype, values)

def encode_upgrade_constructor_input_parameters(self, values: List[Any]) -> List[bytes]:
return self._do_encode_endpoint_input_parameters("upgrade", self.upgrade_constructor_prototype, values)

def encode_endpoint_input_parameters(self, endpoint_name: str, values: List[Any]) -> List[bytes]:
endpoint_prototype = self._get_endpoint_prototype(endpoint_name)
return self._do_encode_endpoint_input_parameters(endpoint_name, endpoint_prototype, values)

def _do_encode_endpoint_input_parameters(self, endpoint_name: str, endpoint_prototype: 'EndpointPrototype', values: List[Any]):
if len(values) != len(endpoint_prototype.input_parameters):
raise ValueError(f"for {endpoint_name}, invalid value length: expected {len(endpoint_prototype.input_parameters)}, got {len(values)}")

input_values = deepcopy(endpoint_prototype.input_parameters)
input_values_as_native_object_holders = cast(List[IPayloadHolder], input_values)

# Populate the input values with the provided arguments
for i, arg in enumerate(values):
input_values_as_native_object_holders[i].set_payload(arg)

input_values_encoded = self._serializer.serialize_to_parts(input_values)
return input_values_encoded

def decode_endpoint_output_parameters(self, endpoint_name: str, encoded_values: List[bytes]) -> List[Any]:
endpoint_prototype = self._get_endpoint_prototype(endpoint_name)
output_values = deepcopy(endpoint_prototype.output_parameters)
self._serializer.deserialize_parts(encoded_values, output_values)

output_values_as_native_object_holders = cast(List[IPayloadHolder], output_values)
output_native_values = [value.get_payload() for value in output_values_as_native_object_holders]
return output_native_values

def _get_custom_type_prototype(self, type_name: str) -> Any:
type_prototype = self.custom_types_prototypes_by_name.get(type_name)

if not type_prototype:
raise ValueError(f"custom type '{type_name}' not found")

return type_prototype

def _get_endpoint_prototype(self, endpoint_name: str) -> 'EndpointPrototype':
endpoint_prototype = self.endpoints_prototypes_by_name.get(endpoint_name)

if not endpoint_prototype:
raise ValueError(f"endpoint '{endpoint_name}' not found")

return endpoint_prototype

def _create_prototype(self, type_formula: TypeFormula) -> Any:
name = type_formula.name

if name == "bool":
return BoolValue()
if name == "u8":
return U8Value()
if name == "u16":
return U16Value()
if name == "u32":
return U32Value()
if name == "u64":
return U64Value()
if name == "i8":
return I8Value()
if name == "i16":
return I16Value()
if name == "i32":
return I32Value()
if name == "BigUint":
return BigUIntValue()
if name == "BigInt":
return BigUIntValue()
if name == "bytes":
return BytesValue()
if name == "utf-8 string":
return StringValue()
if name == "Address":
return AddressValue()
if name == "TokenIdentifier":
return TokenIdentifierValue()
if name == "CodeMetadata":
return BytesValue()
if name == "tuple":
return TupleValue([self._create_prototype(type_parameter) for type_parameter in type_formula.type_parameters])
if name == "Option":
type_parameter = type_formula.type_parameters[0]
return OptionValue(self._create_prototype(type_parameter))
if name == "List":
type_parameter = type_formula.type_parameters[0]
return ListValue([], item_creator=lambda: self._create_prototype(type_parameter))
if name == "optional":
# The prototype of an optional is provided a value (the placeholder).
type_parameter = type_formula.type_parameters[0]
return OptionalValue(self._create_prototype(type_parameter))
if name == "variadic":
type_parameter = type_formula.type_parameters[0]
return VariadicValues([], item_creator=lambda: self._create_prototype(type_parameter))
if name == "multi":
return MultiValue([self._create_prototype(type_parameter) for type_parameter in type_formula.type_parameters])

# Handle custom types
type_prototype = self._get_custom_type_prototype(name)
return deepcopy(type_prototype)

@classmethod
def load(cls, path: Path) -> 'Abi':
definition = AbiDefinition.load(path)
return cls(definition)


class EndpointPrototype:
def __init__(self, input_parameters: List[Any], output_parameters: List[Any]) -> None:
self.input_parameters = input_parameters
self.output_parameters = output_parameters
Loading
Loading