In [1]:
import sys
sys.path.append("../")

from dataclasses import dataclass, field
from architecture.utils.id_management import generate_id
from architecture.problem.ProblemSolver import ProblemSolver, ProblemBehavior
from architecture.layer.IntepretationLayer import IntepretationLayer
from architecture.layer.ForwardLayer import ForwardLayer
from architecture.layer.RepresentationLayer import RepresentationLayer
from architecture.units.MemoryUnit import MemoryUnit
from architecture.units.CombineRepresent import CombineRepresent
from architecture.units.PropertyUnit import PropertyUnit
from architecture.units.CombineProperty import BooleanUtilization, ScaleUtilization, OptionUtilization
from architecture.units.manage import ReadOnlyUnit
from architecture.utils.list_operator import ReadOnlyList
from architecture.problem.vision.RepresentUnit import FlattenUnit
from torch import nn, no_grad, randn

In [2]:
class ProblemPerson(ProblemSolver):
    @dataclass
    class PropertyPerson:
        name : str = ""
        age : int = 0
        sex : bool = False
        role : str = ""
        extra : dict = field(default_factory=dict)

        @classmethod
        def from_dict(cls, data: dict):
            known_fields = {f.name for f in cls.__dataclass_fields__.values()}
            init_fields = {k: v for k, v in data.items() if k in known_fields}
            extra_fields = {k: v for k, v in data.items() if k not in known_fields}
            obj = cls(**init_fields)
            obj.extra = extra_fields
            return obj


    class NonCode(ProblemBehavior):
        def __init__(self, _id=None, phi_dim : int = None, *args, **kwargs):
            super().__init__(_id, *args, **kwargs)
            self.represent = RepresentationLayer(
                [FlattenUnit(metadata={}, phi_dim=phi_dim, num_heads=2, img_shape=(28, 28, 1)) for _ in range(4)]
            )
            self.mem = MemoryUnit(metadata={}, phi_dim=phi_dim, components=64)
            self._combine_represent = ForwardLayer(
                ReadOnlyUnit([CombineRepresent(metadata={}, mem_unit=self.mem, phi_dim=phi_dim, m_dim=4) for _ in range(3)])
            )
            self._properties = ForwardLayer(
                ReadOnlyUnit([PropertyUnit(metadata={}, components=64, phi_dim=phi_dim) for _ in range(4)])
            )
            self.interpretation = IntepretationLayer(
                [
                    ScaleUtilization(m_dim=4, phi_dim=phi_dim, property_name="age", metadata=dict()),
                    BooleanUtilization(m_dim=4, phi_dim=phi_dim, property_name="sex", metadata=dict()),
                    OptionUtilization(m_dim=4, phi_dim=phi_dim, property_name="role", metadata=dict(), options=ReadOnlyList(["teacher", "coder", "programmer"]))
                ]
            )

        def interpretation_layer(self):
            return self.interpretation

        def representation_layer(self):
            return self.represent
        
        def intepret(self, *args, **kwargs):
            with no_grad():
                r = self._forward(*args, **kwargs)
                raw = self.interpretation_layer().intepret(r)
            return raw

        def _forward(self, x, *args, **kwargs):
            r = self.representation_layer()(x, *args, **kwargs)
            r = self._combine_represent(r)
            r = r.mean(dim=1)
            r = self._properties(r)
            return r

        def forward(self, x, *args, **kwargs):
            r = self._forward(x, *args, **kwargs)
            out = self.interpretation_layer()(r, *args, **kwargs)
            return out
        
        def recognize(self, *args, **kwargs):
            pass
        
        def save(self, *args, **kwargs):
            pass
        
        @property
        def units(self):
            pass

    class Person:
        def __init__(self, *args, **kwargs):
            self._infor = ProblemPerson.PropertyPerson.from_dict(kwargs)
        
        def __str__(self):
            return f"{self._infor.name}, {self._infor.age}"
        
        def hello_world(self):
            print(f"Hello {self._infor.name}")

    def __init__(self, *args, **kwargs):
        _id = generate_id()
        non_coding_behavior = self.NonCode(_id, *args, **kwargs)
        super().__init__(non_coding_behavior=non_coding_behavior, *args, **kwargs)
    
    def as_entity(self, x, *args, **kwargs):
        raw = self.noncode.intepret(x, *args, **kwargs)
        assert self.satisfy_rules(raw)
        raw.update(**kwargs)
        return self.Person(**raw)

    def recognize(self, *args, **kwargs):
        pass

    def satisfy_rules(self, raw_property, *args, **kwargs):
        print(raw_property)
        return True
    
    def intepret(self, *args, **kwargs):
        pass

In [3]:
problem_solver = ProblemPerson(phi_dim=128)

In [4]:
x = randn(32, 28, 28, 1)

In [5]:
y = problem_solver(x)
y

{'age': tensor([[-0.4041],
         [-0.7225],
         [-0.8790],
         [-0.7698],
         [-1.8365],
         [-1.5981],
         [ 0.0817],
         [-0.4488],
         [-1.0494],
         [-0.4833],
         [-0.3567],
         [-0.6146],
         [-1.2225],
         [-1.3197],
         [-0.7765],
         [-0.8190],
         [-1.5614],
         [-1.1196],
         [-0.8656],
         [-0.6057],
         [-1.3940],
         [-0.6412],
         [-1.7297],
         [-0.5954],
         [-1.5787],
         [-0.9156],
         [-0.4824],
         [-0.9475],
         [-0.8723],
         [-0.8882],
         [-1.0135],
         [-1.0594]], grad_fn=<AddmmBackward0>),
 'sex': tensor([[0.4186],
         [0.4898],
         [0.3924],
         [0.4450],
         [0.4894],
         [0.4456],
         [0.3554],
         [0.4282],
         [0.4735],
         [0.4701],
         [0.4627],
         [0.4251],
         [0.4855],
         [0.4491],
         [0.5013],
         [0.4823],
         [0.40

In [4]:
z = randn(1, 28, 28, 1)

In [5]:
person = problem_solver.as_entity(z, name="Loc")

{'age': 0.6224932670593262, 'sex': True, 'sex.raw': 0.56653892993927, 'role': 'programmer', 'role.raw': 0.3963854908943176}


In [6]:
person.hello_world()

Hello Loc


In [7]:
print(person)

Loc, 0.6224932670593262


In [9]:
print(person._infor.extra)
print(person._infor.role)

{'sex.raw': 0.56653892993927, 'role.raw': 0.3963854908943176}
programmer
