In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

from ori_layers import SandwichFc, SandwichConv, SandwichLin, LinearNormalized, MultiMargin, LipResFC, LipResConv

In [2]:
device='cpu'
batch_size = 64
w = 1 # can be 1, 2, 4, indicating the network size


In [3]:
mean = [0.4914, 0.4822, 0.4465]
std  = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
hue = 0.02
saturation = (.3, 2.)
brightness = 0.1
contrast = (.5, 2.)

transforms_list = [transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=brightness, contrast=contrast,
            saturation=saturation, hue=hue),
        transforms.ToTensor(),
        normalize]

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose(transforms_list))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),normalize]))
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
for x, y in testloader:
    print(x.min(), x.max())
    print(x.shape)
    print(y.shape)
    break

tensor(-1.9895) tensor(2.1265)
torch.Size([64, 3, 32, 32])
torch.Size([64])


In [6]:
import torch.optim as optim


def train_net(model):
    epochs = 100
    lr = 1e-2

    criterion = MultiMargin()
    opt = torch.optim.Adam(model.parameters(), lr=lr,  weight_decay=0)
    lr_schedule = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs], [0, lr, lr/20.0, 0])[0]

    for epoch in range(100):  # loop over the dataset multiple times

        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]

            lr = lr_schedule(epoch + (i+1)/len(trainloader))
            opt.param_groups[0].update(lr=lr)

            inputs, labels = data
            inputs = inputs.to(device)#.view(-1, 32 * 32 * 3)
            labels = labels.to(device)

            opt.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()

            if i % 10 == 0:    # print every 2000 mini-batches
                print(epoch, i, loss.item())

    print('Finished Training')

In [7]:
lips = np.array([1, 10, 100])
lips_layer = np.exp(np.log(lips) / 7)
for i, l in enumerate(lips):
    # net = nn.Sequential(
    #     SandwichConv(3, 32 * w, 3, scale=np.sqrt(l)),
    #     SandwichConv(32 * w, 32 * w, 3, stride=2),
    #     SandwichConv(32 * w, 64 * w, 3),
    #     SandwichConv(64 * w, 64 * w, 3, stride=2),
    #     nn.Flatten(),
    #     SandwichFc(64 * 8 * 8 * w, 512 * w),
    #     SandwichFc(512 * w, 512 * w),
    #     SandwichLin(512 * w,10, scale=np.sqrt(l))
    # ).to(device)

    net = nn.Sequential(
        LipResConv(3, 32 * w, 3, strided=False, L = lips_layer[i]),
        LipResConv(32 * w, 32 * w, 3, strided=True, L = lips_layer[i]),
        LipResConv(32 * w, 64 * w, 3, strided=False, L = lips_layer[i]),
        LipResConv(64 * w, 64 * w, 3, strided=True, L = lips_layer[i]),
        nn.Flatten(),
        LipResFC(64 * 8 * 8 * w, 512 * w, L = lips_layer[i]),
        LipResFC(512 * w, 512 * w, L = lips_layer[i]),
        LipResFC(512 * w,10, L = lips_layer[i])
    ).to(device)
    train_net(net)
    torch.save(net.state_dict(), f'./lip_{l}.pth')


0 0 0.4470723867416382
0 10 0.4483972489833832
0 20 0.4494003355503082
0 30 0.44101962447166443
0 40 0.4514889717102051
0 50 0.447468101978302
0 60 0.44915860891342163
0 70 0.44953662157058716
0 80 0.45385026931762695


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x127809280>
Traceback (most recent call last):
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/wangjiarui/miniconda3/envs/ml/lib/python3.8/selectors.py", line 415, in selec

KeyboardInterrupt: 

In [9]:
lips = np.array([1, 10, 100])
lips_layer = np.exp(np.log(lips) / 7)
act = nn.ReLU
for i, l in enumerate(lips):
    net = nn.Sequential(
        LipResConv(3, 32 * w, 3, strided=False, L = lips_layer[i]),
        LipResConv(32 * w, 32 * w, 3, strided=True, L = lips_layer[i]),
        LipResConv(32 * w, 64 * w, 3, strided=False, L = lips_layer[i]),
        LipResConv(64 * w, 64 * w, 3, strided=True, L = lips_layer[i]),
        nn.Flatten(),
        LipResFC(64 * 8 * 8 * w, 512 * w, L = lips_layer[i]),
        LipResFC(512 * w, 512 * w, L = lips_layer[i]),
        LipResFC(512 * w,10, L = lips_layer[i])
    ).to(device)

    net.load_state_dict(torch.load(f'./ckpts/lip_{l}_small.pth', map_location=torch.device('cpu')))

    correct = 0
    total = 0
    iter = 0
    for x, y in testloader:
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            outputs = net(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

        iter += 1
        if iter % 10 == 0:
            break

    print(l, correct / total)

1 0.709375
10 0.76875
100 0.7984375
