In [97]:
import json
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, AnyHttpUrl, UUID4, model_validator
from datetime import date

# Data definition

## Data section

In [98]:
class DemographicFactors(BaseModel):
    gender: str
    sex: str
    age: str
    demographic_group: str
    location: str
    socioeconomics: str

In [99]:
class CatalogColumn(BaseModel):
    column_name: str
    description: Optional[str]
    type: str
    number_missing_values: int
    categories: Optional[List[str]]

In [100]:
class DQMetric(BaseModel):
    name: str
    description: str
    value: str
    reference: Optional[AnyHttpUrl]

In [101]:
class Data(BaseModel):
    id: str
    description: str
    n: int
    demographic_factors: DemographicFactors
    bias_declaration: str
    catalog: List[CatalogColumn]
    data_quality_assessment: List[DQMetric]
    clinicians: Optional[List[str]]
    patients: Optional[List[str]]

## Training algorithm

In [102]:
class PreprocessSteps(BaseModel):
    name: str
    description: str
    reference: AnyHttpUrl

In [103]:
class NameValue(BaseModel):
    name: str
    value: str

In [104]:
class NameVersion(BaseModel):
    name: str
    version: str

In [105]:
class ImplementationDetails(BaseModel):
    code: Optional[AnyHttpUrl]
    type_of_architecture: str
    preprocess_steps: Optional[List[PreprocessSteps]]
    hyperparameters: List[NameValue]
    programming_lang: NameVersion

In [106]:
class TrainingAlgorithm(BaseModel):
    id: str
    name: str
    description: str
    reference: Optional[AnyHttpUrl]
    implementation_details: ImplementationDetails

## Evaluation strategy

In [107]:
class RolesAndResponsibilities(BaseModel):
    user_id: str
    role_user: str
    responsibilities: str

In [108]:
class EvaluationTypeEnum(str, Enum):
    internal = 'internal'
    clinical = 'clinical'
    continual = 'continual'

In [109]:
class MeasurementTypeEnum(str, Enum):
    ai_performance = 'AI Performance'
    perceived_utility = 'Perceived Utility'
    perceived_usability = 'Perceived Usability'

In [110]:
class Metric(BaseModel):
    id: str
    name: str
    description: str
    instrument: Optional[AnyHttpUrl]
    reference: Optional[AnyHttpUrl]

In [111]:
class Step(BaseModel):
    name: str
    description: str
    reference: Optional[AnyHttpUrl]
    # List of Metric.id 
    metrics: Optional[List[str]]    

In [112]:
class EvaluationStrategy(BaseModel):
    id: str
    roles_and_responsibilities: Optional[List[RolesAndResponsibilities]]
    evaluation_type: EvaluationTypeEnum
    measurement_type: MeasurementTypeEnum
    metrics: List[Metric]
    steps: List[Step]
    code: Optional[AnyHttpUrl]
    reference: Optional[AnyHttpUrl]
    
    @model_validator(mode='after')
    def validate_metrics_fk(self):
        metrics_ids = []
        if not self.metrics:
            return
        for m in self.metrics:
            metrics_ids.append(m.id)
        # check for repeated ids
        if len(metrics_ids) > len(set(metrics_ids)):
            raise ValueError('Non unique IDs detected for Metric.id')
        # check foreign key
        if not self.steps:
            return
        for s in self.steps:
            if not s.metrics:
                continue
            for sm in s.metrics:
                if sm not in metrics_ids:
                    print(f'->{sm}')
                    raise ValueError(f'Step.metrics ID {sm} not defined before')


## AI Entity

In [113]:
class AIEntity(BaseModel):
    id: str
    manufacturer: str
    purpose: str
    release_date: date
    regulation_check: str
    ethical_declarations: List[str]
    standards: List[str]
    certifications: List[str]
    encryption: str
    field_tested_libraries: bool

In [114]:
class SaMD(BaseModel):
    ai_entity: str
    samd_name: str
    samd_clinical_purpose: str
    models: List[str]

In [115]:
class FoundationalModel(BaseModel):
    ai_entity: str
    name: str
    description: str
    doi: AnyHttpUrl

In [116]:
class XAI(BaseModel):
    name: str
    description: str
    reference: AnyHttpUrl

In [117]:
class ExternalDocument(BaseModel):
    document_id: str
    document_name: str
    document_reference: AnyHttpUrl

In [118]:
class Training(BaseModel):
    training_algorithm: str
    training_datasets: List[str]
    quality_roles_and_responsibilities: Optional[ExternalDocument]
    # test this
    training_start_date: date
    training_end_date: date

In [119]:
class AIModel(BaseModel):
    ai_entity: str
    foundational_model: Optional[str]
    xai_mechanism: List[XAI]
    training: Training

## Evaluations

In [120]:
class IDValue(BaseModel):
    id: str
    value: str

In [121]:
class Evaluation(BaseModel):
    ai_entity: str
    date_start: date
    date_end: date
    description: str
    evaluation_strategy: str
    datasets: Optional[List[str]]
    participants: Optional[List[str]]
    values: List[IDValue]
    reference: Optional[AnyHttpUrl]

## Complete Passport

In [122]:
class AIPassport(BaseModel):
    ai_entity_passport_uuid: UUID4
    data: Optional[List[Data]]
    training_algorithm: Optional[TrainingAlgorithm]
    evaluation_strategies: Optional[List[EvaluationStrategy]]
    composition_ai_entities: Optional[List[AIEntity]]
    SaMD: Optional[SaMD]
    foundational_model: Optional[FoundationalModel]
    ai_models: Optional[List[AIModel]]
    evaluations: List[Evaluation]

In [123]:
passport = AIPassport.model_json_schema()
with open('../ai_passport_validator.json', 'w') as file:
    json.dump(passport, file, indent=2)

# Testing

## SaMD example

In [124]:
with open ('../palliative_care_example/aleph_pc.json', 'r') as file:
    samd_json = json.load(file)
    AIPassport(**samd_json)

## Model example

In [125]:
with open ('../palliative_care_example/one_year_mortality.json', 'r') as file:
    oym_json = json.load(file)
    AIPassport(**oym_json)

In [126]:
with open ('../palliative_care_example/regression.json', 'r') as file:
    regression_json = json.load(file)
    AIPassport(**regression_json)

In [127]:
with open ('../palliative_care_example/one_year_frailty.json', 'r') as file:
    frailty_json = json.load(file)
    AIPassport(**frailty_json)