# Example 3: HKR classifier on MNIST dataset

This notebook demonstrates how to learn a binary classifier on the MNIST0-8 dataset (MNIST with only 0 and 8).

## 1. Data preparation

For this task we will select two classes: 0 and 8. Labels are changed to {-1,1}, which
is compatible with the hinge term used in the loss.


In [1]:
import torch
from torchvision import datasets

# First we select the two classes
selected_classes = [0, 8]  # must be two classes as we perform binary classification


def prepare_data(dataset, class_a=0, class_b=8):
    """
    This function converts the MNIST data to make it suitable for our binary
    classification setup.
    """
    x = dataset.data
    y = dataset.targets
    # select items from the two selected classes
    mask = (y == class_a) + (
        y == class_b
    )  # mask to select only items from class_a or class_b
    x = x[mask]
    y = y[mask]

    # convert from range int[0,255] to float32[-1,1]
    x = x.float() / 255
    x = x.reshape((-1, 28, 28, 1))
    # change label to binary classification {-1,1}

    y_ = torch.zeros_like(y).float()
    y_[y == class_a] = 1.0
    y_[y == class_b] = -1.0
    return torch.utils.data.TensorDataset(x, y_)


train = datasets.MNIST("./data", train=True, download=True)
test = datasets.MNIST("./data", train=False, download=True)

# Prepare the data
train = prepare_data(train, selected_classes[0], selected_classes[1])
test = prepare_data(test, selected_classes[0], selected_classes[1])

# Display infos about dataset
print(
    f"Train set size: {len(train)} samples, classes proportions: "
    f"{100 * (train.tensors[1] == 1).numpy().mean():.2f} %"
)
print(
    f"Test set size: {len(test)} samples, classes proportions: "
    f"{100 * (test.tensors[1] == 1).numpy().mean():.2f} %"
)



Train set size: 11774 samples, classes proportions: 50.31 %
Test set size: 1954 samples, classes proportions: 50.15 %


## 2. Build Lipschitz model

Here, the experiments are done with a model with only fully-connected layers. However,
`torchlip` also provides state-of-the-art 1-Lipschitz convolutional layers.


In [2]:
import torch
from deel import torchlip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ninputs = 28 * 28
wass = torchlip.Sequential(
    torch.nn.Flatten(),
    torchlip.SpectralLinear(ninputs, 128),
    torchlip.FullSort(),
    torchlip.SpectralLinear(128, 64),
    torchlip.FullSort(),
    torchlip.SpectralLinear(64, 32),
    torchlip.FullSort(),
    torchlip.FrobeniusLinear(32, 1),
).to(device)

wass


Sequential model contains a layer which is not a Lipschitz layer: Flatten(start_dim=1, end_dim=-1)


Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): SpectralLinear(in_features=784, out_features=128, bias=True)
  (2): FullSort()
  (3): SpectralLinear(in_features=128, out_features=64, bias=True)
  (4): FullSort()
  (5): SpectralLinear(in_features=64, out_features=32, bias=True)
  (6): FullSort()
  (7): FrobeniusLinear(in_features=32, out_features=1, bias=True)
)

## 3. Learn classification on MNIST

In [3]:
from deel.torchlip.functional import kr_loss, hkr_loss, hinge_margin_loss
from tqdm import tqdm

# training parameters
epochs = 10
batch_size = 128

# loss parameters
min_margin = 1
alpha = 10

optimizer = torch.optim.Adam(lr=0.001, params=wass.parameters())

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)

for epoch in range(epochs):

    print(f"Epoch {epoch + 1}/{epochs}")

    m_kr, m_hm, m_acc = 0, 0, 0
    wass.train()

    with tqdm(total=len(train_loader)) as tsteps:
        for step, (data, target) in enumerate(train_loader):
            tsteps.update()

            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = wass(data)
            loss = hkr_loss(output, target, alpha=alpha, min_margin=min_margin)
            loss.backward()
            optimizer.step()

            # Compute metrics on batch
            m_kr += kr_loss(output, target, (1, -1))
            m_hm += hinge_margin_loss(output, target, min_margin)
            m_acc += (torch.sign(output).flatten() == torch.sign(target)).sum() / len(
                target
            )

            # Print metrics of current batch
            postfix = {
                k: "{:.04f}".format(v)
                for k, v in {
                    "loss": loss,
                    "kr": m_kr / (step + 1),
                    "hinge": m_hm / (step + 1),
                    "acc": m_acc / (step + 1),
                }.items()
            }
            tsteps.set_postfix(postfix)

        # Compute test loss for the current epoch
        wass.eval()
        testo = []
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            testo.append(wass(data).detach().cpu())
        testo = torch.cat(testo).flatten()

        # Print metrics for the current epoch (train and validation metrics)
        postfix.update(
            {
                f"val_{k}": "{:.04f}".format(v)
                for k, v in {
                    "loss": hkr_loss(
                        testo, test.tensors[1], alpha=alpha, min_margin=min_margin
                    ),
                    "kr": kr_loss(testo.flatten(), test.tensors[1], (1, -1)),
                    "hinge": hinge_margin_loss(
                        testo.flatten(), test.tensors[1], min_margin
                    ),
                    "acc": (torch.sign(testo).flatten() == torch.sign(test.tensors[1]))
                    .float()
                    .mean(),
                }.items()
            }
        )
        tsteps.set_postfix(postfix)


Epoch 1/10


100%|██████████| 92/92 [00:01<00:00, 60.63it/s, loss=-3.0010, kr=1.9975, hinge=0.2509, acc=0.9332, val_loss=-3.0327, val_kr=3.2924, val_hinge=0.0260, val_acc=0.9928]


Epoch 2/10


100%|██████████| 92/92 [00:01<00:00, 60.22it/s, loss=-4.4525, kr=4.1030, hinge=0.0254, acc=0.9938, val_loss=-4.8342, val_kr=5.0409, val_hinge=0.0207, val_acc=0.9923]


Epoch 3/10


100%|██████████| 92/92 [00:01<00:00, 60.08it/s, loss=-5.6243, kr=5.6371, hinge=0.0241, acc=0.9926, val_loss=-5.9461, val_kr=6.1530, val_hinge=0.0207, val_acc=0.9918]


Epoch 4/10


100%|██████████| 92/92 [00:01<00:00, 60.13it/s, loss=-6.3080, kr=6.3542, hinge=0.0227, acc=0.9923, val_loss=-6.3689, val_kr=6.5977, val_hinge=0.0229, val_acc=0.9918]


Epoch 5/10


100%|██████████| 92/92 [00:01<00:00, 60.48it/s, loss=-6.4896, kr=6.6610, hinge=0.0216, acc=0.9930, val_loss=-6.5609, val_kr=6.7654, val_hinge=0.0204, val_acc=0.9923]


Epoch 6/10


100%|██████████| 92/92 [00:01<00:00, 60.28it/s, loss=-6.5060, kr=6.8111, hinge=0.0202, acc=0.9930, val_loss=-6.6961, val_kr=6.9176, val_hinge=0.0221, val_acc=0.9918]


Epoch 7/10


100%|██████████| 92/92 [00:01<00:00, 60.94it/s, loss=-6.7015, kr=6.9254, hinge=0.0202, acc=0.9935, val_loss=-6.7808, val_kr=6.9808, val_hinge=0.0200, val_acc=0.9928]


Epoch 8/10


100%|██████████| 92/92 [00:01<00:00, 60.25it/s, loss=-6.7822, kr=6.9913, hinge=0.0187, acc=0.9935, val_loss=-6.8321, val_kr=7.0534, val_hinge=0.0221, val_acc=0.9923]


Epoch 9/10


100%|██████████| 92/92 [00:01<00:00, 59.73it/s, loss=-6.6593, kr=7.0458, hinge=0.0190, acc=0.9937, val_loss=-6.8656, val_kr=7.0789, val_hinge=0.0213, val_acc=0.9928]


Epoch 10/10


100%|██████████| 92/92 [00:01<00:00, 60.08it/s, loss=-7.0013, kr=7.0847, hinge=0.0182, acc=0.9941, val_loss=-6.9047, val_kr=7.1115, val_hinge=0.0207, val_acc=0.9928]


## 4. Evaluate the Lipschitz constant of our networks

### 4.1. Empirical evaluation

We can estimate the Lipschitz constant by evaluating 

$$
    \frac{\Vert{}F(x_2) - F(x_1)\Vert{}}{\Vert{}x_2 - x_1\Vert{}} \quad\text{or}\quad 
    \frac{\Vert{}F(x + \epsilon) - F(x)\Vert{}}{\Vert{}\epsilon\Vert{}}
$$

for various inputs.

In [4]:
from scipy.spatial.distance import pdist

wass.eval()

p = []
for _ in range(64):
    eps = 1e-3
    batch, _ = next(iter(train_loader))
    dist = torch.distributions.Uniform(-eps, +eps).sample(batch.shape)
    y1 = wass(batch.to(device)).detach().cpu()
    y2 = wass((batch + dist).to(device)).detach().cpu()

    p.append(
        torch.max(
            torch.norm(y2 - y1, dim=1)
            / torch.norm(dist.reshape(dist.shape[0], -1), dim=1)
        )
    )
print(torch.tensor(p).max())

tensor(0.1517)


In [5]:
from scipy.spatial.distance import pdist

wass.eval()

p = []
for batch, _ in tqdm(train_loader):
    x = batch.numpy()
    y = wass(batch.to(device)).detach().cpu().numpy()
    xd = pdist(x.reshape(batch.shape[0], -1))
    yd = pdist(y.reshape(batch.shape[0], -1))

    p.append((yd / xd).max())
print(torch.tensor(p).max())

100%|██████████| 92/92 [00:00<00:00, 168.02it/s]

tensor(0.9063, dtype=torch.float64)





As we can see, using the $\epsilon$-version, we greatly under-estimate the Lipschitz constant.
Using the train dataset, we find a Lipschitz constant close to 0.9, which is better, but our network should be 1-Lipschitz.

### 4.1. Singular-Value Decomposition

Since our network is only made of linear layers and `FullSort` activation, we can compute *Singular-Value Decomposition* (SVD) of our weight matrix and check that, for each linear layer, all singular values are 1.

In [6]:
print("=== Before export ===")
layers = list(wass.children())
for layer in layers:
    if hasattr(layer, "weight"):
        w = layer.weight
        u, s, v = torch.svd(w)
        print(f"{layer}, min={s.min()}, max={s.max()}")

=== Before export ===
SpectralLinear(in_features=784, out_features=128, bias=True), min=0.9999998807907104, max=1.0
SpectralLinear(in_features=128, out_features=64, bias=True), min=0.9999998807907104, max=1.0000001192092896
SpectralLinear(in_features=64, out_features=32, bias=True), min=0.9999998807907104, max=1.0
FrobeniusLinear(in_features=32, out_features=1, bias=True), min=0.9999999403953552, max=0.9999999403953552


In [7]:
wexport = wass.vanilla_export()

print("=== After export ===")
layers = list(wexport.children())
for layer in layers:
    if hasattr(layer, "weight"):
        w = layer.weight
        u, s, v = torch.svd(w)
        print(f"{layer}, min={s.min()}, max={s.max()}")

=== After export ===
Linear(in_features=784, out_features=128, bias=True), min=0.9999998807907104, max=1.0
Linear(in_features=128, out_features=64, bias=True), min=0.9999998807907104, max=1.0000001192092896
Linear(in_features=64, out_features=32, bias=True), min=0.9999998807907104, max=1.0
Linear(in_features=32, out_features=1, bias=True), min=0.9999999403953552, max=0.9999999403953552


As we can see, all our singular values are very close to one.

<div class="alert alert-block alert-danger"></div>