In [1]:
import copy
import types
import logging

import hunter
import numpy

from cicada.arithmetic import Field
from cicada import transcript

In [2]:
logging.basicConfig(level=logging.INFO, style="{", format = "{message}")

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.propagate = False

class LogExpressions(hunter.actions.Action):
    def __init__(self):
        self.stack = []
        
    def __call__(self, event):
        if not hasattr(event.function_object, "__qualname__"):
            return

        if event.kind == "call":
            self.stack.append(copy.deepcopy(event.locals))
            return

        if event.kind == "return":
            fqname = event.module + "." + event.function_object.__qualname__
            qname = event.function_object.__qualname__
            name = event.function_object.__name__
            args = self.stack.pop()
            locals = event.locals
            result = event.arg

            # Hide __init__ functions.
            if name in ["__init__"]:
                return

            # Hide private functions.
            if name.startswith("_") and not name.startswith("__"):
                return

            # Hide unimportant functions.
            if fqname in [
                "cicada.arithmetic.Field.bytes",
                "cicada.arithmetic.Field.dtype",
                "cicada.arithmetic.Field.order",
                ]:
                return

            if fqname == "cicada.arithmetic.Field.inplace_add":
                o = self.repr(args["self"])
                lhs = self.repr(args["lhs"])
                rhs = self.repr(args["rhs"])
                result = self.repr(locals["lhs"])
                
                logger.info(f"lhs = {lhs}")
                logger.info(f"{o}.{name}(lhs=lhs, rhs={rhs})")
                logger.info(f"cicada.transcript.assert_equal(lhs, {result})")
                
            elif fqname == "cicada.arithmetic.Field.uniform":
                o = self.repr(args["self"])
                size = self.repr(args["size"])
                bg = self.repr(args["generator"].bit_generator)
                state = self.repr(args["generator"].bit_generator.state)
                result = self.repr(result)
                
                logger.info(f"bg = {bg}")
                logger.info(f"bg.state = {state}")
                logger.info(f"cicada.transcript.assert_equal({o}.{name}(size={size}, generator=numpy.random.Generator(bg)), {result})")
            
            elif "self" in args:
                o = self.repr(args["self"])
                signature = ", ".join([f"{key}={self.repr(value)}" for key, value in args.items() if key != "self"])
                result = self.repr(result)
                
                logger.info(f"cicada.transcript.assert_equal({o}.{name}({signature}), {result})")
                
            else:
                raise NotImplementedError()
                
            logger.info("")


    def repr(self, o):
        if isinstance(o, numpy.ndarray):
            return f"numpy.array({o.tolist()}, dtype={o.dtype})"
        if isinstance(o, numpy.random.Generator):
            return f"numpy.random.Generator({self.repr(o.bit_generator)})"
        if isinstance(o, numpy.random.PCG64):
            return f"numpy.random.PCG64()"
        return repr(o)


with hunter.trace(module_startswith="cicada", kind_in=("call", "return"), action=LogExpressions()):
    f = Field(order=127)
    a = f.ones(3)
    b = f.uniform(size=3, generator=numpy.random.default_rng())
    f.inplace_add(a, b)

cicada.transcript.assert_equal(cicada.arithmetic.Field(order=127).ones(shape=3), numpy.array([1, 1, 1], dtype=object))

bg = numpy.random.PCG64()
bg.state = {'bit_generator': 'PCG64', 'state': {'state': 230203933409659715196253238105416408377, 'inc': 181113771745324753386854769399717100255}, 'has_uint32': 0, 'uinteger': 0}
cicada.transcript.assert_equal(cicada.arithmetic.Field(order=127).uniform(size=3, generator=numpy.random.Generator(bg)), numpy.array([42, 58, 227], dtype=object))

lhs = numpy.array([1, 1, 1], dtype=object)
cicada.arithmetic.Field(order=127).inplace_add(lhs=lhs, rhs=numpy.array([42, 58, 227], dtype=object))
cicada.transcript.assert_equal(lhs, numpy.array([43, 59, 101], dtype=object))

