In [None]:
import json
import os

from pathlib import Path
from typing import Union

from gretel_client.helpers import poll
from gretel_client.projects import get_project

fixtures = Path.cwd().absolute()

In [None]:
class ModelFixtures:
    def __init__(self, manifest_path: Path):
        self._manifest_file = manifest_path
        self._manifest = json.loads(self._manifest_file.read_text())

    def _write_manifest(self, key: str, value: str):
        self._manifest[key] = value
        print(self._manifest)
        self._manifest_file.write_text(json.dumps(self._manifest))

    def update_model(
        self, model_name: str, config: Union[Path, str], data_source: Path
    ):
        model = self._project.create_model_obj(config, data_source=str(data_source))
        model.submit_cloud()
        poll(model)
        self._write_manifest(model_name, model.model_id)

    def set_project(self, project_name: str):
        self._project = get_project(name=project_name)
        self._write_manifest("_project", self._project.name)


In [None]:
manifest = ModelFixtures(fixtures / "model_fixtures.json")
manifest.set_project(os.getenv("GRETEL_PROJECT", "gretel-client-project-pretrained"))

In [None]:
manifest.update_model(
    "synthetics_default", "synthetics/default", fixtures / "account-balances.csv"
)

In [None]:
manifest.update_model(
    "transforms_default",
    fixtures / "transforms_config.yml",
    fixtures / "account-balances.csv",
)

In [None]:
manifest.update_model(
    "classify_default",
    fixtures / "classify_config.yml",
    fixtures / "account-balances.csv",
)