In [None]:
from google.colab import drive
drive.mount('/content/drive')

FOLDERNAME = 'CVX_Robust_NN/'

import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

%load_ext autoreload
%autoreload 2

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Import dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, RandomSampler

import os

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from prepare_data import *
from models.mobilenetv2 import *
from models.vgg import *
from models.praresnet import *
from cvx_scripts.losses import *
from cvx_scripts.cvx_nn import *
from cvx_scripts.cvx_training import *
from sam import *
from pgd import *

Set device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


Load the CIFAR-10 Dataset

In [None]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12606128.61it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Set parameters, intialize model and optimizer.

In [None]:
lr = 0.1 #1e-4 # 0.1
best_acc = 0
start_epoch = 0

In [None]:
net = PreActResNet18(10)

net = net.to(device)
criterion = nn.CrossEntropyLoss()
base_opt = torch.optim.SGD
optimizer = SAM(net.parameters(), base_opt,lr=lr, momentum=0.9, weight_decay=5e-4, rho=0.6)

In [None]:
# Training
def train(epoch, sam=False):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = net(inputs)
        loss = criterion(outputs, targets)

        if sam:
          loss.backward()
          optimizer.first_step(zero_grad=True)

          output_2 = net(inputs)
          criterion(output_2, targets).backward()
          optimizer.second_step(zero_grad=True)
        else:
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(train_loss, 100.*correct / total)

def test(epoch, best_acc):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        best_acc = acc
    print(best_acc)

In [None]:
def lr_schedule(epoch, total_epochs, initial_lr):
    if epoch < total_epochs * 0.75:
        return initial_lr
    elif epoch < total_epochs * 0.9:
        return initial_lr * 0.1
    else:
        return initial_lr * 0.01

for epoch in range(start_epoch, start_epoch+100):
    lr = lr_schedule(epoch, 100, 0.1)
    optimizer.param_groups[0].update(lr=lr)
    train(epoch, sam = True)
    test(epoch, best_acc)

torch.save(net.state_dict(), sys.path[-1] + 'praresnet.pth')