In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import math



In [8]:
# loading and preprocessing MNIST

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

# MNIST dataset
train_set = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transform, download=True)
test_set = torchvision.datasets.MNIST(root='../../data/', train=False, transform=transform, download=True)

# Data loader
trainloader = torch.utils.data.DataLoader(dataset=train_set,
                                          batch_size=32, 
                                          shuffle=True)

testloader = torch.utils.data.DataLoader(dataset=test_set,
                                          batch_size=32, 
                                          shuffle=False)

In [9]:
## network
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
net = MLP()

In [10]:
## creating a function that initializes weights
# new init weight        
def init_weight(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight) #he initialize, can use xavier instead
        #nn.init.constant_(m.bias, 0.001) # optional bias
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight) #he initialize, can use xavier instead
        #nn.init.constant_(m.bias, 0.001) # optional bias
    elif type(m) == nn.BatchNorm2d:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 1)

In [11]:
# apply initializers
net.apply(init_weight)


MLP(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=10, bias=True)
)

In [16]:
# define loss and optimizer
import torch.optim as optim
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(net.parameters(), lr=learning_rate)
optimizer = optim.SGD(net.parameters(), lr = learning_rate, momentum=0.9, nesterov= True, weight_decay= 0.01)


In [17]:
# LR scheduler
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingLR_with_Restart(_LRScheduler):
    """Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::

        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
        \cos(\frac{T_{cur}}{T_{max}}\pi))

    When last_epoch=-1, sets initial lr as lr.

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. The original pytorch
    implementation only implements the cosine annealing part of SGDR,
    I added my own implementation of the restarts part.
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations. (LENGTH OF 1 CYCLE)
        T_mult (float): Increase T_max by a factor of T_mult
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
        model (pytorch model): The model to save.
        out_dir (str): Directory to save snapshots
        take_snapshot (bool): Whether to save snapshots at every restart

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_max, T_mult, model, out_dir, take_snapshot, eta_min=0, last_epoch=-1):
        self.T_max = T_max
        self.T_mult = T_mult
        self.Te = self.T_max
        self.eta_min = eta_min
        self.current_epoch = last_epoch
        
        self.model = model
        self.out_dir = out_dir
        self.take_snapshot = take_snapshot
        
        self.lr_history = []
        
        super(CosineAnnealingLR_with_Restart, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        new_lrs = [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2

                for base_lr in self.base_lrs]
        
        self.lr_history.append(new_lrs)
        return new_lrs
    
    def step(self, epoch=None):
        if epoch is None:
        
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1
        
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
        
        ## restart
        if self.current_epoch == self.Te:
            print("restart at epoch {:03d}".format(self.last_epoch + 1))
            
            if self.take_snapshot:
                torch.save({
                    'epoch': self.T_max,
                    'state_dict': self.model.state_dict()
                }, self.out_dir + "/" + 'snapshot_e_{:03d}.pth.tar'.format(self.T_max))
            
            ## reset epochs since the last reset
            self.current_epoch = 0
            
            ## reset the next goal
            self.Te = int(self.Te * self.T_mult)
            self.T_max = self.T_max + self.Te

In [18]:
# T_max = how many Epochs before restarting learning rate
# T_mult = increase cycle length after restart 

scheduler = CosineAnnealingLR_with_Restart(optimizer, T_max=1, T_mult=1, model = net,  out_dir='blank', take_snapshot=False)

In [19]:
# modified model training to keep track of train/val loss
n_epochs = 5

for epoch in range(n_epochs):
    scheduler.step()
    running_loss = 0.0
    total_train_loss = 0.0
    for i, train_data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = train_data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print loss per n minibatches
        running_loss += loss.item()
        total_train_loss += loss.item()
        if i % 500 == 499:    # print every 500 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 500))
            running_loss = 0.0
    
    # keep track of loss in test dataset 
    correct = 0
    total = 0
    total_test_loss = 0.0
    with torch.no_grad():
        for test_data in testloader:
            test_images, test_labels = test_data
            test_outputs = net(test_images)
            test_loss = criterion(test_outputs, test_labels)
            total_test_loss += test_loss.item()
            _, predicted = torch.max(test_outputs.data, 1)
            total += test_labels.size(0)
            correct += (predicted == test_labels).sum().item()


    
    
    # for printing average loss every epoch
    print("===> Epoch {} Complete: Train Avg. Loss: {:.4f}".format(epoch+1, total_train_loss / len(trainloader)))
    print("===> Epoch {} Complete: Test Avg. Loss: {:.4f}".format(epoch+1, total_test_loss / len(testloader)))
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
print('Finished Training')

restart at epoch 001
[1,   500] loss: 2.660
[1,  1000] loss: 2.651
[1,  1500] loss: 2.652
===> Epoch 1 Complete: Train Avg. Loss: 2.6504
===> Epoch 1 Complete: Test Avg. Loss: 2.6314
Accuracy of the network on the 10000 test images: 15 %
restart at epoch 002
[2,   500] loss: 2.642
[2,  1000] loss: 2.650
[2,  1500] loss: 2.659
===> Epoch 2 Complete: Train Avg. Loss: 2.6504
===> Epoch 2 Complete: Test Avg. Loss: 2.6314
Accuracy of the network on the 10000 test images: 15 %
restart at epoch 003
[3,   500] loss: 2.646
[3,  1000] loss: 2.649
[3,  1500] loss: 2.655
===> Epoch 3 Complete: Train Avg. Loss: 2.6504
===> Epoch 3 Complete: Test Avg. Loss: 2.6314
Accuracy of the network on the 10000 test images: 15 %
restart at epoch 004
[4,   500] loss: 2.641
[4,  1000] loss: 2.658
[4,  1500] loss: 2.641
===> Epoch 4 Complete: Train Avg. Loss: 2.6504
===> Epoch 4 Complete: Test Avg. Loss: 2.6314
Accuracy of the network on the 10000 test images: 15 %
restart at epoch 005
[5,   500] loss: 2.658
[5, 