In [6]:
from typing import cast

from pandas import Categorical, DataFrame, Series
from seaborn import load_dataset
from torch import float32, tensor

from modugant.matrix import Matrix

iris = DataFrame(load_dataset("iris"))
species = cast('Series[str]', iris.pop("species"))
iris['species'] = Categorical(species).codes
data = Matrix.load(tensor(iris.values, dtype = float32), (150, 5))

In [7]:
## There are 4 sizes that are used as type parameters for the needed protocols.
## using the `Dim` class, we can ensure that our objects align in their dimensionalities.

## C: The number of conditional variables

from modugant.matrix import Dim

## C: The number of conditional variables
conditions = Dim[3](3)
## L: The number of latent variables
latent = Dim[10](10)
## G: The number of generated features
generated = Dim[7](7)
## D: The number of real features to discriminate
dim = Dim[7](7)
## The batch size
batch = Dim[512](512)

category = Dim[3](3)
normed = Dim[4](4)

In [8]:
## create a simple Regimen class

from typing import Tuple, override

from modugant.protocols import Action, Regimen


class SimpleRegimen(Regimen):
    def __init__(self, batch: int, iterations: int) -> None:
        self._batch = batch
        self._iterations = iterations
    @override
    def command(self, iteration: int, loss: Tuple[float, float]) -> Tuple[Action, str]:
        if iteration >= self._iterations:
            return ('stop', 'Max iterations reached.')
        else:
            return ('continue', 'Continue training.')
    @override
    def report(self, iteration: int, action: Action, message: str, d_loss: float, g_loss: float) -> None:
        if action == 'continue' and iteration % 10 == 0:
            print(f'Iteration: {iteration}, D Loss: {d_loss}, G Loss: {g_loss}')
        elif action != 'continue':
            print(f'Iteration: {iteration}, D Loss: {d_loss}, G Loss: {g_loss}')
            print(f'\t{action}: {message}')
    @override
    def reset(self) -> None:
        pass
    @property
    @override
    def batch(self) -> int:
        return self._batch
    @property
    @override
    def k(self) -> int:
        return 1
    @property
    @override
    def d_factor(self) -> float:
        return 1.0
    @property
    @override
    def g_factor(self) -> float:
        return 1.0

In [9]:
from modugant.connectors import JointConnector
from modugant.discriminators import FoldedDiscriminator, SmoothedDiscriminator
from modugant.generators import ResidualGenerator
from modugant.samplers.random import RandomSampler
from modugant.transformers import CategoryTransformer, StandardizeTransformer

generator = ResidualGenerator(
    conditions,
    latent,
    generated,
    steps = [256, 256],
    learning = 0.0001,
    decay = 0.001
)

connector = JointConnector(
    conditions,
    generated,
    dim,
    transformers = (
        StandardizeTransformer(
            normed,
            data = data,
            index = [0, 1, 2, 3]
        ),
        CategoryTransformer(index = (4, category))
    ),
    sampler = RandomSampler(
        data.shape[1],
        data = data,
        split = 0.8
    )
)

discriminator = SmoothedDiscriminator(
    FoldedDiscriminator(
        conditions,
        dim,
        group = 16,
        steps = [256, 256],
        lr = 0.0001
    ),
    factor = 10
)

In [10]:
from modugant.trainer import Trainer

trainer = Trainer(generator, discriminator, connector)

trainer.train(SimpleRegimen(batch, 500))

Iteration: 0, D Loss: 0.015230022370815277, G Loss: -0.08300692588090897
Iteration: 10, D Loss: 0.0034142471849918365, G Loss: -0.11145349591970444
Iteration: 20, D Loss: 0.006651170551776886, G Loss: -0.20655685663223267
Iteration: 30, D Loss: 0.05562207102775574, G Loss: -0.2730565071105957
Iteration: 40, D Loss: -0.022122129797935486, G Loss: -0.3504694402217865
Iteration: 50, D Loss: 0.021930888295173645, G Loss: -0.36635181307792664
Iteration: 60, D Loss: 0.012780040502548218, G Loss: -0.473486065864563
Iteration: 70, D Loss: -0.08872607350349426, G Loss: -0.43231406807899475
Iteration: 80, D Loss: -0.08668327331542969, G Loss: -0.4800696074962616
Iteration: 90, D Loss: -0.0826062560081482, G Loss: -0.46301180124282837
Iteration: 100, D Loss: -0.1896798312664032, G Loss: -0.6601606011390686
Iteration: 110, D Loss: -0.07319486141204834, G Loss: -0.5775861144065857
Iteration: 120, D Loss: -0.096417635679245, G Loss: -0.6934158205986023
Iteration: 130, D Loss: -0.06752997636795044, G

In [19]:
seeds = Matrix.load(
    tensor([
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0]
    ]),
    shape = (3, Dim[3](3))
)

with trainer.test():
    fake = generator.sample(seeds)
    prepared = connector.prepare(seeds, fake)
    print(
        seeds,
        fake,
        prepared,
        sep = '\n'
    )

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
tensor([[-1.3867, -0.9796, -0.5780, -0.0083,  3.2802, -2.8989, -3.1806],
        [ 0.2897, -1.5002,  1.8807,  0.9896, -2.8921,  4.2539, -2.8045],
        [ 0.9558, -0.5924,  1.6137,  1.1633, -2.6836, -2.5619,  3.5942]],
       grad_fn=<AddmmBackward0>)
tensor([[-1.3867e+00, -9.7960e-01, -5.7801e-01, -8.2732e-03,  9.9638e-01,
          2.0648e-03,  1.5580e-03],
        [ 2.8967e-01, -1.5002e+00,  1.8807e+00,  9.8958e-01,  7.8672e-04,
          9.9835e-01,  8.5874e-04],
        [ 9.5575e-01, -5.9236e-01,  1.6137e+00,  1.1633e+00,  1.8701e-03,
          2.1121e-03,  9.9602e-01]], grad_fn=<CatBackward0>)


In [20]:
seeds = Matrix.load(
    tensor([
        [0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0]
    ]),
    shape = (3, Dim[3](3))
)

with trainer.test():
    fake = generator.sample(seeds)
    prepared = connector.prepare(seeds, fake)
    print(
        seeds,
        fake,
        prepared,
        sep = '\n'
    )

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[-0.4677, -0.5061,  0.8408,  1.8266, -0.4324, -0.2734, -0.8410],
        [-1.1495,  1.1017, -0.8824, -1.3620, -0.1115, -1.1865, -0.6843],
        [-0.2840, -0.1743, -0.2375, -0.0452,  0.2836, -0.4731, -0.6929]],
       grad_fn=<AddmmBackward0>)
tensor([[-0.4677, -0.5061,  0.8408,  1.8266,  0.3525,  0.4133,  0.2343],
        [-1.1495,  1.1017, -0.8824, -1.3620,  0.5249,  0.1791,  0.2960],
        [-0.2840, -0.1743, -0.2375, -0.0452,  0.5418,  0.2542,  0.2040]],
       grad_fn=<CatBackward0>)
