In [1]:
from torch.nn import Module
from torch import nn


In [23]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=(2,2))
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=(2,2))
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(3136, 1000)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = self.dropout1(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.dropout2(y)
        y = self.fc2(y)
        return y

In [24]:
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import *
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import time

In [25]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
torch.cuda.device_count()    

2

In [26]:
import torch
from torch.optim.optimizer import Optimizer


class wame(Optimizer):


    def __init__(self, params,lr=1e-3, alpha=0.9, etas=(0.1, 1.2), zetas=(0.01, 100), epsilon=1e-10, debug=False ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= alpha:
            raise ValueError("Invalid learning rate: {}".format(alpha))
        if not 0.0 < etas[0] < 1.0 < etas[1]:
            raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))

        defaults = dict(alpha=alpha, etas=etas, zetas=zetas, lr=lr, epsilon=epsilon, debug=debug)
        super(wame, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
   
        for group in self.param_groups:    
            etaminus, etaplus = group['etas']
            zeta_min, zeta_max = group['zetas']
            alpha = group['alpha']
            lr = group['lr']
            epsilon = group['epsilon']
            debug = group['debug']
            _p = 0

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('wame does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['prev'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
                    state['prev'] = grad
                    state['Theta'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
                    state['Z'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
                    state['zeta'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
                    state['gradmult'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
                    
                    state['step_size'] = grad.new().resize_as_(grad).fill_(group['lr']) ### check on initialisation
                    state['prev'] = grad
                    
                if (_p == 0) & debug:
                    print("Weights at start: ",p.data.sum(), p.data.size())
                step_size = state['step_size']
                Theta = state['Theta']
                Z = state['Z']
                zeta = state['zeta']
                gradmult = state['gradmult']                    
                
                state['step'] += 1

                gradmult = grad.mul(state['prev'])
                zeta[gradmult.gt(0.)] = zeta[gradmult.gt(0.)].mul(etaplus).clamp(zeta_min, zeta_max)
                zeta[gradmult.lt(0.)] = zeta[gradmult.lt(0.)].mul(etaminus).clamp(zeta_min, zeta_max)
                zeta[gradmult.eq(0.)] = 1
                
                Z = Z.mul(alpha).add(zeta.mul(1 - alpha))
                Theta = Theta.mul(alpha).add(grad.mul(grad).mul(1 - alpha))

                step_size = Z.mul(-lr).mul(grad).div(Theta.add(epsilon))

                grad = grad.clone(memory_format=torch.preserve_format)
                
                # update parameters
                p.data = p.data.add(step_size)
                
                if (_p == 0) & debug:
                    print("Step number ", state['step'])
                    print("Weights after update: ",p.data.sum(), p.data.size())
                    print("Updates applied: ",step_size.sum(), step_size.size())
                    print("Theta: ",Theta.sum())
                    print("Z: ",Z.sum())
                    print("zeta: ",zeta.sum())
                    print("grad: ",grad.sum())
                    print("grad_sqd", grad_sqd.sum())
                    print("gradmult", gradmult.sum())
                state['prev'] = grad
                
                _p = _p + 1

        return loss


In [27]:
from torchsummary import summary

batch_size = 1024
path = '/home/bernard/Documents/BBK/CourseWork/ML/LeNet5-MNIST-PyTorch/'
train_dataset = mnist.MNIST(root=path + 'train', train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root=path + 'test', train=False, transform=ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)

model = Model()
model = torch.nn.DataParallel(model).cuda()
summary(model, (1, 28, 28))
print(model)


opt = wame(model.parameters(), lr=1e-7)
loss = CrossEntropyLoss()
epoch = 20

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             832
            Conv2d-2           [-1, 32, 28, 28]             832
              ReLU-3           [-1, 32, 28, 28]               0
              ReLU-4           [-1, 32, 28, 28]               0
         MaxPool2d-5           [-1, 32, 14, 14]               0
         MaxPool2d-6           [-1, 32, 14, 14]               0
            Conv2d-7           [-1, 64, 14, 14]          51,264
            Conv2d-8           [-1, 64, 14, 14]          51,264
              ReLU-9           [-1, 64, 14, 14]               0
        MaxPool2d-10             [-1, 64, 7, 7]               0
             ReLU-11           [-1, 64, 14, 14]               0
          Dropout-12             [-1, 64, 7, 7]               0
        MaxPool2d-13             [-1, 64, 7, 7]               0
          Dropout-14             [-1, 6

In [28]:
time0 = time.perf_counter()
torch.manual_seed(7)
for _epoch in range(epoch):
    batch = 0
    for idx, (train_x, train_label) in enumerate(train_loader):
        train_x = train_x.cuda()
        train_label = train_label.cuda()
        label_np = np.zeros((train_label.shape[0], 10))
        opt.zero_grad()
        predict_y = model(train_x.float())
        _error = loss(predict_y, train_label.long())
        _error.backward()
        opt.step()

        batch = batch + 1
    
    correct = 0
    _sum = 0

    for idx, (test_x, test_label) in enumerate(test_loader):
        test_x = test_x.cuda()
        test_label = test_label
        predict_y = model(test_x.float()).detach().cpu()
        predict_ys = np.argmax(predict_y, axis=-1)
        label_np = test_label.numpy()
        _ = predict_ys == test_label
        correct += np.sum(_.numpy(), axis=-1)
        _sum += _.shape[0]

    time1 = time.perf_counter()
    print('epoch: {}  accuracy: {:.6f} time {:.3f} seconds'.format(_epoch + 1, correct / _sum, time1 - time0))
    time0 = time1
    #torch.save(model, path + 'models/mnist_{:.2f}.pkl'.format(correct / _sum))

epoch: 1  accuracy: 0.912500 time 1.492 seconds
epoch: 2  accuracy: 0.935200 time 1.447 seconds
epoch: 3  accuracy: 0.937800 time 1.457 seconds
epoch: 4  accuracy: 0.942500 time 1.450 seconds
epoch: 5  accuracy: 0.945800 time 1.449 seconds
epoch: 6  accuracy: 0.944400 time 1.484 seconds
epoch: 7  accuracy: 0.947200 time 1.464 seconds
epoch: 8  accuracy: 0.949500 time 1.459 seconds
epoch: 9  accuracy: 0.952500 time 1.525 seconds
epoch: 10  accuracy: 0.954100 time 1.512 seconds
epoch: 11  accuracy: 0.955700 time 1.534 seconds
epoch: 12  accuracy: 0.954100 time 1.500 seconds
epoch: 13  accuracy: 0.956700 time 1.483 seconds
epoch: 14  accuracy: 0.958700 time 1.444 seconds
epoch: 15  accuracy: 0.959900 time 1.445 seconds
epoch: 16  accuracy: 0.958600 time 1.480 seconds
epoch: 17  accuracy: 0.960600 time 1.478 seconds
epoch: 18  accuracy: 0.959600 time 1.453 seconds
epoch: 19  accuracy: 0.961300 time 1.458 seconds
epoch: 20  accuracy: 0.962400 time 1.439 seconds


In [9]:
print(torch.__version__)
print(torch.version.cuda)

1.5.0
10.2
