In [None]:
!pip install -U gretel-trainer

If running in Colab, the `pip install` command above will update the `matplotlib` library under the hood, but the previously installed version has already been imported automatically by Colab. As `pip`'s log output should suggest, you need to restart the Colab runtime to use the new version (and, by extension, import and use Benchmark).

In [None]:
import gretel_trainer.benchmark as b

## Datasets

### From your own data

In [None]:
my_demo_data = b.make_dataset(["~/Downloads/demo.csv"], datatype="tabular_mixed")

### From Gretel

In [None]:
datasets = []
# datasets = b.list_gretel_datasets()
# datasets = b.list_gretel_datasets(datatype="time_series")
# datasets = b.list_gretel_datasets(datatype="tabular_mixed", tags=["small", "marketing"])

[dataset.name for dataset in datasets]

In [None]:
b.list_gretel_dataset_tags()

In [None]:
# Select a specific dataset by name
iris = b.get_gretel_dataset("iris")

## Models

### Gretel defaults

Preconfigured based on [public blueprints](https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics).

In [None]:
from gretel_trainer.benchmark import (
    GretelAmplify,
    GretelAuto,
    GretelCTGAN,
    GretelGPTX,
    GretelLSTM,
)

### Customized Gretel models

In [None]:
from gretel_trainer.benchmark import GretelModel


class TunedLSTM(GretelModel):
    config = "/path/to/my_config.yml"


class TweakedCtgan(GretelModel):
    config = {...}

### Completely custom, non-Gretel models

In [None]:
import time

import pandas as pd


class MyCustomModel:
    def train(self, source: str, **kwargs) -> None:
        self.source_df = pd.read_csv(source)
        time.sleep(8)
        return None

    def generate(self, **kwargs) -> pd.DataFrame:
        time.sleep(3)
        return self.source_df.sample(frac=0.6)

## Launch a Benchmark Comparison!

In [None]:
comparison = b.compare(datasets=[my_demo_data, iris], models=[GretelLSTM, GretelAmplify])

In [None]:
comparison.results

In [None]:
comparison.wait()
comparison.export_results("./results.csv")