In [111]:
%load_ext autoreload
%autoreload 2

In [1]:
import json
import qat

from qat.utils.pydantic import WarnOnExtraFieldsModel

from pydantic import ConfigDict, Field, model_validator, field_serializer, ValidationInfo, field_validator, PlainSerializer
from typing import Dict, List, Tuple, Union, Optional
import typing 
from typing import Annotated

import numpy as np

In [2]:
componentid = str
class HasId(WarnOnExtraFieldsModel):
    id: componentid


T = typing.TypeVar('T')
SerializeDictToIds = PlainSerializer(lambda d: {k : v.id for (k,v) in d.items()}, return_type=dict, when_used='always')
ComponentDict = Annotated[Dict[componentid, T], SerializeDictToIds]

SerializeListToIds = PlainSerializer(lambda d: [v.id for v in d], return_type=list, when_used='always')
ComponentList = Annotated[List[T], SerializeListToIds]

In [3]:
def get_field_type_and_container(cls, field):
    field_type = cls.model_fields[field].annotation
    container = typing.get_origin(field_type)
    field_cls = None
    if container is list:
        field_cls = typing.get_args(field_type)[0]
    if container is dict:
        field_cls = typing.get_args(field_type)[1]
    return container, field_cls

def get_or_create(cls, context, data : str|dict|HasId):
    if issubclass(type(data), HasId):
        return data 
    elif isinstance(data, str):
        return context.get(data)
    else:
        return cls(context=context, **data)

class HasId(WarnOnExtraFieldsModel):
    id: str


    def __init__(self, context = None, *args, **kw):
        kw = self.populate(context=context, **kw)
        super().__init__(*args, **kw)

    @classmethod
    def populate(cls, context = None, **data):
        ''' Load a model from a serialised dict '''
        if context is None: 
            context = {}

        for field in cls.model_fields:
            container, field_cls = get_field_type_and_container(cls, field)
            if container and issubclass(field_cls, HasId):
                if container is dict and isinstance(data[field], dict):
                    data[field] = {k : get_or_create(field_cls,context, v) for (k,v) in data[field].items()}
                    new_context_items = data[field].values()
                elif container is list and isinstance(data[field], list):
                    data[field] = [get_or_create(field_cls,context, d) for d in data[field]]
                    new_context_items = data[field]
                else:
                    found_type = type(data[field])
                    raise Exception(f'Container is {container}, found {found_type}')

                context.update({new.id : new for new in new_context_items})
        return data

In [4]:
class At(HasId):
    x: int

class Bt(HasId):
    x: int
    As: ComponentDict[At]

class Ct(HasId):
    x: int
    As: ComponentDict[At]
    Bs: ComponentList[Bt]

class Dt(HasId):
    x: int
    Cs: ComponentDict[Ct]

In [5]:
class Outer(HasId):
    A: List[At]
    B: List[Bt]
    C: List[Ct]
    D: List[Dt]

This model is a DAG which means it can be serialised in layers

In [6]:
pick = lambda L, size=3: {l.id: l for l in np.random.choice(L, size=size)}

A = [At(x=i, id='A' + str(i)) for i in range(10)]
B = [Bt(x=i, id='B' + str(i), As=pick(A,3)) for i in range(10)]
C = [Ct(x=i, id='C' + str(i), As=pick(A,3), Bs=list(pick(B,3).values())) for i in range(10)]
D = [Dt(x=i, id='D' + str(i), Cs=pick(C,3)) for i in range(5)]

In [8]:
O1 = Outer(A=A, B=B,C=C,D=D, id='outer')
blob = O1.model_dump()
O2 = Outer(**blob)
assert(O2 == O1)