# Add a new plugin

By default, the library will import all the files with prefix "plugin\_\*.py" from src/synthcity/plugins, and load all the classes which implement the [Plugin interface](src/synthcity/plugins/core/plugin.py).

Each plugin must implement the following methods:
- hyperparameter_space() - a static method that returns the hyperparameters that can be tuned during AutoML.
- type() - a static method that returns the type of the plugin. e.g., debug, generative, bayesian, etc.
- name() - a static method that returns the name of the plugin. e.g., ctgan, random_noisee, etc.
- _fit() - internal method, called by `fit` on each training set.
- _generate() - internal method, called by `generate`.

In [None]:
!pip install synthcity
!pip uninstall -y torchaudio torchdata

## Existing plugins

In [2]:
# synthcity absolute
from synthcity.plugins import Plugins

generators = Plugins()

generators.list()

  from .autonotebook import tqdm as notebook_tqdm


                  variable OMP_PATH to the location of the header before importing keopscore or pykeops,
                  e.g. using os.environ: import os; os.environ['OMP_PATH'] = '/path/to/omp/header'


[2025-02-17T20:10:17.504171+0900][2428][CRITICAL] module disabled: /Users/minkeychang/synthcity/src/synthcity/plugins/generic/plugin_goggle.py


['survival_nflow',
 'nflow',
 'syn_seq',
 'radialgan',
 'image_adsgan',
 'marginal_distributions',
 'survival_ctgan',
 'great',
 'image_cgan',
 'survival_gan',
 'pategan',
 'ddpm',
 'dummy_sampler',
 'tvae',
 'timevae',
 'decaf',
 'dpgan',
 'fflows',
 'uniform_sampler',
 'privbayes',
 'arf',
 'rtvae',
 'aim',
 'timegan',
 'ctgan',
 'adsgan',
 'bayesian_network',
 'survae']

## Example plugin: Generate 0-1

In [3]:
# stdlib
from typing import Any, List

# third party
import numpy as np
import pandas as pd

# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema


class ZeroOnePlugin(Plugin):
    """Dummy plugin for debugging."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)

    @staticmethod
    def name() -> str:
        return "zero_one"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
        return []

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "ZeroOnePlugin":
        self.features_count = X.shape[1]
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any):
        return GenericDataLoader(
            np.random.randint(0, 2, size=(count, self.features_count))
        )

In [4]:
# Add the new plugin to the collection

generators.add("zero_one", ZeroOnePlugin)

<synthcity.plugins.Plugins at 0x1056044f0>

In [None]:
# Check the new plugins list
generators.list()

In [5]:
# Load reference data

# third party
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True, as_frame=True)

loader = GenericDataLoader(X)

loader.dataframe()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
0,17.99,10.38,122.80,1001.0,0.11840,0.27760,0.30010,0.14710,0.2419,0.07871,...,25.380,17.33,184.60,2019.0,0.16220,0.66560,0.7119,0.2654,0.4601,0.11890
1,20.57,17.77,132.90,1326.0,0.08474,0.07864,0.08690,0.07017,0.1812,0.05667,...,24.990,23.41,158.80,1956.0,0.12380,0.18660,0.2416,0.1860,0.2750,0.08902
2,19.69,21.25,130.00,1203.0,0.10960,0.15990,0.19740,0.12790,0.2069,0.05999,...,23.570,25.53,152.50,1709.0,0.14440,0.42450,0.4504,0.2430,0.3613,0.08758
3,11.42,20.38,77.58,386.1,0.14250,0.28390,0.24140,0.10520,0.2597,0.09744,...,14.910,26.50,98.87,567.7,0.20980,0.86630,0.6869,0.2575,0.6638,0.17300
4,20.29,14.34,135.10,1297.0,0.10030,0.13280,0.19800,0.10430,0.1809,0.05883,...,22.540,16.67,152.20,1575.0,0.13740,0.20500,0.4000,0.1625,0.2364,0.07678
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,21.56,22.39,142.00,1479.0,0.11100,0.11590,0.24390,0.13890,0.1726,0.05623,...,25.450,26.40,166.10,2027.0,0.14100,0.21130,0.4107,0.2216,0.2060,0.07115
565,20.13,28.25,131.20,1261.0,0.09780,0.10340,0.14400,0.09791,0.1752,0.05533,...,23.690,38.25,155.00,1731.0,0.11660,0.19220,0.3215,0.1628,0.2572,0.06637
566,16.60,28.08,108.30,858.1,0.08455,0.10230,0.09251,0.05302,0.1590,0.05648,...,18.980,34.12,126.70,1124.0,0.11390,0.30940,0.3403,0.1418,0.2218,0.07820
567,20.60,29.33,140.10,1265.0,0.11780,0.27700,0.35140,0.15200,0.2397,0.07016,...,25.740,39.42,184.60,1821.0,0.16500,0.86810,0.9387,0.2650,0.4087,0.12400


In [None]:
# Train the new plugin

gen = generators.get("zero_one")

gen.fit(loader)

In [None]:
# Generate some new data

gen.generate(count=10)

### Oops, this didn't work.

__The Plugin interface enforces the new generated data to:__
 - satistify the same constraints as the training set.
 - Or to satisfy the constraints provided at inference time(if provided).
 
 
 If the generated dataframe fails to comply, an exception will be raised.

Let's try again

## A functional plugin: Integrate SDV CTGAN

In this example, we will show how to integrate the CTGAN implementation from the [SDV library](https://github.com/sdv-dev/SDV).

__Note__ : `synthcity` also includes a dedicated CTGAN re-implementation.

In [None]:
!pip install ctgan

In [None]:
# stdlib
from typing import Any, List

# third party
import numpy as np
import pandas as pd
from ctgan import CTGAN

# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.distribution import (
    Distribution,
    IntegerDistribution,
)


class sdv_ctgan_plugin(Plugin):
    """SDV CTGAN integration in synthcity."""

    def __init__(
        self,
        embedding_n_units: int = 128,
        n_iter: int = 2000,
        batch_size: int = 100,
        cat_limit: int = 15,
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.cat_limit = cat_limit
        self.model = CTGAN(
            embedding_dim=embedding_n_units,
            batch_size=batch_size,
            epochs=n_iter,
            verbose=False,
        )

    @staticmethod
    def name() -> str:
        return "sdv_ctgan"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
        """
        We can customize the hyperparameter space, and use it in AutoML benchmarks.
        """
        return [
            IntegerDistribution(name="embedding_n_units", low=100, high=500, step=50),
            IntegerDistribution(name="batch_size", low=100, high=300, step=50),
            IntegerDistribution(name="n_iter", low=100, high=500, step=50),
        ]

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "sdvPlugin":
        """We selected the discrete columns based on the count of unique values, and train the CTGAN"""
        discrete_columns = []

        for col in X.columns:
            if len(X[col].unique()) < self.cat_limit:
                discrete_columns.append(col)

        self.model.fit(X.dataframe(), discrete_columns=discrete_columns)
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame:
        return self._safe_generate(self.model.sample, count, syn_schema)

In [None]:
generators.add("sdv_ctgan", sdv_ctgan_plugin)

generators.list()

In [None]:
# Train the new plugin

gen = generators.get("sdv_ctgan", n_iter=100)

gen.fit(loader)

In [None]:
# Generate some new data

gen.generate(count=10).dataframe()

In [None]:
# Custom generation constraints

# synthcity absolute
from synthcity.plugins.core.constraints import Constraints

constraints = Constraints(rules=[("worst radius", ">", 15)])

generated = gen.generate(count=10, constraints=constraints)

assert (generated["worst radius"] > 15).any()

generated.dataframe()

## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
