Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lift gretel model compatibility to separate module #30

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/gretel_trainer/benchmark/gretel/compatibility.py
@@ -0,0 +1,34 @@
from typing import Optional

from gretel_trainer.benchmark.core import DataSource, Datatype


def is_runnable(model_key: Optional[str], source: DataSource) -> bool:
if model_key is None:
return True
elif model_key in ("lstm", "synthetics"):
return _lstm(source)
elif model_key in ("ctgan", "actgan"):
return _ctgan(source)
elif model_key in ("gpt_x"):
return _gptx(source)
elif model_key in ("amplify"):
return _amplify(source)
else:
return True


def _lstm(source: DataSource) -> bool:
return source.column_count <= 150


def _ctgan(source: DataSource) -> bool:
return True


def _gptx(source: DataSource) -> bool:
return source.column_count == 1 and source.datatype == Datatype.NATURAL_LANGUAGE


def _amplify(source: DataSource) -> bool:
return True
7 changes: 2 additions & 5 deletions src/gretel_trainer/benchmark/gretel/sdk.py
Expand Up @@ -8,6 +8,7 @@
from gretel_trainer.benchmark.gretel.models import GretelModel, GretelModelConfig

import gretel_client.helpers
import gretel_trainer.benchmark.gretel.compatibility as compatibility

from gretel_client.evaluation.quality_report import QualityReport
from gretel_client.projects.projects import create_or_get_unique_project, search_projects
Expand Down Expand Up @@ -92,11 +93,7 @@ def model_name(self) -> str:
return self.model.name

def runnable(self, source: DataSource) -> bool:
if self.model_key == "gpt_x":
if source.column_count > 1 or source.datatype != Datatype.NATURAL_LANGUAGE:
return False

return True
return compatibility.is_runnable(self.model_key, source)

def train(self, source: str, **kwargs) -> None:
project = self.sdk.create_project(self.project_name)
Expand Down
9 changes: 3 additions & 6 deletions src/gretel_trainer/benchmark/gretel/trainer.py
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import Protocol

import gretel_trainer
import gretel_trainer.benchmark.gretel.compatibility as compatibility
import pandas as pd

from gretel_trainer.benchmark.core import DataSource
Expand All @@ -26,7 +27,7 @@ def _get_trainer_model_type(
elif model_name == "ctgan":
model_class = models.GretelCTGAN
else:
raise Exception(f"Unexpected model name 'f{model_name}' in config")
raise Exception(f"Unexpected model name '{model_name}' in config")

return model_class(config=config_dict)

Expand Down Expand Up @@ -62,11 +63,7 @@ def model_name(self) -> str:
return self.model.name

def runnable(self, source: DataSource) -> bool:
if self.model_key is not None and self.model_key in ("lstm", "synthetics"):
if source.column_count > 150:
return False

return True
return compatibility.is_runnable(self.model_key, source)

def train(self, source: str, **kwargs) -> None:
Path(self.benchmark_dir).mkdir(exist_ok=True)
Expand Down