# Ordered SGD on classification w/ pytorch
by Dizhi Ma

## Define some util function
Here are some utility functions for the experiments

In [None]:
import torchvision.transforms as transforms
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import time
from tqdm import tqdm
import numpy as np

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.memory = 100

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def get_data(type='MINIST'):
    batch_size_train, batch_size_test = 128, 1000 
    data_path = '/data'

    if type == 'CIFAR10':
        epoch = 50
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_data = torchvision.datasets.CIFAR10(data_path, train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.CIFAR10(data_path, train=False, download=True, transform=train_transform)

    elif type == 'CIFAR100':
        epoch = 50
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_data = torchvision.datasets.CIFAR100(data_path, train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.CIFAR100(data_path, train=False, download=True, transform=train_transform)
    
    elif type == 'MNIST':
        train_transform = transforms.Compose([
                            transforms.Grayscale(num_output_channels=3),
                            transforms.ToTensor()       
                       ])                
        train_data = torchvision.datasets.MNIST(data_path, train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.MNIST(data_path, train=False, download=True, transform=train_transform)
    
    elif type == 'FashionMINST':
        train_transform = transforms.Compose([
                            transforms.Grayscale(num_output_channels=3),
                            transforms.ToTensor()       
                       ])   

        train_data = torchvision.datasets.FashionMNIST(data_path, train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.FashionMNIST(data_path, train=False, download=True, transform=train_transform)
    
    elif type == 'SVHN':
        train_transform = transforms.Compose([
                            transforms.Grayscale(num_output_channels=3),
                            transforms.ToTensor()       
                       ])   

        train_data = torchvision.datasets.SVHN(data_path, split='train', download=True, transform=train_transform)
        test_data = torchvision.datasets.SVHN(data_path, split='test', download=True, transform=train_transform)
    

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size_test, shuffle=False)

    return train_loader, test_loader

class OurCNN(nn.Module): 

    def __init__(self):
        super(OurCNN, self).__init__()

        self.conv1 = nn.Conv2d( 3, 8, kernel_size=3)
        self.conv2 = nn.Conv2d( 8, 16, kernel_size=3)
        self.fc = nn.Linear(576, 10)

    def forward(self, x):
        b,_,_,_ = x.size()
        x = self.conv1(x)        
        x = F.relu(F.max_pool2d(x,2)) 

        x = self.conv2(x)        
        x = F.relu(F.max_pool2d(x,2)) 

        x = x.view(b, -1)      
        x = self.fc(x)
        return x 

class trainer:
    def __init__(self, model='resnet18', dataset='CIFAR10', q=None, adaptive=False, base_lr=0.01):
        if model == 'resnet18':
            model = models.resnet18(pretrained=True)
            model = model.cuda()
        elif model == 'mobilenet_v2':
            model = models.mobilenet_v2(pretrained=True)
            model = model.cuda()
        elif model == 'mobilenet_v3_large':
            model = models.mobilenet_v3_large(pretrained=True)
            model = model.cuda()
        elif model == 'OurNN':
            model = OurCNN().cuda()
        
        self.classifier = model
        self.base_lr = base_lr
        self.optimizer = optim.SGD(self.classifier.parameters(), lr=self.base_lr, momentum=0.8)
        self.dataset = dataset
        self.train_loader, self.test_loader = get_data(dataset)
        self.q = q
        self.adaptive = adaptive

    def train(self, epoch, q, best_acc):

        self.classifier.train() 
        loss_func = nn.CrossEntropyLoss(reduction='none')

        pbar = tqdm(enumerate(self.train_loader))
        count = 0
        correct = 0
        train_losses = AverageMeter()
        train_acc = AverageMeter()

        for batch_idx, (images, targets) in pbar:

            images, targets = images.cuda(), targets.cuda()
            self.optimizer.zero_grad()
            output = self.classifier(images)
            loss = loss_func(output, targets) 

            bs = images.size()[0]
            count += bs
            ssize = bs if not q else q(bs, best_acc)

            loss = torch.mean(torch.sort(loss, dim=0, descending=True)[0][:min(ssize, bs)], dim=0)
            loss.backward()
            self.optimizer.step()

            pred = output.data.max(1, keepdim=True)[1] 
            correct += pred.eq(targets.data.view_as(pred)).sum() 

            train_losses.update(loss.item())
            train_acc.update(correct/count)

            if batch_idx == 0:
                b_size = ssize
            pbar.set_description(f'Epoch {epoch} [{count}/{len(self.train_loader.dataset)}]: bs:{b_size} Loss: {train_losses.avg:.2f} Acc: {100.*correct/count:.2f}%')

        return train_losses.avg, correct/count

    def test(self, epoch):

        self.classifier.eval() 
        loss_func = nn.CrossEntropyLoss()

        test_loss = 0
        correct = 0

        count = 0

        test_losses = AverageMeter()
        pbar = tqdm(self.test_loader)
        with torch.no_grad():
            for images, targets in pbar:
                # count_b += 1
                images, targets = images.cuda(), targets.cuda()
                output = self.classifier(images)
                test_loss = loss_func(output, targets).item()
                test_losses.update(test_loss)
                pred = output.data.max(1, keepdim=True)[1] 
                correct += pred.eq(targets.data.view_as(pred)).sum() 
                bs = images.size()[0]
                count += bs
                pbar.set_description(f'Epoch {epoch} [{count}/{len(self.test_loader.dataset)}]:Loss: {test_losses.avg:.2f} Acc: {100.*correct/count:.2f}%')

        test_loss = test_losses.avg

        return 100.*correct/len(self.test_loader.dataset)
    
    def fit(self):
        train_acces = []
        test_accs = []
        max_epoch = 20 # if self.dataset not in ['CIFAR100'] else 50

        start = time.time()
        best_acc = 0
        best_test = 0
        for epoch in range(1, max_epoch+1):
            if self.q is None and not self.adaptive:
                q = None
            elif self.adaptive:
                q = adaptive_q(best_acc)
            else:
                q = self.q

            loss, acc = self.train(epoch, q, best_acc)
            test_acc = self.test(epoch)

            best_acc = max(acc, best_acc)
            best_test = max(best_test, test_acc)
            train_acces.append(acc)
            test_accs.append(test_acc)
            if epoch == max_epoch-10:
                self.base_lr /= 10
                self.optimizer = optim.SGD(self.classifier.parameters(), lr=self.base_lr, momentum=0.8)
        print(f'\nbest acc {best_test}% time used {time.time()-start:.2f}')
        return train_acces, test_accs

class trainerPlus:
    def __init__(self, model='resnet18', dataset='CIFAR10', q=None, adaptive=False, base_lr=0.01):
        if model == 'resnet18':
            model = models.resnet18(pretrained=True)
            model = model.cuda()
        elif model == 'mobilenet_v2':
            model = models.mobilenet_v2(pretrained=True)
            model = model.cuda()
        elif model == 'mobilenet_v3_large':
            model = models.mobilenet_v3_large(pretrained=True)
            model = model.cuda()
        elif model == 'OurNN':
            model = OurCNN().cuda()
        
        self.classifier = model
        self.base_lr = base_lr
        self.optimizer = optim.SGD(self.classifier.parameters(), lr=self.base_lr, momentum=0.8)
        self.dataset = dataset
        self.train_loader, self.test_loader = get_data(dataset)
        self.q = q
        self.adaptive = adaptive

    def train(self, epoch, q, best_acc):

        self.classifier.train() 
        loss_func = nn.CrossEntropyLoss(reduction='none')

        pbar = tqdm(enumerate(self.train_loader))
        count = 0
        correct = 0
        train_losses = AverageMeter()
        train_acc = AverageMeter()

        for batch_idx, (images, targets) in pbar:

            images, targets = images.cuda(), targets.cuda()
            self.optimizer.zero_grad()
            output = self.classifier(images)
            loss = loss_func(output, targets) 

            bs = images.size()[0]
            count += bs
            ssize = bs if not q else q(bs, best_acc)

            loss = torch.mean(torch.sort(loss, dim=0, descending=True)[0][:min(ssize, bs)], dim=0)
            loss.backward()
            self.optimizer.step()

            pred = output.data.max(1, keepdim=True)[1] 
            correct += pred.eq(targets.data.view_as(pred)).sum() 

            train_losses.update(loss.item())
            train_acc.update(correct/count)

            if batch_idx == 0:
                b_size = ssize
                adaptive_lr = self.base_lr*np.sqrt(ssize/bs)
                self.optimizer = optim.SGD(self.classifier.parameters(), lr=adaptive_lr, momentum=0.8)
            pbar.set_description(f'Epoch {epoch} [{count}/{len(self.train_loader.dataset)}]: bs:{b_size} Loss: {train_losses.avg:.2f} Acc: {100.*correct/count:.2f}%')

        return train_losses.avg, correct/count

    def test(self, epoch):

        self.classifier.eval() 
        loss_func = nn.CrossEntropyLoss()

        test_loss = 0
        correct = 0

        count = 0

        test_losses = AverageMeter()
        pbar = tqdm(self.test_loader)
        with torch.no_grad():
            for images, targets in pbar:
                # count_b += 1
                images, targets = images.cuda(), targets.cuda()
                output = self.classifier(images)
                test_loss = loss_func(output, targets).item()
                test_losses.update(test_loss)
                pred = output.data.max(1, keepdim=True)[1] 
                correct += pred.eq(targets.data.view_as(pred)).sum() 
                bs = images.size()[0]
                count += bs
                pbar.set_description(f'Epoch {epoch} [{count}/{len(self.test_loader.dataset)}]:Loss: {test_losses.avg:.2f} Acc: {100.*correct/count:.2f}%')

        test_loss = test_losses.avg

        return 100.*correct/len(self.test_loader.dataset)
    
    def fit(self):
        train_acces = []
        test_accs = []
        max_epoch = 20 # if self.dataset not in ['CIFAR100'] else 50

        start = time.time()
        best_acc = 0
        best_test = 0
        for epoch in range(1, max_epoch+1):
            if self.q is None and not self.adaptive:
                q = None
            elif self.adaptive:
                q = adaptive_q(best_acc)
            else:
                q = self.q

            loss, acc = self.train(epoch, q, best_acc)
            test_acc = self.test(epoch)

            best_acc = max(acc, best_acc)
            best_test = max(best_test, test_acc)
            train_acces.append(acc)
            test_accs.append(test_acc)
            if epoch == max_epoch-10:
                self.base_lr = 0.001
        print(f'\nbest acc {best_test}% time used {time.time()-start:.2f}')
        return train_acces, test_accs

## q-value strategy (fixed and adaptive)
Here, we define q value strategy used in the experiments

In [None]:
def fix(num):
    def q(bs, epoch):
        return num
    return q

def adaptive_q(acc):
    train_acc = acc*100
    if train_acc >= 99.5:
        def q(bs, epoch):
            return max(bs // 16, 4)
    elif train_acc >= 95:
        def q(bs, epoch):
            return max(bs // 8, 8)
    elif train_acc >= 90:
        def q(bs, epoch):
            return max(bs // 4,16)
    elif train_acc >= 80:
        def q(bs, epoch):
            return max(bs // 2, 32)
    else:
        def q(bs, epoch):
            return bs
    return q

## Experiments

### Experiments with adptive q

In [None]:
sgd_cifar10 = trainer(dataset='CIFAR10').fit()
sgd_cifar100 = trainer(dataset='CIFAR100').fit()
sgd_MNIST = trainer(dataset='MNIST').fit()
sgd_svhn = trainer(dataset='SVHN').fit()

In [None]:
qsgd_cifar10 = trainer(dataset='CIFAR10', adaptive=True).fit()
qsgd_cifar100 = trainer(dataset='CIFAR100', adaptive=True).fit()
qsgd_MNIST = trainer(dataset='MNIST', adaptive=True).fit()
qsgd_svhn = trainer(dataset='SVHN', adaptive=True).fit()

In [None]:
sgd_cifar10_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR10').fit()
sgd_cifar100_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR100').fit()
sgd_MNIST_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='MNIST').fit()
sgd_svhn_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='SVHN').fit()

In [None]:
qsgd_cifar10_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR10', adaptive=True).fit()
qsgd_cifar100_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR100', adaptive=True).fit()
qsgd_MNIST_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='MNIST', adaptive=True).fit()
qsgd_svhn_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='SVHN', adaptive=True).fit()

### Ablation study on fixed q

In [None]:
qsgd_96 = trainer(dataset='CIFAR10', q=fix(96)).fit()
qsgd_64 = trainer(dataset='CIFAR10', q=fix(64)).fit()
qsgd_32 = trainer(dataset='CIFAR10', q=fix(32)).fit()

In [None]:
qsgd_96_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR10', q=fix(96)).fit()
qsgd_64_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR10', q=fix(64), base_lr=0.003).fit()
qsgd_32_mobilenet_v2 = trainer(model='mobilenet_v2', dataset='CIFAR10', q=fix(32), base_lr=0.003).fit()

### Improvement with learning rate scheme

In [None]:
qsgd_plus_96_mobilenet_v2 = trainerPlus(model='mobilenet_v2', dataset='CIFAR10', q=fix(96)).fit()
qsgd_plus_64_mobilenet_v2 = trainerPlus(model='mobilenet_v2', dataset='CIFAR10', q=fix(64)).fit()
qsgd_plus_32_mobilenet_v2 = trainerPlus(model='mobilenet_v2', dataset='CIFAR10', q=fix(32)).fit()

## Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

length = 20
cmap = get_cmap(6)
fig = plt.figure(figsize=(12, 12))

ax = fig.add_subplot(2,1,1)
ax.plot(range(length),[d.item()*100 for d in sgd_cifar10[0]],
        linewidth = '2',color='k', label='SGD')
ax.plot(range(length),[d.item()*100 for d in qsgd_cifar10[0]],
        linewidth = '2',color='r', label='OSGD')
ax.plot(range(length),[d.item()*100 for d in qsgd_96[0]],
        linewidth = '2',color=cmap(2), label='OSGD q=96')
ax.plot(range(length),[d.item()*100 for d in qsgd_64[0]],
        linewidth = '2',color=cmap(3), label='OSGD q=64')
ax.plot(range(length),[d.item()*100 for d in qsgd_32[0]],
        linewidth = '2',color=cmap(4), label='OSGD q=32')
# ax.set_xlabel('epoch', fontsize=20)
ax.set_ylabel('train accuracy(%)', fontsize=20)
# ax.set_title(f'Accuracy', fontsize=15)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,1,2)
ax.plot(range(length),[d.item() for d in sgd_cifar10[1]],
        linewidth = '2',color='k', label='SGD')
ax.plot(range(length),[d.item() for d in qsgd_cifar10[1]],
        linewidth = '2',color='r', label='OSGD')
ax.plot(range(length),[d.item() for d in qsgd_96[1]],
        linewidth = '2',color=cmap(2), label='OSGD q=96')
ax.plot(range(length),[d.item() for d in qsgd_64[1]],
        linewidth = '2',color=cmap(3), label='OSGD q=64')
ax.plot(range(length),[d.item() for d in qsgd_32[1]],
        linewidth = '2',color=cmap(4), label='OSGD q=32')
ax.set_xlabel('epoch', fontsize=20)
ax.set_ylabel('test accuracy(%)', fontsize=20)
# ax.set_title(f'Accuracy', fontsize=15)
ax.legend(prop={'size': 18})



fig.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

length = 20
cmap = get_cmap(4)
fig = plt.figure(figsize=(12, 12))

ax = fig.add_subplot(2,1,1)
ax.plot(range(length),[d.item()*100 for d in qsgd_96_mobilenet_v2[0]],'--',
        linewidth = '2',color='k', label='OSGD q=96')
ax.plot(range(length),[d.item()*100 for d in qsgd_plus_96_mobilenet_v2[0]],
        linewidth = '2',color='k', label='OSGD q=96*')
ax.plot(range(length),[d.item()*100 for d in qsgd_64_mobilenet_v2[0]],'--',
        linewidth = '2',color='r', label='OSGD q=64')
ax.plot(range(length),[d.item()*100 for d in qsgd_plus_64_mobilenet_v2[0]],
        linewidth = '2',color='r', label='OSGD q=64*')
ax.plot(range(length),[d.item()*100 for d in qsgd_32_mobilenet_v2[0]],'--',
        linewidth = '2',color='g', label='OSGD q=32')
ax.plot(range(length),[d.item()*100 for d in qsgd_plus_32_mobilenet_v2[0]],
        linewidth = '2',color='g', label='OSGD q=32*')
# ax.set_xlabel('epoch', fontsize=20)
ax.set_ylabel('train accuracy(%)', fontsize=20)
# ax.set_title(f'Accuracy', fontsize=15)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,1,2)
ax.plot(range(length),[d.item() for d in qsgd_96_mobilenet_v2[1]],'--',
        linewidth = '2',color='k', label='OSGD q=96')
ax.plot(range(length),[d.item() for d in qsgd_plus_96_mobilenet_v2[1]],
        linewidth = '2',color='k', label='OSGD q=96*')
ax.plot(range(length),[d.item() for d in qsgd_64_mobilenet_v2[1]],'--',
        linewidth = '2',color='r', label='OSGD q=64')
ax.plot(range(length),[d.item() for d in qsgd_plus_64_mobilenet_v2[1]],
        linewidth = '2',color='r', label='OSGD q=64*')
ax.plot(range(length),[d.item() for d in qsgd_32_mobilenet_v2[1]],'--',
        linewidth = '2',color='g', label='OSGD q=32')
ax.plot(range(length),[d.item() for d in qsgd_plus_32_mobilenet_v2[1]],
        linewidth = '2',color='g', label='OSGD q=32*')
ax.set_xlabel('epoch', fontsize=20)
ax.set_ylabel('test accuracy(%)', fontsize=20)
# ax.set_title(f'Accuracy', fontsize=15)
ax.legend(prop={'size': 18})
fig.show()

In [None]:
length = 20
cmap = get_cmap(2)
f_size = 30
fig = plt.figure(figsize=(48, 15))
ax = fig.add_subplot(2,4,1)
ax.plot(range(length),[d.item() for d in sgd_cifar10[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_cifar10[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_cifar10[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_cifar10[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
# ax.set_xlabel('epoch', fontsize=f_size)
ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'CIFAR10 & ResNet18', fontsize=f_size)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,4,5)
ax.plot(range(length),[d.item() for d in sgd_cifar10_mobilenet_v2[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_cifar10_mobilenet_v2[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_cifar10_mobilenet_v2[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_cifar10_mobilenet_v2[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
ax.set_xlabel('epoch', fontsize=f_size)
ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'CIFAR10 & MobileNetV2', fontsize=f_size)

ax = fig.add_subplot(2,4,2)
ax.plot(range(length),[d.item() for d in sgd_cifar100[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_cifar100[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_cifar100[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_cifar100[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
# ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'CIFAR100 & ResNet18', fontsize=f_size)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,4,6)
ax.plot(range(length),[d.item() for d in sgd_cifar100_mobilenet_v2[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_cifar100_mobilenet_v2[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_cifar100_mobilenet_v2[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_cifar100_mobilenet_v2[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'CIFAR100 & MobileNetV2', fontsize=f_size)

ax = fig.add_subplot(2,4,3)
ax.plot(range(length),[d.item() for d in sgd_MNIST[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_MNIST[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_MNIST[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_MNIST[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
# ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'MNIST & ResNet18', fontsize=f_size)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,4,7)
ax.plot(range(length),[d.item() for d in sgd_MNIST_mobilenet_v2[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_MNIST_mobilenet_v2[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_MNIST_mobilenet_v2[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_MNIST_mobilenet_v2[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'MNIST & MobileNetV2', fontsize=f_size)
# ax.legend(prop={'size': 18})

ax = fig.add_subplot(2,4,4)
ax.plot(range(length),[d.item() for d in sgd_svhn[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_svhn[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_svhn[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_svhn[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
# ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'SVHN & ResNet18', fontsize=f_size)
# ax.legend(prop={'size': 20})

ax = fig.add_subplot(2,4,8)
ax.plot(range(length),[d.item() for d in sgd_svhn_mobilenet_v2[1]],
        linewidth = '2',color='k', label='SGD test')
ax.plot(range(length),[d.item()*100 for d in sgd_svhn_mobilenet_v2[0]], 'k--',
        linewidth = '2',color='k', label='SGD train')
ax.plot(range(length),[d.item() for d in qsgd_svhn_mobilenet_v2[1]],
        linewidth = '2',color='r', label='OSGD test')
ax.plot(range(length),[d.item()*100 for d in qsgd_svhn_mobilenet_v2[0]], 'r--',
        linewidth = '2',color='r', label='OSGD train')
ax.set_xlabel('epoch', fontsize=f_size)
# ax.set_ylabel('accuracy(%)', fontsize=f_size)
ax.set_title(f'SVHN & MobileNetV2', fontsize=f_size)
ax.legend(prop={'size': f_size})

fig.show()

## ReLU Linear Regression Via ordered SGD (Preliminary Implemetation)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from numpy.core.fromnumeric import mean
import random

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.memory = 100

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def ReLU(x):
    return x * (x > 0)

def predict(x,w):
    return ReLU(w.dot(x))

def full_q(e, epochs, batch_size):
    return batch_size

def naive_q(e, epochs, batch_size):
    return batch_size//2

def Train(data,w_star,epochs, lr, batch_size=64, mode='0'):
    '''
    regular SGD training
    input:
        data  :nparray Nxd 
        w     :nparray d
        epoch :int
        lr    :float
        batch_size: int
        mode  :'0' for replacement, '1' for no replacement
    '''
    train_error=[]
    smooth_error = []
    #initialize
    w=np.random.randn(data[0].shape[0])
    b=np.random.randn(1)
    noise = np.random.randn(2000)*(0.05)**2
    order=list(range(len(data)))
    order_list = []
    while len(order_list) < epochs*batch_size:
        order_list +=random.sample(order, len(order))

    train_losses = AverageMeter()
    batch_losses = AverageMeter()
    pbar = tqdm(range(epochs))
    for e in pbar:
        dw = [] # cache dw
        batch_losses = AverageMeter()
        for b in range(batch_size):
            i = random.choice(order)
            a = predict(data[i],w)
            l = (predict(data[i],w_star)+noise[i]) - a
            dw.append(lr*l*data[i]*(a>0)) # gradient is 0 if a<0
            train_losses.update(l**2)
            batch_losses.update(l**2)
        dw = np.array(dw)
        w += np.mean(dw, axis=0)
        train_error.append(train_losses.avg)
        smooth_error.append(batch_losses.avg)

        if batch_size == 1:
            pbar.set_description(f"epoch:{e} loss:{train_losses.avg}")
        else:
            pbar.set_description(f"epoch:{e} loss:{batch_losses.avg}")
        
    return [train_error, smooth_error]

def TrainOSGD(data,w_star,epochs, lr, batch_size=64, strategy=naive_q):
    '''
    input:
        data  :nparray Nxd 
        w     :nparray d
        epoch :int
        lr    :float
        batch_size: int
        mode  :'0' for replacement, '1' for no replacement
    '''
    train_error=[]
    smooth_error = []
    #initialize
    w=np.random.randn(data[0].shape[0])
    b=np.random.randn(1)
    noise = np.random.randn(2000)*(0.05)**2
    order=list(range(len(data)))
    order_list = []
    while len(order_list) < epochs*batch_size:
        order_list +=random.sample(order, len(order))
    train_losses = AverageMeter()
    batch_losses = AverageMeter()
    pbar = tqdm(range(epochs))
    for e in pbar:
        dw = [] # cache dw
        loss_batch = []
        batch_losses = AverageMeter()
        for b in range(batch_size):
            i = random.choice(order)
            a = predict(data[i],w)
            l = (predict(data[i],w_star)+noise[i]) - a
            loss_batch.append(l.item()**2)
            dw.append(lr*l*data[i]*(a>0))
            train_losses.update(l**2)
            batch_losses.update(l**2)
        
        q = strategy(e, epochs, len(loss_batch))        

        q_idx = np.argsort(np.array(loss_batch),axis=0)[-q:]
        dw = np.array(dw)[q_idx]

        w += np.mean(dw, axis=0)
        train_error.append(train_losses.avg)
        smooth_error.append(batch_losses.avg)
        pbar.set_description(f"epoch:{e} qvalue: {q} loss:{batch_losses.avg:.6f}")

    return [train_error, smooth_error]


def plot_mistake(trainError,filename='1'):
    '''
    plot a figure for error vs epochs
    '''
    length=len(trainError[0])
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(1,1,1)
    ax.plot(range(length),trainError[0],
            linewidth = '2',color='b',label='SGD training data error')
    ax.plot(range(length),trainError[1],
            linewidth = '2',color='G',label='Order-SGD training data error')
    ax.set_xlabel('iter', fontsize=20)
    ax.set_ylabel('loss', fontsize=20)

    ax.set_title(f'{filename}', fontsize=15)
    ax.legend(prop={'size': 15})

In [None]:
d, N, T, lr, b_size = 500, 2000, 2000, 5e-2, 100
x = np.random.randn(N, d)
e = np.random.randn(N, 1)*(0.05)
w_star = np.random.randn(d)

error,_ = TrainOSGD(x,w_star,T,lr, b_size, full_q)
OSGD_error,_ = TrainOSGD(x,w_star,T,lr, b_size)

plot_mistake([error, OSGD_error], 'loss vs iter')