In [108]:
from enum import Enum
from pydantic import BaseModel, Field
from typing import Any, Literal, Optional, Union, List, TypeVar, Generic, Type, Dict

class BoundingBox(BaseModel):
    x1: float = Field(..., description="X-coordinate of the top-left corner")
    y1: float = Field(..., description="Y-coordinate of the top-left corner")
    x2: float = Field(..., description="X-coordinate of the bottom-right corner")
    y2: float = Field(..., description="Y-coordinate of the bottom-right corner")


# ==============

class MediaType(str, Enum):
    PDF_IMAGE = "pdf_image"
    RAW_IMAGE = "raw_image"
    PDF_TEXT = "pdf_text"
    VIDEO_TRANSCRIPT = "video_transcript"
    WEBSITE_QA = "website_qa"
    WEBSITE_EXPERIENCE = "website_experience"

class DocumentChunkRegistry:
    _registry: Dict[MediaType, Type['BaseDocumentChunk']] = {}
    _properties: List[str] = []
    @classmethod
    def register(cls, chunk_class: Type['BaseDocumentChunk']):
        if ("media_type" in chunk_class.model_fields):
            media_type = chunk_class.model_fields['media_type'].default
            cls._registry[media_type] = chunk_class

        cls._properties.extend([prop for prop in list(chunk_class.model_fields.keys()) if prop not in cls._properties and prop != "metadata"])

        return chunk_class

    @classmethod
    def get(cls, media_type: MediaType) -> Type['BaseDocumentChunk']:
        return cls._registry.get(media_type)


    @classmethod
    def all(cls) -> List[Type['BaseDocumentChunk']]:
        return list(cls._registry.values())

    @classmethod
    def get_properties(cls) -> List[Type['BaseDocumentChunk']]:
        return cls._properties

    @classmethod
    def media_types(cls) -> List[MediaType]:
        return list(cls._registry.keys())

def register_document_chunk(cls: Type['BaseDocumentChunk']) -> Type['BaseDocumentChunk']:
    return DocumentChunkRegistry.register(cls)

class MetadataRegistry:
    _properties: List[str] = []

    @classmethod
    def register(cls, metadata_class: Type['BaseModel']):
        cls._properties.extend(list(metadata_class.model_fields.keys()))
        return metadata_class

    @classmethod
    def get(cls, media_type: MediaType) -> Type['BaseModel']:
        return cls._properties.get(media_type)

    @classmethod
    def get_properties(cls) -> List[Type['BaseModel']]:
        return cls._properties

def register_metadata(cls: Type['BaseModel']) -> Type['BaseModel']:
    return MetadataRegistry.register(cls)

# ==============


@register_document_chunk
class Document(BaseModel):
    uuid: str
    document_id: str
    public_path: Optional[str] = Field(default="")
    original_path: str
    s3_object_name: Optional[str] = Field(default="")
    media_name: str


class BaseDocumentChunk(BaseModel):
    uuid: str
    document: Document
    text: str = ""
    title: str
    media_type: MediaType
    metadata: Union[dict, BaseModel]

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        DocumentChunkRegistry.register(cls)

    @classmethod
    def model_validate(cls, obj: Any, *args, **kwargs):
        if isinstance(obj, dict):
            metadata = {}
            for key in list(obj.keys()):
                if key.startswith('meta_'):
                    metadata[key[5:]] = obj.pop(key)
            if metadata:
                obj['metadata'] = metadata
        return super().model_validate(obj, *args, **kwargs)

@register_metadata
class VideoTranscriptMetadata(BaseModel):
    start: float
    end: float

@register_document_chunk
class VideoTranscriptChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.VIDEO_TRANSCRIPT
    metadata: VideoTranscriptMetadata

@register_metadata
class PdfImageMetadata(BaseModel):
    s3_object_name: str
    page_number: int
    bbox: BoundingBox

@register_document_chunk
class PdfImageChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.PDF_IMAGE
    metadata: PdfImageMetadata

@register_metadata
class RawImageMetadata(BaseModel):
    s3_object_name: str

@register_document_chunk
class RawImageChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.RAW_IMAGE
    metadata: RawImageMetadata

@register_metadata
class PdfTextMetadata(BaseModel):
    page_number: int
    bbox: BoundingBox

@register_document_chunk
class PdfTextChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.PDF_TEXT
    metadata: PdfTextMetadata

@register_metadata
class WebsiteQAMetadata(BaseModel):
    question: str
    answer: str
    url: str

@register_document_chunk
class WebsiteQAChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.WEBSITE_QA
    metadata: WebsiteQAMetadata

@register_metadata
class WebsiteExperienceMetadata(BaseModel):
    description: str
    title: str
    type: str
    url: str

@register_document_chunk
class WebsiteExperienceChunk(BaseDocumentChunk):
    media_type: MediaType = MediaType.WEBSITE_EXPERIENCE
    metadata: WebsiteExperienceMetadata

DocumentChunkMetadata = Union[RawImageMetadata, PdfImageMetadata, PdfTextMetadata, VideoTranscriptMetadata, WebsiteQAMetadata, WebsiteExperienceMetadata]
DocumentChunk = Union[RawImageChunk, PdfImageChunk, PdfTextChunk, VideoTranscriptChunk, WebsiteQAChunk, WebsiteExperienceChunk]

ChunkType = TypeVar('ChunkType', bound=BaseDocumentChunk)

class ChunkWithScore(Generic[ChunkType], BaseDocumentChunk):
    metadata: DocumentChunkMetadata
    score: float

    @classmethod
    def model_validate(cls, obj: Any, *args, **kwargs):
        if isinstance(obj, dict) and 'metadata' in obj:
            chunk_type = DocumentChunkRegistry.get(MediaType(obj['media_type']))
            if chunk_type:
                metadata_type = chunk_type.__annotations__['metadata']
                obj['metadata'] = metadata_type.model_validate(obj['metadata'])
        return super().model_validate(obj, *args, **kwargs)

class DocumentWithChunks(Document):
    chunks: List[DocumentChunk]

  class ChunkWithScore(Generic[ChunkType], BaseDocumentChunk):


In [110]:
# print(dir(DocumentChunkRegistry.all()[1].model_fields.get('metadata')))
# DocumentChunkRegistry.all()[1].model_fields.get('metadata
MetadataRegistry.get_properties()
DocumentChunkRegistry.get_properties()

['uuid',
 'document_id',
 'public_path',
 'original_path',
 's3_object_name',
 'media_name',
 'document',
 'text',
 'title',
 'media_type']

In [28]:
document =  {
    "uuid": "32782090-acb0-44b6-9604-b6239a561aa2",
    "document_id": "70cc4c0c-7d44-4ac7-a241-a0fb8820e380",
    "public_path": "",
    "original_path": "/home/erwan/Desktop/clients/ScienceInfuse/server/ftp-data/2015 - Dinosaures.pdf",
    "s3_object_name": "pdf/70cc4c0c-7d44-4ac7-a241-a0fb8820e380.pdf",
    "media_name": "2015 - Dinosaures"
}
properties = {
    "uuid": "50ebbdfc-0342-475a-9c2b-1a169985e8b6",
    "text": "Dessin d'un dinosaure. Le dinosaure est brun. Il y a beaucoup de trous sur la tête du dinosaure. Il y a une longue queue sur le dinosaure.",
    "title": "",
    "media_type": "pdf_image",
    "metadata": {
        "s3_object_name": "pdf/70cc4c0c-7d44-4ac7-a241-a0fb8820e380.pdf/images/35a0e77e-5c40-49fa-b9df-ad3bcb7bf50f.png",
        "page_number": 10,
        "bbox": {
            "x1": 0.0,
            "y1": 0.0,
            "x2": 1.0,
            "y2": 1.0
        }
    },
    "score": 0.44338053464889526
}

chunk_with_score = ChunkWithScore.model_validate({**properties, "score": 1, "document": document, "uuid": "uuid"})
print(f"Final metadata type: {type(chunk_with_score.metadata)}")
print(f"Final metadata: {chunk_with_score.metadata}")


Final metadata type: <class '__main__.PdfImageMetadata'>
Final metadata: s3_object_name='pdf/70cc4c0c-7d44-4ac7-a241-a0fb8820e380.pdf/images/35a0e77e-5c40-49fa-b9df-ad3bcb7bf50f.png' page_number=10 bbox=BoundingBox(x1=0.0, y1=0.0, x2=1.0, y2=1.0)


In [9]:
document =  {
    "uuid": "32782090-acb0-44b6-9604-b6239a561aa2",
    "document_id": "70cc4c0c-7d44-4ac7-a241-a0fb8820e380",
    "public_path": "",
    "original_path": "/home/erwan/Desktop/clients/ScienceInfuse/server/ftp-data/2015 - Dinosaures.pdf",
    "s3_object_name": "pdf/70cc4c0c-7d44-4ac7-a241-a0fb8820e380.pdf",
    "media_name": "2015 - Dinosaures"
}
properties = {
    "uuid": "50ebbdfc-0342-475a-9c2b-1a169985e8b6",
    "text": "Dessin d'un dinosaure. Le dinosaure est brun. Il y a beaucoup de trous sur la tête du dinosaure. Il y a une longue queue sur le dinosaure.",
    "title": "",
    "media_type": "pdf_image",
    "metadata": {
        "s3_object_name": "pdf/70cc4c0c-7d44-4ac7-a241-a0fb8820e380.pdf/images/35a0e77e-5c40-49fa-b9df-ad3bcb7bf50f.png",
        "page_number": 10,
        "bbox": {
            "x1": 0.0,
            "y1": 0.0,
            "x2": 1.0,
            "y2": 1.0
        }
    },
    "score": 0.44338053464889526
}
chunk_with_score = ChunkWithScore.model_validate({**properties, "score": 1, "document": document, "uuid": "uuid"})
chunk_with_score.metadata.page_number


AttributeError: 'RawImageMetadata' object has no attribute 'page_number'