In [0]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import optim
from torch import nn
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
from LookAhead import Lookahead

In [0]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [0]:
writer = SummaryWriter('tensorboard')
save_file = 'state.pt'
device = torch.device('cuda:0')

In [0]:
net_sgd = resnet18().to(device)
net_sgd_la = resnet18().to(device)
net_adam = resnet18().to(device)
net_adam_la = resnet18().to(device)

In [0]:
criterion_sgd = nn.CrossEntropyLoss()
criterion_sgd_la = nn.CrossEntropyLoss()
criterion_adam = nn.CrossEntropyLoss()
criterion_adam_la = nn.CrossEntropyLoss()
opt_sgd = optim.SGD(net_sgd.parameters(), lr=0.001, momentum=0.9)
opt_sgd_la = Lookahead(optim.SGD(net_sgd_la.parameters(), lr=0.001, momentum=0.9))
opt_adam = optim.Adam(net_adam.parameters())
opt_adam_la = Lookahead(optim.Adam(net_adam_la.parameters()))

In [0]:
for epoch in range(15):
    running_loss = np.zeros(4)
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
        
        opt_sgd.zero_grad()
        opt_sgd_la.zero_grad()
        opt_adam.zero_grad()
        opt_adam_la.zero_grad()
        
        out_sgd = net_sgd(inputs)
        out_sgd_la = net_sgd_la(inputs)
        out_adam = net_adam(inputs)
        out_adam_la = net_adam_la(inputs)
        
        loss_sgd = criterion_sgd(out_sgd, labels)
        loss_sgd_la = criterion_sgd_la(out_sgd_la, labels)
        loss_adam = criterion_adam(out_adam, labels)
        loss_adam_la = criterion_adam_la(out_adam_la, labels)
        
        loss_sgd.backward()
        opt_sgd.step()
        
        loss_sgd_la.backward()
        opt_sgd_la.step()
        
        loss_adam.backward()
        opt_adam.step()
        
        loss_adam_la.backward()
        opt_adam_la.step()
        
        running_loss[0] += loss_sgd.item()
        running_loss[1] += loss_sgd_la.item()
        running_loss[2] += loss_adam.item()
        running_loss[3] += loss_adam_la.item()
        
        if i % 2000 == 1999:
            print('[%d, %5d] loss: [%.3f, %.3f, %.3f, %.3f]' %
                  (epoch + 1, i + 1, running_loss[0] / 2000, running_loss[1] / 2000,
                  running_loss[2] / 2000, running_loss[3] / 2000))
            writer.add_scalars('Loss', {'SGD': running_loss[0] / 2000,
                                       'Lookahead(SGD)': running_loss[1] / 2000,
                                       'Adam': running_loss[2] / 2000,
                                       'Lookahead(Adam)': running_loss[3] / 2000})
            running_loss = np.zeros(4)
    torch.save({'epoch': epoch+1,
               'sgd_state':[net_sgd.state_dict(), opt_sgd.state_dict()],
               'sgd_la_state':[net_sgd_la.state_dict(), opt_sgd_la.state_dict()],
               'adam_state':[net_adam.state_dict(), opt_adam.state_dict()],
               'adam_la_state':[net_adam_la.state_dict(), opt_adam_la.state_dict()]},
              save_file)

[1,  2000] loss: [2.397, 2.339, 2.291, 2.238]
[1,  4000] loss: [2.098, 2.011, 2.075, 2.062]
[1,  6000] loss: [1.989, 1.918, 2.038, 2.055]
[1,  8000] loss: [1.853, 1.798, 1.939, 1.896]
[1, 10000] loss: [1.793, 1.746, 1.818, 1.811]
[1, 12000] loss: [1.708, 1.664, 1.736, 1.731]
[2,  2000] loss: [1.635, 1.608, 1.640, 1.648]
[2,  4000] loss: [1.600, 1.551, 1.638, 1.631]
[2,  6000] loss: [1.534, 1.485, 1.555, 1.539]
[2,  8000] loss: [1.499, 1.479, 1.531, 1.538]
[2, 10000] loss: [1.437, 1.411, 1.447, 1.467]
[2, 12000] loss: [1.434, 1.399, 1.459, 1.430]
[3,  2000] loss: [1.368, 1.337, 1.404, 1.388]
[3,  4000] loss: [1.336, 1.328, 1.404, 1.357]
[3,  6000] loss: [1.276, 1.258, 1.358, 1.315]
[3,  8000] loss: [1.247, 1.246, 1.312, 1.284]
[3, 10000] loss: [1.238, 1.233, 1.295, 1.253]
[3, 12000] loss: [1.219, 1.209, 1.285, 1.244]
[4,  2000] loss: [1.154, 1.131, 1.227, 1.161]
[4,  4000] loss: [1.132, 1.127, 1.237, 1.200]
[4,  6000] loss: [1.134, 1.120, 1.242, 1.161]
[4,  8000] loss: [1.128, 1.118, 1.

In [0]:
for epoch in range(15, 30):
    running_loss = np.zeros(4)
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
        
        opt_sgd.zero_grad()
        opt_sgd_la.zero_grad()
        opt_adam.zero_grad()
        opt_adam_la.zero_grad()
        
        out_sgd = net_sgd(inputs)
        out_sgd_la = net_sgd_la(inputs)
        out_adam = net_adam(inputs)
        out_adam_la = net_adam_la(inputs)
        
        loss_sgd = criterion_sgd(out_sgd, labels)
        loss_sgd_la = criterion_sgd_la(out_sgd_la, labels)
        loss_adam = criterion_adam(out_adam, labels)
        loss_adam_la = criterion_adam_la(out_adam_la, labels)
        
        loss_sgd.backward()
        opt_sgd.step()
        
        loss_sgd_la.backward()
        opt_sgd_la.step()
        
        loss_adam.backward()
        opt_adam.step()
        
        loss_adam_la.backward()
        opt_adam_la.step()
        
        running_loss[0] += loss_sgd.item()
        running_loss[1] += loss_sgd_la.item()
        running_loss[2] += loss_adam.item()
        running_loss[3] += loss_adam_la.item()
        
        if i % 2000 == 1999:
            print('[%d, %5d] loss: [%.3f, %.3f, %.3f, %.3f]' %
                  (epoch + 1, i + 1, running_loss[0] / 2000, running_loss[1] / 2000,
                  running_loss[2] / 2000, running_loss[3] / 2000))
            writer.add_scalars('Loss', {'SGD': running_loss[0] / 2000,
                                       'Lookahead(SGD)': running_loss[1] / 2000,
                                       'Adam': running_loss[2] / 2000,
                                       'Lookahead(Adam)': running_loss[3] / 2000})
            running_loss = np.zeros(4)
    torch.save({'epoch': epoch+1,
               'sgd_state':[net_sgd.state_dict(), opt_sgd.state_dict()],
               'sgd_la_state':[net_sgd_la.state_dict(), opt_sgd_la.state_dict()],
               'adam_state':[net_adam.state_dict(), opt_adam.state_dict()],
               'adam_la_state':[net_adam_la.state_dict(), opt_adam_la.state_dict()]},
              save_file)

[16,  2000] loss: [0.344, 0.333, 0.595, 0.642]
[16,  4000] loss: [0.362, 0.368, 0.607, 0.630]
[16,  6000] loss: [0.356, 0.365, 0.608, 0.644]
[16,  8000] loss: [0.371, 0.372, 0.614, 0.618]
[16, 10000] loss: [0.436, 0.396, 0.636, 0.626]
[16, 12000] loss: [0.403, 0.394, 0.613, 0.643]
[17,  2000] loss: [0.313, 0.310, 0.582, 0.573]
[17,  4000] loss: [0.316, 0.311, 0.581, 0.577]
[17,  6000] loss: [0.336, 0.319, 0.575, 0.597]
[17,  8000] loss: [0.352, 0.344, 0.597, 0.586]
[17, 10000] loss: [0.351, 0.338, 0.596, 0.610]
[17, 12000] loss: [0.379, 0.357, 0.588, 0.592]
[18,  2000] loss: [0.281, 0.266, 0.549, 0.574]
[18,  4000] loss: [0.286, 0.279, 0.525, 0.549]
[18,  6000] loss: [0.298, 0.279, 0.529, 0.545]
[18,  8000] loss: [0.305, 0.306, 0.534, 0.566]
[18, 10000] loss: [0.310, 0.310, 0.557, 0.567]
[18, 12000] loss: [0.330, 0.335, 0.598, 0.580]
[19,  2000] loss: [0.238, 0.236, 0.538, 0.499]
[19,  4000] loss: [0.246, 0.245, 0.496, 0.501]
[19,  6000] loss: [0.256, 0.271, 0.523, 0.510]
[19,  8000] l

In [0]:
for epoch in range(30, 40):
    running_loss = np.zeros(4)
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
        
        opt_sgd.zero_grad()
        opt_sgd_la.zero_grad()
        opt_adam.zero_grad()
        opt_adam_la.zero_grad()
        
        out_sgd = net_sgd(inputs)
        out_sgd_la = net_sgd_la(inputs)
        out_adam = net_adam(inputs)
        out_adam_la = net_adam_la(inputs)
        
        loss_sgd = criterion_sgd(out_sgd, labels)
        loss_sgd_la = criterion_sgd_la(out_sgd_la, labels)
        loss_adam = criterion_adam(out_adam, labels)
        loss_adam_la = criterion_adam_la(out_adam_la, labels)
        
        loss_sgd.backward()
        opt_sgd.step()
        
        loss_sgd_la.backward()
        opt_sgd_la.step()
        
        loss_adam.backward()
        opt_adam.step()
        
        loss_adam_la.backward()
        opt_adam_la.step()
        
        running_loss[0] += loss_sgd.item()
        running_loss[1] += loss_sgd_la.item()
        running_loss[2] += loss_adam.item()
        running_loss[3] += loss_adam_la.item()
        
        if i % 2000 == 1999:
            print('[%d, %5d] loss: [%.3f, %.3f, %.3f, %.3f]' %
                  (epoch + 1, i + 1, running_loss[0] / 2000, running_loss[1] / 2000,
                  running_loss[2] / 2000, running_loss[3] / 2000))
            writer.add_scalars('Loss', {'SGD': running_loss[0] / 2000,
                                       'Lookahead(SGD)': running_loss[1] / 2000,
                                       'Adam': running_loss[2] / 2000,
                                       'Lookahead(Adam)': running_loss[3] / 2000})
            running_loss = np.zeros(4)
    torch.save({'epoch': epoch+1,
               'sgd_state':[net_sgd.state_dict(), opt_sgd.state_dict()],
               'sgd_la_state':[net_sgd_la.state_dict(), opt_sgd_la.state_dict()],
               'adam_state':[net_adam.state_dict(), opt_adam.state_dict()],
               'adam_la_state':[net_adam_la.state_dict(), opt_adam_la.state_dict()]},
              save_file)

[31,  2000] loss: [0.077, 0.120, 0.308, 0.277]
[31,  4000] loss: [0.087, 0.109, 0.327, 0.323]
[31,  6000] loss: [0.099, 0.104, 0.318, 0.295]
[31,  8000] loss: [0.100, 0.106, 0.336, 0.310]
[31, 10000] loss: [0.112, 0.121, 0.346, 0.316]
[31, 12000] loss: [0.108, 0.120, 0.348, 0.331]
[32,  2000] loss: [0.076, 0.069, 0.306, 0.270]
[32,  4000] loss: [0.082, 0.088, 0.280, 0.286]
[32,  6000] loss: [0.083, 0.087, 0.304, 0.321]
[32,  8000] loss: [0.080, 0.113, 0.311, 0.285]
[32, 10000] loss: [0.096, 0.105, 0.302, 0.296]
[32, 12000] loss: [0.079, 0.111, 0.309, 0.313]
[33,  2000] loss: [0.064, 0.066, 0.250, 0.242]
[33,  4000] loss: [0.066, 0.076, 0.272, 0.267]
[33,  6000] loss: [0.073, 0.077, 0.302, 0.265]
[33,  8000] loss: [0.080, 0.096, 0.289, 0.276]
[33, 10000] loss: [0.098, 0.083, 0.287, 0.282]
[33, 12000] loss: [0.096, 0.098, 0.292, 0.288]
[34,  2000] loss: [0.069, 0.062, 0.244, 0.224]
[34,  4000] loss: [0.072, 0.064, 0.283, 0.250]
[34,  6000] loss: [0.070, 0.082, 0.254, 0.254]
[34,  8000] l

In [0]:
for epoch in range(40, 50):
    running_loss = np.zeros(4)
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
        
        opt_sgd.zero_grad()
        opt_sgd_la.zero_grad()
        opt_adam.zero_grad()
        opt_adam_la.zero_grad()
        
        out_sgd = net_sgd(inputs)
        out_sgd_la = net_sgd_la(inputs)
        out_adam = net_adam(inputs)
        out_adam_la = net_adam_la(inputs)
        
        loss_sgd = criterion_sgd(out_sgd, labels)
        loss_sgd_la = criterion_sgd_la(out_sgd_la, labels)
        loss_adam = criterion_adam(out_adam, labels)
        loss_adam_la = criterion_adam_la(out_adam_la, labels)
        
        loss_sgd.backward()
        opt_sgd.step()
        
        loss_sgd_la.backward()
        opt_sgd_la.step()
        
        loss_adam.backward()
        opt_adam.step()
        
        loss_adam_la.backward()
        opt_adam_la.step()
        
        running_loss[0] += loss_sgd.item()
        running_loss[1] += loss_sgd_la.item()
        running_loss[2] += loss_adam.item()
        running_loss[3] += loss_adam_la.item()
        
        if i % 2000 == 1999:
            print('[%d, %5d] loss: [%.3f, %.3f, %.3f, %.3f]' %
                  (epoch + 1, i + 1, running_loss[0] / 2000, running_loss[1] / 2000,
                  running_loss[2] / 2000, running_loss[3] / 2000))
            writer.add_scalars('Loss', {'SGD': running_loss[0] / 2000,
                                       'Lookahead(SGD)': running_loss[1] / 2000,
                                       'Adam': running_loss[2] / 2000,
                                       'Lookahead(Adam)': running_loss[3] / 2000})
            running_loss = np.zeros(4)
    torch.save({'epoch': epoch+1,
               'sgd_state':[net_sgd.state_dict(), opt_sgd.state_dict()],
               'sgd_la_state':[net_sgd_la.state_dict(), opt_sgd_la.state_dict()],
               'adam_state':[net_adam.state_dict(), opt_adam.state_dict()],
               'adam_la_state':[net_adam_la.state_dict(), opt_adam_la.state_dict()]},
              save_file)

[41,  2000] loss: [0.047, 0.036, 0.194, 0.166]
[41,  4000] loss: [0.054, 0.051, 0.204, 0.160]
[41,  6000] loss: [0.059, 0.050, 0.209, 0.221]
[41,  8000] loss: [0.061, 0.050, 0.194, 0.263]
[41, 10000] loss: [0.063, 0.057, 0.204, 0.235]
[41, 12000] loss: [0.079, 0.063, 0.215, 0.210]
[42,  2000] loss: [0.059, 0.048, 0.171, 0.173]
[42,  4000] loss: [0.044, 0.047, 0.181, 0.177]
[42,  6000] loss: [0.052, 0.046, 0.183, 0.193]
[42,  8000] loss: [0.052, 0.060, 0.208, 0.196]
[42, 10000] loss: [0.053, 0.060, 0.192, 0.197]
[42, 12000] loss: [0.058, 0.051, 0.188, 0.185]
[43,  2000] loss: [0.039, 0.038, 0.172, 0.151]
[43,  4000] loss: [0.038, 0.038, 0.172, 0.171]
[43,  6000] loss: [0.045, 0.050, 0.207, 0.180]
[43,  8000] loss: [0.056, 0.048, 0.197, 0.225]
[43, 10000] loss: [0.058, 0.042, 0.188, 0.209]
[43, 12000] loss: [0.053, 0.055, 0.195, 0.200]
[44,  2000] loss: [0.041, 0.045, 0.154, 0.161]
[44,  4000] loss: [0.043, 0.041, 0.160, 0.198]
[44,  6000] loss: [0.032, 0.047, 0.165, 0.193]
[44,  8000] l