In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.autograd import Variable
from ada_hessian import AdaHessian
import torch.optim.lr_scheduler as lr_scheduler
import time
import pandas as pd

import numpy as np
import math

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 256


def getNet(device):
    net = models.resnet18()
    net.to(device)
    return net

def generateExperiment(net, optimizer, trainloader, testloader, 
                       isHessian, csv_name, device, criterion = nn.CrossEntropyLoss(), total_epochs = 160):
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        [80, 120],
        gamma=0.1,
        last_epoch=-1)
    train_losses = []
    train_acc = []
    train_times = []
    val_loss = []
    val_acc = []
    epochs = []

    for epoch in range(total_epochs):  

        train_loss = 0.0
        train_step = 0
        train_total = 0
        train_correct = 0
        opt_time = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            train_step = train_step + 1
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)

            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            loss = criterion(outputs, labels)
            if isHessian:
                loss.backward(create_graph=True)
            else:
                loss.backward()

            t = time.process_time()
            optimizer.step()
            opt_time += time.process_time() - t

            scheduler.step()

            # print statistics
            train_loss += loss.item()

        test_loss = 0.0
        test_step = 0
        test_total = 0
        test_correct = 0

        for i, data in enumerate(testloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            test_step = test_step + 1
            inputs, labels = data[0].to(device), data[1].to(device)


            # forward + backward + optimize
            outputs = net(inputs)

            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            loss = criterion(outputs, labels)
            # print statistics
            test_loss += loss.item()
        train_losses.append(train_loss / train_step)
        train_acc.append(train_correct / train_total)
        train_times.append(opt_time / train_step)
        val_loss.append(test_loss / test_step)
        val_acc.append(test_correct / test_total)
        epochs.append(epoch)
        print("Epoch: " + str(epoch) + " finished")
    extract_dat = pd.DataFrame({
        "epoch": epochs,
        "loss": train_losses,
        "accuracy": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "opt_time": train_times,
    })
    extract_dat.to_csv(csv_name, index=False)

## Experiment 1: Computer Vision

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



In [None]:
net = getNet(device)

optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

generateExperiment(net, optimizer, trainloader, testloader, False, "SGD_Moment_torch.csv", device)

In [None]:
net = getNet(device)

optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0, weight_decay=5e-4)

generateExperiment(net, optimizer, trainloader, testloader, False, "SGD_torch.csv", device)

In [None]:
net = getNet(device)
optimizer_ada = AdaHessian(net.parameters(), lr=0.15, 
                           average_conv_kernel=True, hessian_power=1, 
                           n_samples=1, weight_decay=5e-4)

generateExperiment(net, optimizer_ada, trainloader, testloader, True, "AdaHess_torch.csv", device)

In [None]:
net = getNet(device)
optimizer_adam = optim.Adam (net.parameters(), lr=0.001, weight_decay=5e-4)

generateExperiment(net, optimizer_adam, trainloader, testloader, False, "Adam_torch.csv", device)

In [None]:
net = getNet(device)
optimizer_adamw = optim.AdamW (net.parameters(), lr=0.01, weight_decay=5e-4)

generateExperiment(net, optimizer_adamw, trainloader, testloader, False, "AdamW_torch.csv", device)

## Experiment 2: DNN

In [4]:
n = 100000
batch_size=2000
def gen_egg_pts(n):
    x1 = np.random.uniform(-512, 512, n)
    x2 = np.random.uniform(-512, 512, n)
    f_x = -(x2 + 47) * np.sin(np.sqrt(np.abs(x1 / 2 + (x2 + 47)))) \
        - x1 * np.sin(np.abs(x1 - (x2 + 47))) 
    noise = np.random.normal(0, math.sqrt(0.3), n) 
    X = np.transpose(np.array([x1, x2]))
    return X, f_x + noise

x, y = gen_egg_pts(n)

In [5]:
train_size = int(n * 0.8)
test_size = n - train_size

criterion = nn.MSELoss()

tensor_x = torch.Tensor(x)
tensor_y = torch.Tensor(y)
my_dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y)

train_reg, test_reg = torch.utils.data.random_split(my_dataset, (train_size, test_size))

train_reg_loader = torch.utils.data.DataLoader(train_reg, batch_size = batch_size)
test_reg_loader = torch.utils.data.DataLoader(test_reg, batch_size = batch_size)

In [17]:
class Reg_Net(nn.Module):
    def __init__(self):
        super(Reg_Net, self).__init__()
        self.fc1 = nn.Linear(2, 120)
        self.fc2 = nn.Linear(120, 120)
        self.fc3 = nn.Linear(120, 120)
        self.fc4 = nn.Linear(120, 120)
        self.fc5 = nn.Linear(120, 1)

    def forward(self, x):
        x = F.(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

def getRegNet(device):
    net = Reg_Net()
    net.to(device)
    return net

In [7]:
def generateExperimentReg(net, optimizer, trainloader, testloader, 
                       isHessian, csv_name, device, criterion = nn.MSELoss(), total_epochs = 2000):
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        [800, 1200],
        gamma=0.1,
        last_epoch=-1)
    train_losses = []
    train_times = []
    val_loss = []
    epochs = []

    for epoch in range(total_epochs):  

        train_loss = 0.0
        train_step = 0
        train_total = 0
        opt_time = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            train_step = train_step + 1
            inputs, y = data[0].to(device), data[1].to(device)
            y = torch.unsqueeze(y, 1)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)

            train_total += y.size(0)

            loss = criterion(outputs, y)
            if isHessian:
                loss.backward(create_graph=True)
            else:
                loss.backward()

            t = time.process_time()
            optimizer.step()
            opt_time += time.process_time() - t

            scheduler.step()

            # print statistics
            train_loss += loss.item()

        test_loss = 0.0
        test_step = 0
        test_total = 0

        for i, data in enumerate(testloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            test_step = test_step + 1
            inputs, y = data[0].to(device), data[1].to(device)
            y = torch.unsqueeze(y, 1)

            # forward + backward + optimize
            outputs = net(inputs)
            test_total += y.size(0)

            loss = criterion(outputs, y)
            # print statistics
            test_loss += loss.item()
        train_losses.append(train_loss / train_step)
        train_times.append(opt_time / train_step)
        val_loss.append(test_loss / test_step)
        epochs.append(epoch)
        print("Epoch: " + str(epoch) + " finished with training loss " + str(train_loss / train_step))
    extract_dat = pd.DataFrame({
        "epoch": epochs,
        "loss": train_losses,
        "val_loss": val_loss,
        "opt_time": train_times,
    })
    extract_dat.to_csv(csv_name, index=False)

In [None]:
net = getRegNet(device)


optimizer = optim.SGD(net.parameters(), lr=10e-6, momentum=10e-3, weight_decay=5e-4)



generateExperimentReg(net, optimizer, train_reg_loader, test_reg_loader, False, 
                      "SGD_Moment_Reg_torch.csv", device)

In [None]:
net = getRegNet(device)


optimizer = optim.SGD(net.parameters(), lr=10e-6, momentum=0, weight_decay=5e-4)



generateExperimentReg(net, optimizer, train_reg_loader, test_reg_loader, False, 
                      "SGD_Reg_torch.csv", device)

In [None]:
net = getRegNet(device)


optimizer = optim.Adam(net.parameters(), weight_decay=5e-4)



generateExperimentReg(net, optimizer, train_reg_loader, test_reg_loader, False, 
                      "Adam_Reg_torch.csv", device)

In [None]:
net = getRegNet(device)


optimizer = optim.AdamW(net.parameters(), weight_decay=5e-4)



generateExperimentReg(net, optimizer, train_reg_loader, test_reg_loader, False, 
                      "AdamW_Reg_torch.csv", device)

In [27]:
net = getRegNet(device)


optimizer = AdaHessian(net.parameters(), lr=0.1, weight_decay=5e-4)



generateExperimentReg(net, optimizer, train_reg_loader, test_reg_loader, True, 
                      "AdamHess_Reg_torch.csv", device)

Epoch: 0 finished with training loss 88842.07421875
Epoch: 1 finished with training loss 85822.5548828125
Epoch: 2 finished with training loss 84843.801953125
Epoch: 3 finished with training loss 84437.1466796875
Epoch: 4 finished with training loss 84201.66953125
Epoch: 5 finished with training loss 84061.5443359375
Epoch: 6 finished with training loss 83863.1912109375
Epoch: 7 finished with training loss 83721.21171875
Epoch: 8 finished with training loss 83549.837109375
Epoch: 9 finished with training loss 83391.902734375
Epoch: 10 finished with training loss 83292.856640625
Epoch: 11 finished with training loss 83203.823828125
Epoch: 12 finished with training loss 83167.7421875
Epoch: 13 finished with training loss 83112.983203125
Epoch: 14 finished with training loss 83049.5509765625
Epoch: 15 finished with training loss 82906.5259765625
Epoch: 16 finished with training loss 82696.4869140625
Epoch: 17 finished with training loss 82642.173828125
Epoch: 18 finished with training los

Epoch: 152 finished with training loss 81627.7197265625
Epoch: 153 finished with training loss 81626.724609375
Epoch: 154 finished with training loss 81625.5734375
Epoch: 155 finished with training loss 81624.608203125
Epoch: 156 finished with training loss 81623.91015625
Epoch: 157 finished with training loss 81622.8546875
Epoch: 158 finished with training loss 81621.6400390625
Epoch: 159 finished with training loss 81620.39296875
Epoch: 160 finished with training loss 81619.2349609375
Epoch: 161 finished with training loss 81618.3240234375
Epoch: 162 finished with training loss 81617.325390625
Epoch: 163 finished with training loss 81616.227734375
Epoch: 164 finished with training loss 81615.2412109375
Epoch: 165 finished with training loss 81614.18984375
Epoch: 166 finished with training loss 81613.0525390625
Epoch: 167 finished with training loss 81612.0439453125
Epoch: 168 finished with training loss 81611.225390625
Epoch: 169 finished with training loss 81610.3236328125
Epoch: 17

Epoch: 301 finished with training loss 81497.5220703125
Epoch: 302 finished with training loss 81496.543359375
Epoch: 303 finished with training loss 81496.028515625
Epoch: 304 finished with training loss 81495.51328125
Epoch: 305 finished with training loss 81494.8802734375
Epoch: 306 finished with training loss 81494.2802734375
Epoch: 307 finished with training loss 81493.5431640625
Epoch: 308 finished with training loss 81492.92109375
Epoch: 309 finished with training loss 81492.0115234375
Epoch: 310 finished with training loss 81490.846875
Epoch: 311 finished with training loss 81490.232421875
Epoch: 312 finished with training loss 81489.67421875
Epoch: 313 finished with training loss 81489.0880859375
Epoch: 314 finished with training loss 81488.441796875
Epoch: 315 finished with training loss 81487.9869140625
Epoch: 316 finished with training loss 81487.2419921875
Epoch: 317 finished with training loss 81486.89140625
Epoch: 318 finished with training loss 81486.4380859375
Epoch: 3

Epoch: 450 finished with training loss 81425.0037109375
Epoch: 451 finished with training loss 81424.7373046875
Epoch: 452 finished with training loss 81424.39375
Epoch: 453 finished with training loss 81424.0369140625
Epoch: 454 finished with training loss 81423.6953125
Epoch: 455 finished with training loss 81423.1986328125
Epoch: 456 finished with training loss 81422.838671875
Epoch: 457 finished with training loss 81422.4423828125
Epoch: 458 finished with training loss 81422.3087890625
Epoch: 459 finished with training loss 81421.790234375
Epoch: 460 finished with training loss 81421.5095703125
Epoch: 461 finished with training loss 81421.2388671875
Epoch: 462 finished with training loss 81420.95
Epoch: 463 finished with training loss 81420.5796875
Epoch: 464 finished with training loss 81420.150390625
Epoch: 465 finished with training loss 81419.9240234375
Epoch: 466 finished with training loss 81419.4720703125
Epoch: 467 finished with training loss 81419.322265625
Epoch: 468 fini

Epoch: 599 finished with training loss 81373.956640625
Epoch: 600 finished with training loss 81374.0306640625
Epoch: 601 finished with training loss 81373.800390625
Epoch: 602 finished with training loss 81373.412890625
Epoch: 603 finished with training loss 81373.3458984375
Epoch: 604 finished with training loss 81372.937109375
Epoch: 605 finished with training loss 81372.5498046875
Epoch: 606 finished with training loss 81372.4931640625
Epoch: 607 finished with training loss 81372.18125
Epoch: 608 finished with training loss 81371.8869140625
Epoch: 609 finished with training loss 81371.48828125
Epoch: 610 finished with training loss 81371.3447265625
Epoch: 611 finished with training loss 81371.230859375
Epoch: 612 finished with training loss 81371.0005859375
Epoch: 613 finished with training loss 81370.467578125
Epoch: 614 finished with training loss 81370.507421875
Epoch: 615 finished with training loss 81370.194921875
Epoch: 616 finished with training loss 81369.7462890625
Epoch: 

Epoch: 748 finished with training loss 81333.9080078125
Epoch: 749 finished with training loss 81333.530859375
Epoch: 750 finished with training loss 81333.2822265625
Epoch: 751 finished with training loss 81332.916015625
Epoch: 752 finished with training loss 81332.8796875
Epoch: 753 finished with training loss 81332.96640625
Epoch: 754 finished with training loss 81332.748828125
Epoch: 755 finished with training loss 81332.771875
Epoch: 756 finished with training loss 81332.5513671875
Epoch: 757 finished with training loss 81332.080078125
Epoch: 758 finished with training loss 81331.4904296875
Epoch: 759 finished with training loss 81331.2076171875
Epoch: 760 finished with training loss 81330.88203125
Epoch: 761 finished with training loss 81330.6240234375
Epoch: 762 finished with training loss 81330.315625
Epoch: 763 finished with training loss 81329.9921875
Epoch: 764 finished with training loss 81329.8052734375
Epoch: 765 finished with training loss 81329.4634765625
Epoch: 766 fin

Epoch: 898 finished with training loss 81289.1564453125
Epoch: 899 finished with training loss 81288.8626953125
Epoch: 900 finished with training loss 81288.6060546875
Epoch: 901 finished with training loss 81288.25
Epoch: 902 finished with training loss 81287.82109375
Epoch: 903 finished with training loss 81287.6591796875
Epoch: 904 finished with training loss 81287.4779296875
Epoch: 905 finished with training loss 81287.339453125
Epoch: 906 finished with training loss 81286.946875
Epoch: 907 finished with training loss 81286.8419921875
Epoch: 908 finished with training loss 81286.92734375
Epoch: 909 finished with training loss 81286.4345703125
Epoch: 910 finished with training loss 81286.0412109375
Epoch: 911 finished with training loss 81285.2498046875
Epoch: 912 finished with training loss 81284.8685546875
Epoch: 913 finished with training loss 81284.2970703125
Epoch: 914 finished with training loss 81284.1798828125
Epoch: 915 finished with training loss 81283.7392578125
Epoch: 91

Epoch: 1046 finished with training loss 81247.220703125
Epoch: 1047 finished with training loss 81246.8505859375
Epoch: 1048 finished with training loss 81246.1396484375
Epoch: 1049 finished with training loss 81245.5814453125
Epoch: 1050 finished with training loss 81244.96875
Epoch: 1051 finished with training loss 81244.251171875
Epoch: 1052 finished with training loss 81243.744140625
Epoch: 1053 finished with training loss 81243.0357421875
Epoch: 1054 finished with training loss 81242.26015625
Epoch: 1055 finished with training loss 81241.9404296875
Epoch: 1056 finished with training loss 81241.6912109375
Epoch: 1057 finished with training loss 81241.28359375
Epoch: 1058 finished with training loss 81240.61796875
Epoch: 1059 finished with training loss 81240.205078125
Epoch: 1060 finished with training loss 81239.5146484375
Epoch: 1061 finished with training loss 81239.3515625
Epoch: 1062 finished with training loss 81238.99296875
Epoch: 1063 finished with training loss 81238.68417

Epoch: 1192 finished with training loss 81197.49765625
Epoch: 1193 finished with training loss 81196.7578125
Epoch: 1194 finished with training loss 81196.494921875
Epoch: 1195 finished with training loss 81195.840625
Epoch: 1196 finished with training loss 81195.7853515625
Epoch: 1197 finished with training loss 81195.5734375
Epoch: 1198 finished with training loss 81195.4244140625
Epoch: 1199 finished with training loss 81195.057421875
Epoch: 1200 finished with training loss 81194.778125
Epoch: 1201 finished with training loss 81194.413671875
Epoch: 1202 finished with training loss 81194.151953125
Epoch: 1203 finished with training loss 81193.9439453125
Epoch: 1204 finished with training loss 81193.8599609375
Epoch: 1205 finished with training loss 81193.5841796875
Epoch: 1206 finished with training loss 81193.325390625
Epoch: 1207 finished with training loss 81192.8236328125
Epoch: 1208 finished with training loss 81192.4162109375
Epoch: 1209 finished with training loss 81192.078906

Epoch: 1339 finished with training loss 81150.595703125
Epoch: 1340 finished with training loss 81150.1150390625
Epoch: 1341 finished with training loss 81150.0896484375
Epoch: 1342 finished with training loss 81149.6279296875
Epoch: 1343 finished with training loss 81149.15859375
Epoch: 1344 finished with training loss 81148.931640625
Epoch: 1345 finished with training loss 81148.6234375
Epoch: 1346 finished with training loss 81148.30390625
Epoch: 1347 finished with training loss 81148.2666015625
Epoch: 1348 finished with training loss 81147.604296875
Epoch: 1349 finished with training loss 81147.47265625
Epoch: 1350 finished with training loss 81147.321875
Epoch: 1351 finished with training loss 81146.987890625
Epoch: 1352 finished with training loss 81146.480078125
Epoch: 1353 finished with training loss 81146.123046875
Epoch: 1354 finished with training loss 81145.556640625
Epoch: 1355 finished with training loss 81145.148046875
Epoch: 1356 finished with training loss 81144.791796

Epoch: 1486 finished with training loss 81098.59765625
Epoch: 1487 finished with training loss 81098.1392578125
Epoch: 1488 finished with training loss 81097.9740234375
Epoch: 1489 finished with training loss 81097.5830078125
Epoch: 1490 finished with training loss 81097.350390625
Epoch: 1491 finished with training loss 81097.0794921875
Epoch: 1492 finished with training loss 81096.823828125
Epoch: 1493 finished with training loss 81096.51953125
Epoch: 1494 finished with training loss 81096.0380859375
Epoch: 1495 finished with training loss 81095.9220703125
Epoch: 1496 finished with training loss 81095.8646484375
Epoch: 1497 finished with training loss 81095.5560546875
Epoch: 1498 finished with training loss 81094.780078125
Epoch: 1499 finished with training loss 81094.3009765625
Epoch: 1500 finished with training loss 81093.7728515625
Epoch: 1501 finished with training loss 81093.730078125
Epoch: 1502 finished with training loss 81093.0861328125
Epoch: 1503 finished with training loss

Epoch: 1632 finished with training loss 81046.19296875
Epoch: 1633 finished with training loss 81046.209765625
Epoch: 1634 finished with training loss 81045.37578125
Epoch: 1635 finished with training loss 81044.987890625
Epoch: 1636 finished with training loss 81044.8986328125
Epoch: 1637 finished with training loss 81044.429296875
Epoch: 1638 finished with training loss 81043.9466796875
Epoch: 1639 finished with training loss 81043.419921875
Epoch: 1640 finished with training loss 81043.0169921875
Epoch: 1641 finished with training loss 81042.7267578125
Epoch: 1642 finished with training loss 81042.405078125
Epoch: 1643 finished with training loss 81041.9810546875
Epoch: 1644 finished with training loss 81041.6759765625
Epoch: 1645 finished with training loss 81041.4208984375
Epoch: 1646 finished with training loss 81040.9623046875
Epoch: 1647 finished with training loss 81040.6728515625
Epoch: 1648 finished with training loss 81040.306640625
Epoch: 1649 finished with training loss 8

Epoch: 1779 finished with training loss 80989.3734375
Epoch: 1780 finished with training loss 80988.8166015625
Epoch: 1781 finished with training loss 80988.5244140625
Epoch: 1782 finished with training loss 80988.21796875
Epoch: 1783 finished with training loss 80987.29296875
Epoch: 1784 finished with training loss 80987.27109375
Epoch: 1785 finished with training loss 80986.743359375
Epoch: 1786 finished with training loss 80986.0595703125
Epoch: 1787 finished with training loss 80985.6134765625
Epoch: 1788 finished with training loss 80985.09140625
Epoch: 1789 finished with training loss 80984.7326171875
Epoch: 1790 finished with training loss 80984.1818359375
Epoch: 1791 finished with training loss 80984.0736328125
Epoch: 1792 finished with training loss 80983.6251953125
Epoch: 1793 finished with training loss 80983.1533203125
Epoch: 1794 finished with training loss 80982.6056640625
Epoch: 1795 finished with training loss 80982.3544921875
Epoch: 1796 finished with training loss 809

Epoch: 1926 finished with training loss 80923.208984375
Epoch: 1927 finished with training loss 80922.496484375
Epoch: 1928 finished with training loss 80921.9416015625
Epoch: 1929 finished with training loss 80921.4419921875
Epoch: 1930 finished with training loss 80920.9853515625
Epoch: 1931 finished with training loss 80920.4015625
Epoch: 1932 finished with training loss 80920.14140625
Epoch: 1933 finished with training loss 80920.0125
Epoch: 1934 finished with training loss 80919.38046875
Epoch: 1935 finished with training loss 80918.810546875
Epoch: 1936 finished with training loss 80918.540625
Epoch: 1937 finished with training loss 80917.9134765625
Epoch: 1938 finished with training loss 80917.5197265625
Epoch: 1939 finished with training loss 80916.81015625
Epoch: 1940 finished with training loss 80916.6607421875
Epoch: 1941 finished with training loss 80916.3958984375
Epoch: 1942 finished with training loss 80915.9763671875
Epoch: 1943 finished with training loss 80915.5966796