In [1]:
## There are 4 sizes that are used as type parameters for the needed protocols.
## using the `Dim` class, we can ensure that our objects align in their dimensionalities.

## C: The number of conditional variables

from modugant.matrix.dim import Dim

conditions = Dim[0](0)
## L: The number of latent variables
latent = Dim[10](10)
## G: The number of generated features
generated = Dim[5](5)
## D: The number of real features to discriminate
dim = Dim[5](5)
## The batch size
batch = Dim[8](8)

In [2]:
# Implement the `Discriminator` protocol

from typing import Self, cast, override

from torch import Tensor
from torch import cat as t_cat
from torch.nn import Dropout, Linear, Module, ReLU, Sequential, Sigmoid
from torch.optim import Adam

from modugant.device import Device
from modugant.matrix import Matrix
from modugant.matrix.dim import One
from modugant.matrix.ops import rand, randn, zeros
from modugant.protocols import Discriminator


class IrisDiscriminator[C: int, D: int](Module, Discriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        lr: float = 0.001
    ) -> None:
        super().__init__()
        self._conditions = conditions
        self._outputs = outputs
        self.__lr = lr
        self._model = Sequential(
            *[
                Sequential(
                    Linear(
                        steps[i - 1] if i else outputs + conditions,
                        steps[i]
                    ),
                    ReLU(),
                    Dropout(0.2)
                )
                for i in range(len(steps))
            ],
            Linear(
                steps[-1] if len(steps) else outputs + conditions,
                1
            ),
            Sigmoid()
        )
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)
    @override
    def predict[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Matrix[N, One]:
        return cast(Matrix[N, One], self.forward(t_cat([condition, data], dim = 1)))
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def step[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        self.zero_grad()
        loss = self.loss(condition, data, target)
        _ = loss.backward()
        cast(None, self._optimizer.step())
        return loss
    @override
    def reset(self) -> None:
        for module in self.modules():
            if isinstance(module, Linear):
                module.reset_parameters()
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)
    @override
    def move(self, device: Device) -> Self:
        return self.to(device)
    @override
    def train(self, mode: bool = True) -> Self:
        return super().train(mode)
    @property
    @override
    def rate(self) -> float:
        return self._optimizer.param_groups[0]['lr']

discriminator = IrisDiscriminator(conditions, dim, [10, 10, 5], 0.01)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5201],
        [0.5213],
        [0.5172],
        [0.5458],
        [0.5041],
        [0.5026],
        [0.4996],
        [0.4966]], grad_fn=<SigmoidBackward0>)
tensor([[0.7095]], grad_fn=<DivBackward0>)


In [3]:
# Use a `BasicDiscriminator` to implement `IrisDiscriminator` with the same functionality

from modugant.discriminators import BasicDiscriminator

# assign a ._model and ._optimizer, and implement loss() and restart()
# our previous implementations of the other methods are identical in `BasicDiscriminator`

class IrisDiscriminator[C: int, D: int](BasicDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        lr: float = 0.001
    ) -> None:
        super().__init__(conditions, outputs)
        self.__lr = lr
        self._model = Sequential(
            *[
                Sequential(
                    Linear(
                        steps[i - 1] if i else outputs + conditions,
                        steps[i]
                    ),
                    ReLU(),
                    Dropout(0.2)
                )
                for i in range(len(steps))
            ],
            Linear(
                steps[-1] if len(steps) else outputs + conditions,
                1
            ),
            Sigmoid()
        )
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)

discriminator = IrisDiscriminator(conditions, dim, [10, 10, 5], 0.01)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5127],
        [0.4763],
        [0.5033],
        [0.5000],
        [0.5102],
        [0.4621],
        [0.5168],
        [0.4551]], grad_fn=<SigmoidBackward0>)
tensor([[0.6818]], grad_fn=<DivBackward0>)


In [4]:
# Use an `ExtendedDiscriminator` to override any method of an existing discriminator class

from torch.optim.lr_scheduler import StepLR

from modugant.discriminators.extended import ExtendedDiscriminator

# Create a container class that adds a scheduler to a discriminator
# Note that this could have been done directly in the `IrisDiscriminator` class
# But we can create a classes with and without the scheduler, but otherwise identical

class ScheduledDiscriminator[C: int, D: int](ExtendedDiscriminator[C, D]):
    ## Narrow the expected type to the BasicDiscriminator protocol
    ## This ensures that there is a .optimizer for us to schedule
    _discriminator: BasicDiscriminator[C, D]
    def __init__(
        self,
        discriminator: BasicDiscriminator[C, D], ## Narrow the expected type to the BasicDiscriminator protocol
        step: int,
        gamma: float
    ) -> None:
        super().__init__(discriminator)
        self.__step = step
        self.__gamma = gamma
        self.__scheduler = StepLR(self._discriminator.optimizer, step_size = step, gamma = gamma)
    @override
    def step[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        loss = super().step(condition, data, target)
        self.__scheduler.step()
        return loss
    @override
    def restart(self) -> None:
        super().restart()
        self.__scheduler = StepLR(self._discriminator.optimizer, step_size = self.__step, gamma = self.__gamma)

discriminator = ScheduledDiscriminator(
    IrisDiscriminator(conditions, dim, [10, 10, 5], 0.01),
    step = 100,
    gamma = 0.5
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5742],
        [0.5887],
        [0.5981],
        [0.5853],
        [0.5882],
        [0.5957],
        [0.5885],
        [0.6146]], grad_fn=<SigmoidBackward0>)
tensor([[0.6514]], grad_fn=<DivBackward0>)


In [5]:
## Use a `StandardDiscriminator` to ease model creation

from torch.nn import LeakyReLU

from modugant.discriminators.standard import StandardDiscriminator


class ReLUDiscriminator[C: int, D: int](StandardDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, _: Sequential(
                Linear(inputs, outputs),
                ReLU(),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
        )
        self.__lr = lr
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)


class LeakyDiscriminator[C: int, D: int](StandardDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        slope: float = 0.01,
        dropout: float = 0.2,
        lr: float = 0.001,
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, _: Sequential(
                Linear(inputs, outputs),
                LeakyReLU(slope),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
        )
        self.__lr = lr
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)

discriminator = LeakyDiscriminator(
    conditions,
    dim,
    steps = [10, 10, 5],
    slope = 0.1,
    dropout = 0.5,
    lr = 0.001
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5282],
        [0.5740],
        [0.6081],
        [0.6148],
        [0.5783],
        [0.5783],
        [0.6172],
        [0.5294]], grad_fn=<SigmoidBackward0>)
tensor([[0.5523]], grad_fn=<DivBackward0>)


In [6]:
## Include some reused methods and properties into a new class to inherit from

from typing import Callable


class EntropyDiscriminator[C: int, D: int](StandardDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        layer: Callable[[int, int, int], Module],
        finish: Callable[[int], Module],
        lr: float = 0.001
    ) -> None:
        super().__init__(conditions, outputs, steps, layer, finish)
        self.__lr = lr
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)

# This allows for simpler extensions
class ReLUDiscriminator[C: int, D: int](EntropyDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, _: Sequential(
                Linear(inputs, outputs),
                ReLU(),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
            lr = lr
        )

class LeakyDiscriminator[C: int, D: int](EntropyDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        slope: float = 0.01,
        dropout: float = 0.2,
        lr: float = 0.001,
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, _: Sequential(
                Linear(inputs, outputs),
                LeakyReLU(slope),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
            lr = lr
        )

discriminator = LeakyDiscriminator(
    conditions,
    dim,
    steps = [10, 10, 5],
    slope = 0.1,
    dropout = 0.5,
    lr = 0.001
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5967],
        [0.5580],
        [0.5660],
        [0.5840],
        [0.5643],
        [0.5643],
        [0.5647],
        [0.6286]], grad_fn=<SigmoidBackward0>)
tensor([[0.6391]], grad_fn=<DivBackward0>)


In [7]:
# Create a pivoting discriminator that pivots long to wide before modeling
# replicate predictions for each group to match one-to-one with the data

class PivotDiscriminator[C: int, D: int](EntropyDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        group: int,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, layer: Sequential(
                Linear(
                    inputs if layer else group * inputs, # initial layer takes wide data
                    outputs
                ),
                LeakyReLU(0.01),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
            lr = lr
        )
        self.__group = group
    def _pivot(self, data: Tensor) -> Tensor:
        assert data.shape[0] % self.__group == 0, "Data must be divisible by the group size"
        return data.view(-1, self.__group * (self._conditions + self._outputs))
    @override
    def predict[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Matrix[N, One]:
        folded = self._pivot(t_cat([condition, data], dim = 1))
        predicted = self.forward(folded)
        replicated = predicted.expand(-1, self.__group).reshape(-1, 1)
        return Matrix.load(replicated, (condition.shape[0], Dim.one()))

discriminator = PivotDiscriminator(
    conditions,
    dim,
    group = 4,
    steps = [10, 10, 5]
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.6266],
        [0.6266],
        [0.6266],
        [0.6266],
        [0.6442],
        [0.6442],
        [0.6442],
        [0.6442]], grad_fn=<UnsafeViewBackward0>)
tensor([[0.6709]], grad_fn=<DivBackward0>)


In [8]:
## Create a Smoothness-Regularized Discriminator

from torch.autograd import grad

from modugant.matrix.ops import cat as m_cat
from modugant.matrix.ops import one_hot, ones


class SmoothDiscriminator[C: int, D: int](ExtendedDiscriminator[C, D]):
    _discriminator: BasicDiscriminator[C, D]
    '''
    A discriminator that includes a smoothness penalty on the norm of the gradients
    in the interpolated space between example cases. This ensures that the boundary
    between real and fake data is smooth and regular.

    Args:
        discriminator: The discriminator to regularize
        factor: The factor to multiply the smoothness penalty's contribution to loss

    '''
    def __init__(
        self,
        discriminator: BasicDiscriminator[C, D],
        factor: float
    ) -> None:
        super().__init__(discriminator)
        self.__factor = factor
    def _blend[N: int](self, data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[N, N]:
        # Generate interpolation points between the true and false data
        # sample probability of assigning each original data row to each blend row
        sample = rand((data.shape[0], data.shape[0]))
        # sample percentages of the true cases in the blend per row
        alpha = rand((data.shape[0], Dim.one()))
        # sample from the true cases, use arg-max of probability to select a true into each blend row
        # use one_hot to convert the arg-max index back into a one_hot representing rows
        # `trues` is a [size, data.shape[0]] matrix mapping with 0/1 the sampled true data into the blend
        # the true cases can be selected into the blend with `trues @ data`
        trues = one_hot(
            (sample * target.T).argmax(dim = 1, keepdim = True),
            num_classes = target.shape[0]
        )
        # repeat with false cases
        falses = one_hot(
            (sample * (1 - target).T).argmax(dim = 1, keepdim = True),
            num_classes = target.shape[0]
        )
        # use alpha/(1 - alpha) to weight the true and false cases into the blend
        # the resulting matrix can be used as a sampler with `blend @ data`
        return (alpha * trues) + ((1 - alpha) * falses)
    def penalty[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        # create an interpolated data set
        blend = self._blend(data, target)
        ## We have to make the assumption that the discriminator is still making this transformation in its
        ## predict method before calling .forward. See the next example below for a cleanup of this assumption.
        inputs = m_cat(
            (blend @ condition, blend @ data),
            dim = 1,
            shape = (condition.shape[0], self._conditions + self._outputs)
        )
        inputs.requires_grad = True
        # feed forward the blend data
        outputs = self._discriminator.forward(inputs)
        # get the gradients of the output with respect to the blend data
        gradient = grad(
            outputs = outputs,
            inputs = inputs,
            grad_outputs = ones((outputs.shape[0], outputs.shape[1])),
            create_graph = True,
            retain_graph = True
        )[0]
        # calculate the norm of the gradients per row
        norm = cast(Tensor, gradient.norm(dim = 1, keepdim = True))
        # penalize the distance of the norm from 1
        penalty = self.__factor * ((norm - 1) ** 2).mean(dim = None, keepdim = True)
        return Matrix.load(penalty, shape = (Dim.one(), Dim.one()))
    @override
    def step[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        self._discriminator.zero_grad()
        loss = self._discriminator.loss(condition, data, target)
        _ = loss.backward()
        penalty = self.penalty(condition, data, target)
        _ = penalty.backward(retain_graph = True)
        self._discriminator.optimizer.step()
        return loss

discriminator = SmoothDiscriminator(
    LeakyDiscriminator(
        conditions,
        dim,
        steps = [10, 10, 5],
        slope = 0.1,
        dropout = 0.5,
        lr = 0.001
    ),
    factor = 1.0
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)
penalty = discriminator.penalty(condition, data, target)

print(predicted, loss, penalty, sep = '\n')

tensor([[0.4559],
        [0.4409],
        [0.4469],
        [0.4549],
        [0.4519],
        [0.4556],
        [0.4596],
        [0.4600]], grad_fn=<SigmoidBackward0>)
tensor([[0.6706]], grad_fn=<DivBackward0>)
tensor([[0.9437]], grad_fn=<MulBackward0>)


In [9]:
# Add the penalty to our `FoldedDiscriminator`
# To penalize the gradient along the folded dimension, we need to add a method to the interface

# In order to access the reshaping of the PivotDiscriminator
# we need to narrow the incoming type to depend on a method to use
# this also allows us to bypass the assumption that the discriminator is simply concatenating the condition and data
# For non-pivoting cases, the reshape method can be such a concatenation or otherwise based on implementation
class ReshapeDiscriminator[C: int, D: int](BasicDiscriminator[C, D]):
    '''Discriminator protocol class which can be smoothness regularized.'''

    def __init__(self, conditions: C, outputs: D) -> None:
        super().__init__(conditions, outputs)

    def reshape[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Tensor:
        '''
        Reshape the incoming condition data to fit the discriminator's first layer.

        Args:
            condition (torch.Tensor): The conditional data.
            data (torch.Tensor): The data to reshape.

        Returns:
            torch.Tensor: The reshaped data.

        '''
        ...

# rename and re-parameterize our previously protected _pivot to an exposed reshape method to match the Protocol above
class PivotDiscriminator[C: int, D: int](EntropyDiscriminator[C, D], ReshapeDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        group: int,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, layer: Sequential(
                Linear(
                    inputs if layer else group * inputs,
                    outputs
                ),
                LeakyReLU(0.01),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1), Sigmoid()),
            lr = lr
        )
        self.__group = group
    @override
    def reshape[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Tensor:
        assert data.shape[0] % self.__group == 0, "Data must be divisible by the group size"
        joined = t_cat([condition, data], dim = 1) # join the condition and data
        return joined.view(-1, self.__group * (self._conditions + self._outputs))
    @override
    def predict[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Matrix[N, One]:
        folded = self.reshape(condition, data) # call the new reshape method
        predicted = self.forward(folded)
        replicated = predicted.expand(-1, self.__group).reshape(-1, 1)
        return Matrix.load(replicated, (condition.shape[0], Dim.one()))

# expect the underlying discriminator to be a ReshapedDiscriminator
class SmoothDiscriminator[C: int, D: int](ExtendedDiscriminator[C, D]):
    _discriminator: ReshapeDiscriminator[C, D]
    '''
    A discriminator that includes a smoothness penalty on the norm of the gradients
    in the interpolated space between example cases. This ensures that the boundary
    between real and fake data is smooth and regular.

    Args:
        discriminator: The discriminator to regularize
        factor: The factor to multiply the smoothness penalty's contribution to loss

    '''
    def __init__(
        self,
        discriminator: ReshapeDiscriminator[C, D],
        factor: float
    ) -> None:
        super().__init__(discriminator)
        self.__factor = factor
    def _blend[N: int](self, data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[N, N]:
        sample = rand((data.shape[0], data.shape[0]))
        alpha = rand((data.shape[0], Dim.one()))
        trues = one_hot(
            (sample * target.T).argmax(dim = 1, keepdim = True),
            num_classes = target.shape[0]
        )
        falses = one_hot(
            (sample * (1 - target).T).argmax(dim = 1, keepdim = True),
            num_classes = target.shape[0]
        )
        return (alpha * trues) + ((1 - alpha) * falses)
    def penalty[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        blend = self._blend(data, target)
        inputs = self._discriminator.reshape(
            blend @ condition,
            blend @ data
        )
        inputs.requires_grad = True
        outputs = self._discriminator.forward(inputs)
        gradient = grad(
            outputs = outputs,
            inputs = inputs,
            grad_outputs = ones((outputs.shape[0], outputs.shape[1])),
            create_graph = True,
            retain_graph = True
        )[0]
        # pivot the gradient to the same dimensionality as the data would be
        norm = cast(Tensor, gradient.norm(dim = 1, keepdim = True))
        penalty = self.__factor * ((norm - 1) ** 2).mean(dim = None, keepdim = True)
        return Matrix.load(penalty, shape = (Dim.one(), Dim.one()))
    @override
    def step[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        self._discriminator.zero_grad()
        loss = self._discriminator.loss(condition, data, target)
        _ = loss.backward()
        penalty = self.penalty(condition, data, target)
        _ = penalty.backward(retain_graph = True)
        self._discriminator.optimizer.step()
        return loss

discriminator = SmoothDiscriminator(
    PivotDiscriminator(
        conditions,
        dim,
        group = 4,
        steps = [10, 10, 5]
    ),
    factor = 1.0
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

penalty = discriminator.penalty(condition, data, target)

print(predicted, loss, penalty, sep = '\n')

tensor([[0.5525],
        [0.5525],
        [0.5525],
        [0.5525],
        [0.5482],
        [0.5482],
        [0.5482],
        [0.5482]], grad_fn=<UnsafeViewBackward0>)
tensor([[0.7508]], grad_fn=<DivBackward0>)
tensor([[0.9783]], grad_fn=<MulBackward0>)


In [10]:
# Remove the final Sigmoid, and the log from the loss calculation

# remove the inheritance from the EntropyDiscriminator
class PivotDiscriminator[C: int, D: int](StandardDiscriminator[C, D], ReshapeDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        group: int,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda inputs, outputs, layer: Sequential(
                Linear(
                    inputs if layer else group * inputs,
                    outputs
                ),
                LeakyReLU(0.01),
                Dropout(dropout)
            ),
            finish = lambda inputs: Sequential(Linear(inputs, 1)), # remove the Sigmoid layer
        )
        self.__group = group
        self.__lr = lr
        self._optimizer = Adam(self.parameters(), lr = lr)
    @override
    def reshape[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Tensor:
        assert data.shape[0] % self.__group == 0, "Data must be divisible by the group size"
        joined = t_cat([condition, data], dim = 1)
        return joined.view(-1, self.__group * (self._conditions + self._outputs))
    @override
    def predict[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Matrix[N, One]:
        folded = self.reshape(condition, data)
        predicted = self.forward(folded)
        replicated = predicted.expand(-1, self.__group).reshape(-1, 1)
        return Matrix.load(replicated, (condition.shape[0], Dim.one()))
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        predicted = self.predict(condition, data)
        # remove the log from the loss calculation
        loss = - (target.t() @ predicted + (1 - target.t()) @ (1 - predicted)) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def restart(self) -> None:
        self._optimizer = Adam(self.parameters(), lr = self.__lr)

discriminator = SmoothDiscriminator(
    PivotDiscriminator(
        conditions,
        dim,
        group = 4,
        steps = [10, 10, 5]
    ),
    factor = 1.0
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

penalty = discriminator.penalty(condition, data, target)

print(predicted, loss, penalty, sep = '\n')

tensor([[0.1352],
        [0.1352],
        [0.1352],
        [0.1352],
        [0.0689],
        [0.0689],
        [0.0689],
        [0.0689]], grad_fn=<UnsafeViewBackward0>)
tensor([[-0.4749]], grad_fn=<DivBackward0>)
tensor([[0.9480]], grad_fn=<MulBackward0>)


In [12]:
# We have recreated the library-included `FoldedDiscriminator` and `SmoothedDiscriminator`

from modugant.discriminators import FoldedDiscriminator, SmoothedDiscriminator

discriminator = SmoothedDiscriminator(
    FoldedDiscriminator(
        conditions,
        dim,
        group = 4,
        steps = [10, 10, 5]
    ),
    factor = 1.0
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

penalty = discriminator.penalty(condition, data, target)

print(predicted, loss, penalty, sep = '\n')

tensor([[-0.2857],
        [-0.2857],
        [-0.2857],
        [-0.2857],
        [-0.4035],
        [-0.4035],
        [-0.4035],
        [-0.4035]], grad_fn=<UnsafeViewBackward0>)
tensor([[0.]], grad_fn=<DivBackward0>)
tensor([[0.6550]], grad_fn=<MulBackward0>)


In [13]:
# Recreate SphereDiscriminators

# Create a Module where input weights are maintained on a unit hyper-sphere

from torch import no_grad


class SphLinear(Linear):
    @staticmethod
    def project(network: Linear) -> None:
        # project the weights of the network onto the unit sphere
        with no_grad():
            norm = cast(Tensor, network.weight.norm(2, dim = 1, keepdim = True))
            _ = network.weight.div_(norm)
    def __init__(self, inputs: int, outputs: int):
        super().__init__(inputs, outputs)
        SphLinear.project(self)
    def reproject(self) -> None:
        SphLinear.project(self)
    @override
    def reset_parameters(self) -> None:
        super().reset_parameters()
        SphLinear.project(self)

In [14]:
# Use this type of layer in a discriminator

from typing import Tuple

from modugant.discriminators import ReshapingDiscriminator
from modugant.matrix import Matrix
from modugant.matrix.dim import One


## Make this a ReshapedDiscriminator, so that it can be used with our Regularizer
## Just perform a concatenation in the reshape (identity-like)
class SphereDiscriminator[C: int, D: int](StandardDiscriminator[C, D], ReshapingDiscriminator[C, D]):
    def __init__(
        self,
        conditions: C,
        outputs: D,
        steps: list[int],
        dropout: float = 0.2,
        lr: float = 0.001,
        betas: Tuple[float, float] = (0.5, 0.9),
        decay: float = 0.1,
    ) -> None:
        super().__init__(
            conditions,
            outputs,
            steps,
            layer = lambda ins, outs, _: Sequential(
                SphLinear(ins, outs), # use the SphLinear layer
                ReLU(),
                Dropout(dropout)
            ),
            finish = lambda ins: Sequential(SphLinear(ins, 1), Sigmoid())
        )
        self.__lr = lr
        self.__decay = decay
        self.__betas = betas
        self._optimizer = Adam(
            self.parameters(),
            lr = lr,
            betas = betas,
            weight_decay = decay,
        )
    @override
    def reshape[N: int](self, condition: Matrix[N, C], data: Matrix[N, D]) -> Tensor:
        # no transformation, just a concatenation
        joined = t_cat([condition, data], dim = 1)
        return joined
    @override
    def unshape[N: int](self, data: Tensor, n: N) -> Matrix[N, One]:
        # identity
        return Matrix.load(data, (n, Dim.one()))
    @override
    def loss[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        # binary cross entropy loss
        predicted = self.predict(condition, data)
        loss = - (target.t() @ predicted.log() + (1 - target.t()) @ (1 - predicted).log()) / len(target)
        return Matrix.load(loss, (Dim.one(), Dim.one()))
    @override
    def step[N: int](self, condition: Matrix[N, C], data: Matrix[N, D], target: Matrix[N, One]) -> Matrix[One, One]:
        loss = super().step(condition, data, target)
        # project all linear weight vectors back onto unit sphere after updates
        for module in self.modules():
            if isinstance(module, SphLinear):
                module.reproject()
        return loss
    @override
    def restart(self) -> None:
        self._optimizer = Adam(
            self.parameters(),
            lr = self.__lr,
            betas = self.__betas,
            weight_decay = self.__decay,
        )

discriminator = SphereDiscriminator(
    conditions,
    dim,
    steps = [10, 10, 5],
    dropout = 0.2,
    lr = 0.001,
    betas = (0.5, 0.9),
    decay = 0.1
)

condition = zeros((batch, conditions))
data = randn((batch, dim))
target = rand((batch, Dim.one())).round()

predicted = discriminator.predict(condition, data)
loss = discriminator.loss(condition, data, target)

print(predicted, loss, sep = '\n')

tensor([[0.5572],
        [0.5528],
        [0.5538],
        [0.6612],
        [0.6343],
        [0.5630],
        [0.5638],
        [0.5526]], grad_fn=<SigmoidBackward0>)
tensor([[0.6521]], grad_fn=<DivBackward0>)
