In [None]:
import time

import pandas as pd

import gretel_trainer.b2 as b2

# Gretel dataset
repo = b2.GretelDatasetRepo()
iris = repo.get_dataset("iris")

# Custom datasets (DF, CSV)
bikes = b2.make_dataset(pd.read_csv("~/Downloads/bikebuying.csv"), datatype="tabular", name="bikes")
tiny = b2.make_dataset("~/Downloads/tiny.csv", datatype=b2.Datatype.tabular, name="tiny")

# Custom model
class EchoModel:
    def train(self, source, **kwargs):
        time.sleep(4)
        self.df = pd.read_csv(source.data_source)

    def generate(self, **kwargs):
        time.sleep(2)
        out_path = kwargs["preferred_out_path"]
        self.df.to_csv(out_path)
        return out_path

In [None]:
# SDK

sdk_comp = b2.compare(
    datasets=[iris, bikes, tiny], 
    models=[b2.GretelAmplify, EchoModel], 
    project_display_name="b2sdk",
    working_dir="b2sdk",
    trainer=False,
)

In [None]:
sdk_comp.results

In [None]:
# Trainer

trainer_comp = b2.compare(
    datasets=[iris, bikes, tiny], 
    models=[b2.GretelAmplify], 
    project_display_name="b2trainer",
    working_dir="b2trainer",
    trainer=True,
)

In [None]:
trainer_comp.results

Below: let's see if I can make model configs that significantly reduce the amount of preview data and skip SQS altogether

In [9]:
import gretel_client as gc

def config_dict(blueprint):
    return gc.projects.models.read_model_config(blueprint)

def model_params(config_dict, model_name):
    return config_dict["models"][0][model_name]

MINIMAL_GENERATE = {"num_records": 10}
SKIP_EVALUATE = {"skip": True}


amplify = config_dict("synthetics/amplify")
model_params(amplify, "amplify")["params"]["num_records"] = 10
model_params(amplify, "amplify")["evaluate"] = SKIP_EVALUATE

actgan = config_dict("synthetics/tabular-actgan")
model_params(actgan, "actgan")["generate"] = MINIMAL_GENERATE
model_params(actgan, "actgan")["evaluate"] = SKIP_EVALUATE

lstm = config_dict("synthetics/tabular-lstm")
model_params(lstm, "synthetics")["generate"] = MINIMAL_GENERATE
model_params(lstm, "synthetics")["evaluate"] = SKIP_EVALUATE


print("AMPLIFY")
print(amplify)
print()
print("ACTGAN")
print(actgan)
print()
print("LSTM")
print(lstm)
print()

AMPLIFY
{'schema_version': '1.0', 'name': 'data-amplification-model', 'models': [{'amplify': {'data_source': '__tmp__', 'params': {'num_records': 10, 'target_size_mb': None}, 'evaluate': {'skip': True}}}]}

ACTGAN
{'schema_version': '1.0', 'name': 'tabular-actgan', 'models': [{'actgan': {'data_source': '__tmp__', 'params': {'epochs': 'auto', 'generator_dim': [1024, 1024], 'discriminator_dim': [1024, 1024], 'generator_lr': 0.0001, 'discriminator_lr': 0.00033, 'batch_size': 'auto'}, 'generate': {'num_records': 10}, 'privacy_filters': {'outliers': 'auto', 'similarity': 'auto'}, 'evaluate': {'skip': True}}}]}

LSTM
{'schema_version': '1.0', 'name': 'tabular-lstm', 'models': [{'synthetics': {'data_source': '__tmp__', 'params': {'epochs': 'auto', 'vocab_size': 'auto', 'learning_rate': 'auto', 'batch_size': 'auto', 'rnn_units': 'auto'}, 'generate': {'num_records': 10}, 'privacy_filters': {'outliers': 'auto', 'similarity': 'auto'}, 'evaluate': {'skip': True}}}]}



In [None]:
p = gc.create_or_get_unique_project(name="mktest")
for config in [amplify, actgan, lstm]:
    m = p.create_model_obj(model_config=config, data_source="~/Downloads/bikebuying.csv")
    m.submit_cloud()