Skip to content

Commit

Permalink
add model schema
Browse files Browse the repository at this point in the history
* wip added file

* define model schema

* remove runnable from model schema add models to pipeline
  • Loading branch information
bjornaer authored Jan 11, 2022
1 parent 1ab1da1 commit 4e87224
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
51 changes: 51 additions & 0 deletions pipeline/schemas/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

from pydantic import root_validator

from pipeline.schemas.base import BaseModel
from pipeline.schemas.file import FileCreate, FileGet


class ModelBase(BaseModel):
id: Optional[str]
name: str


class ModelGet(ModelBase):
id: str
hex_file: FileGet

source_sample: str

class Config:
orm_mode = True


class ModelGetDetailed(ModelGet):
...


class ModelCreate(BaseModel):
# The local ID is assigned when a new model is used as part of a new
# pipeline; the server uses the local ID to associated a model to a
# Pipeline before replacing the local ID with the server-generated one
local_id: Optional[str]

model_source: str

name: str

file_id: Optional[str]
file: Optional[FileCreate]

@root_validator
def file_or_id_validation(cls, values):
file, file_id = values.get("file"), values.get("file_id")

file_defined = file is not None
file_id_defined = file_id is not None

if file_defined == file_id_defined:
raise ValueError("You must define either the file OR file_id of a model.")

return values
3 changes: 3 additions & 0 deletions pipeline/schemas/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pipeline.schemas.base import BaseModel
from pipeline.schemas.file import FileGet
from pipeline.schemas.function import FunctionGet
from pipeline.schemas.model import ModelGet
from pipeline.schemas.runnable import RunnableGet, RunnableType


Expand Down Expand Up @@ -46,6 +47,7 @@ class PipelineGet(RunnableGet):
type: RunnableType = Field(RunnableType.pipeline, const=True)
variables: List[PipelineVariableGet]
functions: List[FunctionGet]
models: List[ModelGet]
graph_nodes: List[PipelineGraphNode]
outputs: List[str]

Expand All @@ -61,6 +63,7 @@ class PipelineCreate(BaseModel):
name: str
variables: List[PipelineVariableGet]
functions: List[FunctionGet]
models: List[ModelGet]
graph_nodes: List[PipelineGraphNode]
outputs: List[str]
# models: Optional[dict]
1 change: 1 addition & 0 deletions pipeline/schemas/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RunnableGetDetailed(RunnableGet):
last_runs = []


# NOTE QUESTION: do we use these classes?
class FunctionGet(RunnableGet):
type: RunnableType = Field(RunnableType.function, const=True)

Expand Down

0 comments on commit 4e87224

Please sign in to comment.