Skip to content

Commit

Permalink
extrac_as_obj fields
Browse files Browse the repository at this point in the history
  • Loading branch information
mike0sv committed May 24, 2024
1 parent 922ffb7 commit 5b2f875
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/evidently/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import FrozenSet
from typing import Iterable
from typing import List
from typing import Optional
Expand Down Expand Up @@ -255,7 +256,7 @@ class Config:
frozen = True

path: str
tags: Set["IncludeTags"]
tags: FrozenSet[FieldTags]
classpath: str

def __lt__(self, other):
Expand Down Expand Up @@ -337,6 +338,10 @@ def list_nested_fields(self, exclude: Set["IncludeTags"] = None) -> List[str]:
def _list_with_tags(self, current_tags: Set["IncludeTags"]) -> List[Tuple[str, Set["IncludeTags"]]]:
if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel):
return [(repr(self), current_tags)]
from evidently.core import BaseResult

if issubclass(self._cls, BaseResult) and self._cls.__config__.extract_as_obj:
return [(repr(self), current_tags)]
res = []
for name, field in self._cls.__fields__.items():
field_value = field.type_
Expand Down Expand Up @@ -368,15 +373,17 @@ def list_nested_fields_with_tags(self) -> List[Tuple[str, Set["IncludeTags"]]]:

def list_nested_field_infos(self) -> List[FieldInfo]:
return [
FieldInfo(path=path, tags=tags, classpath=get_classpath(self._get_field_info(path.split(".")).type_))
FieldInfo(
path=path, tags=frozenset(tags), classpath=get_classpath(self._get_field_info(path.split(".")).type_)
)
for path, tags in self.list_nested_fields_with_tags()
]

def _get_field_info(self, path: List[str]) -> ModelField:
if len(path) == 0:
raise ValueError("Empty path provided")
if len(path) == 1:
if isinstance(self._cls, BaseModel):
if isinstance(self._cls, type) and issubclass(self._cls, BaseModel):
return self._cls.__fields__[path[0]]
raise NotImplementedError(f"Not implemented for {self._cls.__name__}")
child, *path = path
Expand Down

0 comments on commit 5b2f875

Please sign in to comment.