In [2]:
import torch
import numpy as np

from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data.dataloader import DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [2]:
import torch.nn as nn
import Lipschitz.nn as Ln

from tqdm import tqdm

def train(model, optmz, crtrn, dl):
    model.train()
    
    loss = 0
    for images, labels in tqdm(dl):
        images = images.to(device)
        labels = labels.to(device)

        output = model(images).squeeze()
        b_loss = crtrn(output, labels)
        
        optmz.zero_grad()
        b_loss.backward()
        optmz.step()
        
        loss += b_loss.item()
    
    return loss

@torch.no_grad()
def eval(model, eps, dl):
    model.eval()

    from autoattack import AutoAttack
    attack = AutoAttack(model, norm='Linf', eps=eps)
    
    total, n_correct, n_certify, n_robustd = 0, 0, 0, 0
    for images, labels in tqdm(dl):
        images = images.to(device)
        labels = labels.to(device)

        output = model(images).squeeze()
        predic = torch.max(output, dim=1)
        second = torch.kthvalue(output, 9, dim=1)
        correct = predic.indices == labels
        certify = correct.logical_and(predic.values - second.values > eps)

        # output = attack.apgd.perturb(images, labels)
        # predic = model(output).squeeze()
        # predic = torch.max(predic, dim=1)
        # robustd = predic.indices == labels

        total += labels.size(0)
        n_correct += correct.sum().item()
        n_certify += certify.sum().item()
        # n_robustd += robustd.sum().item()

    return n_correct / total, n_certify / total, n_robustd / total

In [None]:
net = nn.Sequential(
    Ln.Conv2d(1, 6, 5, 1, 2),
    nn.MaxPool2d(2, 2),

    Ln.Conv2d(6, 16, 5, 1, 0),
    nn.MaxPool2d(2, 2),

    nn.Flatten(),

    Ln.Linear(400, 120),
    Ln.Linear(120, 84),
    Ln.Linear(84, 10),
)

optmz = torch.optim.Adam(net.parameters(), 1e-3)
crtrn = torch.nn.CrossEntropyLoss()

for epoch in range(10):
    loss = train(net, optmz, crtrn, dataloader)
    if epoch % 5 == 4:
        accu = eval(net, 8/255, dataloader)
        print(f'Epoch {epoch+1} average loss: {loss}, accuracy: {accu}')
    else:
        print(f'Epoch {epoch+1} average loss: {loss}')

In [24]:
a = torch.Tensor([[[[1, 2, 3, 4], [4, 5, 6, 4], [7, 8, 9, 4]]]])
print(a)

import torch.nn.functional as F
b = F.unfold(a, 3, 1, 1, 1)
print(b, b.shape)

tensor([[[[1., 2., 3., 4.],
          [4., 5., 6., 4.],
          [7., 8., 9., 4.]]]])
tensor([[[0., 0., 0., 0., 0., 1., 2., 3., 0., 4., 5., 6.],
         [0., 0., 0., 0., 1., 2., 3., 4., 4., 5., 6., 4.],
         [0., 0., 0., 0., 2., 3., 4., 0., 5., 6., 4., 0.],
         [0., 1., 2., 3., 0., 4., 5., 6., 0., 7., 8., 9.],
         [1., 2., 3., 4., 4., 5., 6., 4., 7., 8., 9., 4.],
         [2., 3., 4., 0., 5., 6., 4., 0., 8., 9., 4., 0.],
         [0., 4., 5., 6., 0., 7., 8., 9., 0., 0., 0., 0.],
         [4., 5., 6., 4., 7., 8., 9., 4., 0., 0., 0., 0.],
         [5., 6., 4., 0., 8., 9., 4., 0., 0., 0., 0., 0.]]]) torch.Size([1, 9, 12])


In [8]:
torch.empty(100, 4, 3, 3).flatten(-3)[None, :, None, :].shape

torch.Size([1, 100, 1, 36])