In [7]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss

import cooper
import tutorial_utils
import style_utils

torch.manual_seed(0)
np.random.seed(0)

# Mixture of Gaussians

This example is inspired by Fig. 2 in Cotter et al. (2019)

In [None]:
# Generate dataset

In [None]:
class MixtureSeparation(cooper.ConstrainedMinimizationProblem):
    def __init__(self, is_constrained, use_proxy=False):

        super().__init__(is_constrained=is_constrained)

        self.use_proxy = use_proxy

        # Linear predictor
        self.linear = torch.nn.Linear(2, 1)

    def closure(self, inputs, targets):

        logits = self.linear(inputs)
        loss=bce_loss(logits, targets),

        if not self.is_constrained:
            # Unconstrained problem of separating two classes
            state = cooper.CMPState(
                loss=loss,
            )

        if self.is_constrained:
            # Separating classes s.t. predicting at least 55% as class 0 (blue)

            probs = torch.sigmoid(logits)
            # Hinge approximation of the rate

            if not self.use_proxy:
                # Use a proxy for the constraint: a hinge relaxation
                state = cooper.CMPState(
                    loss=loss,
                    ineq_defect=_,
                )
            else:
                # Use non-proxy constraint defects to update the Lagrange multipliers

                # Proportion of elements in class 0 is the non-proxy defect
                classes = torch.round(probs)
                prop_0 = torch.sum(classes == 0) / targets.numel()

                state = cooper.CMPState(
                    loss=loss,
                    ineq_defect=_,
                    proxy_defect=0.55 - prop_0, # 55% - prop_0 <= 0
                )

        return state


## References

- A. Cotter, H. Jiang, M. Gupta, S. Wang, T. Narayan, S. You,
and K. Sridharan. Optimization with Non-Differentiable
Constraints with Applications to Fairness, Recall, Churn,
and Other Goals. In JMLR, 2019.