In [4]:
#misc
import os
import pickle as pkl
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
import seaborn as sns
import torch
import torchvision
from torch.utils import data
from torchvision import transforms

optim_dict = {
		'sgd': {
			'label': 'SGD',
			'lr': 1e-3
		},
		'sgd_momentum': {
			'label': 'SGD w/ momentum',
			'lr': 1e-3,
			'mu': 0.99
		},
		'sgd_nesterov': {
			'label': 'SGD w/ Nesterov momentum',
			'lr': 1e-3,
			'mu': 0.99,
			'nesterov': True
		},
		'sgd_weight_decay': {
			'label': 'SGDW',
			'lr': 1e-3,
			'mu': 0.99,
			'weight_decay': 1e-6
		},
		'sgd_lrd': {
			'label': 'SGD w/ momentum + LRD',
			'lr': 1e-3,
			'mu': 0.99,
			'lrd': 0.5
		},
		'adam': {
			'label': 'Adam',
			'lr': 1e-3
		},
		'adamW':{
			'label': 'AdamW',
			'lr': 1e-3,
			'weight_decay': 1e-4
		},
		'adam_l2':{
			'label': 'AdamL2',
			'lr': 1e-3,
			'l2_reg': 1e-4
		},
		'adam_lrd': {
			'label': 'Adam w/ LRD',
			'lr': 1e-3,
			'lrd': 0.5
		},
		'Radam': {
			'label': 'RAdam',
			'lr': 1e-3,
			'rectified': True
		},
		'RadamW': {
			'label': 'RAdamW',
			'lr': 1e-3,
			'rectified': True,
			'weight_decay': 1e-4
		},
		'Radam_lrd': {
			'label': 'RAdam w/ LRD',
			'lr': 1e-3,
			'rectified': True,
			'lrd': 0.5
		},
		'nadam': {
			'label': 'Nadam',
			'lr': 1e-3,
			'nesterov': True
		},
		'rmsprop': {
			'label': 'RMSprop',
			'lr': 1e-3,
			'beta2': 0.9,
		},
		'lookahead_sgd': {
			'label': 'Lookahead (SGD)',
			'lr': 1e-3,
			'mu': 0.99
		},
		'lookahead_adam': {
			'label': 'Lookahead (Adam)',
			'lr': 1e-3
		},
		'gradnoise_adam': {
			'label': 'Gradient Noise (Adam)',
			'lr': 1e-3
		},
		'graddropout_adam': {
			'label': 'Gradient Dropout (Adam)',
			'lr': 1e-3
		}
	}


def split_optim_dict(d:dict) -> tuple:
	"""
	Splits an optimization dict into label and dict.
	"""
	temp_d = deepcopy(d)
	label = temp_d['label']
	del temp_d['label']

	return label, temp_d


def load_cifar(num_train=50000, num_val=2048):
	"""
	Loads a subset of the CIFAR dataset and returns it as a tuple.
	"""
	transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

	train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
	val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

	train_dataset, _ = torch.utils.data.random_split(train_dataset, lengths=[num_train, len(train_dataset)-num_train])
	val_dataset, _ = torch.utils.data.random_split(val_dataset, lengths=[num_val, len(val_dataset)-num_val])

	return train_dataset, val_dataset


def load_mnist(filename='data/mnist.npz', num_train=4096, num_val=512):
	"""
	Loads a subset of the grayscale MNIST dataset and returns it as a tuple.
	"""
	data = np.load(filename)

	x_train = data['X_train'][:num_train].astype('float32')
	y_train = data['y_train'][:num_train].astype('int32')

	x_valid = data['X_valid'][:num_val].astype('float32')
	y_valid = data['y_valid'][:num_val].astype('int32')

	train_dataset = Dataset(x_train, y_train)
	val_dataset = Dataset(x_valid, y_valid)

	return train_dataset, val_dataset


def task_to_optimizer(task:str) -> torch.optim.Optimizer:
	"""
	Takes a task as string and returns its respective optimizer class.
	"""
	optimizer = None

	if 'sgd' in task.lower():
		optimizer = getattr(optimizers, 'SGD')
	if 'adam' in task.lower():
		optimizer = getattr(optimizers, 'Adam')
	if 'rmsprop' in task.lower():
		optimizer = getattr(optimizers, 'RMSProp')
	
	if optimizer is None:
		raise ValueError(f'Optimizer for task \'{task}\' was not recognized!')

	return optimizer


def wrap_optimizer(task:str, optimizer):
	"""
	Wraps an instantiated optimizer according to its task specified as a string.
	"""
	if 'gradnoise' in task.lower():
		optimizer = optimizers.GradientNoise(optimizer, eta=0.3, gamma=0.55)

	if 'graddropout' in task.lower():
		optimizer = optimizers.GradientDropout(optimizer, grad_retain=0.9)

	if 'lookahead' in task.lower():
		optimizer = optimizers.Lookahead(optimizer, k=5, alpha=0.5)

	return optimizer


class AvgLoss():
	"""
	Utility class that tracks the average loss.
	"""
	def __init__(self):
		self.sum, self.avg, self.n = 0, 0, 0
		self.losses = []

	def __iadd__(self, other):
		try:
			loss = other.data.numpy()
		except:
			loss = other
		
		if isinstance(other, list):
			self.losses.extend(other)
			self.sum += np.sum(other)
			self.n += len(other)
		else:
			self.losses.append(float(loss))
			self.sum += loss
			self.n += 1

		self.avg = self.sum / self.n

		return self

	def __str__(self):
		return '{0:.4f}'.format(round(self.avg, 4))

	def __len__(self):
		return len(self.losses)


class Dataset(data.Dataset):
	def __init__(self, X, y):
		self.X = X
		self.y = y

	def __len__(self):
		return len(self.X)

	def __getitem__(self, idx):
		return self.X[idx], self.y[idx]


def save_losses(losses, dataset:str, filename:str):
	if not os.path.exists(f'losses_{dataset}/'): os.makedirs(f'losses_{dataset}/')
	with open(f'losses_{dataset}/{filename}.pkl', 'wb') as f:
		pkl.dump(losses, f, protocol=pkl.HIGHEST_PROTOCOL)


def load_losses(dataset:str, filename:str):
	try:
		with open(f'losses_{dataset}/{filename}.pkl', 'rb') as f:
			return pkl.load(f)
	except:
		return None


def plot_mnist(X):
	idx, dim, classes = 0, 28, 10
	canvas = np.zeros((dim*classes, classes*dim))

	for i in range(classes):
		for j in range(classes):
			canvas[i*dim:(i+1)*dim, j*dim:(j+1)*dim] = X[idx].reshape((dim, dim))
			idx += 1

	sns.set(style='darkgrid')
	plt.figure(figsize=(9, 9))
	plt.axis('off')
	plt.tight_layout(pad=0)
	plt.imshow(canvas, cmap='gray')
	plt.savefig('mnist_examples.png')
	plt.clf()


def plot_loss(losses, val_losses, num_epochs):
	sns.set(style='darkgrid')
	plt.figure(figsize=(12, 6))
	plt.plot(np.linspace(0, num_epochs, num=len(losses)), losses.losses, label='Training loss')
	plt.plot(np.linspace(0, num_epochs, num=len(val_losses)), val_losses.losses, label='Validation loss')
	plt.tight_layout(pad=2)
	plt.xlabel('Epoch')
	plt.ylabel('Negative log likelihood')
	plt.savefig('loss.png')
	plt.clf()


def plot_losses(losses, val_losses, labels, num_epochs, title, plot_val=False, yscale_log=False, max_epochs=None):
	sns.set(style='darkgrid')
	plt.figure(figsize=(12, 6))
	colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:cyan', 'tab:olive']

	for i in range(len(losses)):
		plt.plot(np.linspace(0, num_epochs, num=len(losses[i])), smooth(losses[i].losses, 81), label=labels[i], alpha=1, c=colors[i])
		plt.plot(np.linspace(0, num_epochs, num=len(losses[i])), smooth(losses[i].losses, 21), alpha=0.25, c=colors[i])
		if plot_val:
			plt.plot(np.linspace(0, num_epochs, num=len(val_losses[i])), smooth(val_losses[i].losses, 81), alpha=1, linestyle='--', c=colors[i])

	plt.tight_layout(pad=2)
	plt.xlabel('Epoch')
	plt.ylabel('Cross-entropy')
	if yscale_log:
		plt.yscale('log')
	if max_epochs is not None:
		plt.xlim(-1, max_epochs)
	plt.ylim(0, 3)
	plt.title('CNN benchmark on CIFAR-10' if title == 'cifar' else 'MLP benchmark on MNIST')
	plt.legend(loc='upper right')
	plt.savefig(f'loss_{title}.png')
	plt.clf()


def smooth(signal, kernel_size, polyorder=3):
	return savgol_filter(signal, kernel_size, polyorder)

##################################################
#networks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import utils
from torch.nn import Parameter


class MLP(nn.Module):
	"""
	A small multilayer perceptron with parameters that we can optimize for the task.
	"""
	def __init__(self, num_features=784, num_hidden=64, num_outputs=10):
		super(MLP, self).__init__()

		self.W_1 = Parameter(init.xavier_normal_(torch.Tensor(num_hidden, num_features)))
		self.b_1 = Parameter(init.constant_(torch.Tensor(num_hidden), 0))

		self.W_2 = Parameter(init.xavier_normal_(torch.Tensor(num_outputs, num_hidden)))
		self.b_2 = Parameter(init.constant_(torch.Tensor(num_outputs), 0))

	def forward(self, x):
		x = F.relu(F.linear(x, self.W_1, self.b_1))
		x = F.linear(x, self.W_2, self.b_2)

		return x


class CNN(nn.Module):
	"""
	A small convolutional neural network with parameters that we can optimize for the task.
	"""
	def __init__(self, num_layers=4, num_filters=64, num_classes=10, input_size=(3, 32, 32)):
		super(CNN, self).__init__()

		self.channels = input_size[0]
		self.height = input_size[1]
		self.width = input_size[2]
		self.num_filters = num_filters

		self.conv_in = nn.Conv2d(self.channels, self.num_filters, kernel_size=5, padding=2)
		cnn = []
		for _ in range(num_layers):
			cnn.append(nn.Conv2d(self.num_filters, self.num_filters, kernel_size=3, padding=1))
			cnn.append(nn.BatchNorm2d(self.num_filters))
			cnn.append(nn.ReLU())
		self.cnn = nn.Sequential(*cnn)

		self.out_lin = nn.Linear(self.num_filters*self.width*self.height, num_classes)

		if torch.cuda.is_available():
			self.cuda()


	def forward(self, x):
		if torch.cuda.is_available():
			x = x.cuda()

		x = F.relu(self.conv_in(x))
		x = self.cnn(x)
		x = x.reshape(x.size(0), -1)

		return self.out_lin(x)


def fit(net, data, optimizer, batch_size=128, num_epochs=250, lr_schedule=False):
	"""
	Fits parameters of a network `net` using `data` as training data and a given `optimizer`.
	"""
	train_generator = utils.data.DataLoader(data[0], batch_size=batch_size)
	val_generator = utils.data.DataLoader(data[1], batch_size=batch_size)

	losses = AvgLoss()
	val_losses = AvgLoss()

	if lr_schedule:
		scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)

	for epoch in range(num_epochs+1):

		epoch_loss = AvgLoss()
		epoch_val_loss = AvgLoss()

		for x, y in val_generator:
			y = y.type(torch.LongTensor)
			if torch.cuda.is_available(): y = y.cuda()
			epoch_val_loss += F.cross_entropy(net(x), y).cpu()

		for x, y in train_generator:
			y = y.type(torch.LongTensor)
			if torch.cuda.is_available(): y = y.cuda()
			loss = F.cross_entropy(net(x), y).cpu()
			epoch_loss += loss

			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

		if lr_schedule:
			scheduler.step(epoch_loss.avg)

		if epoch % 2 == 0:
			print(f'Epoch {epoch}/{num_epochs}, loss: {epoch_loss}, val loss: {epoch_val_loss}')

		losses += epoch_loss.losses
		val_losses += epoch_val_loss.losses

	return losses, val_losses

##################################################
#optimizers

import torch
import numpy as np
from torch.optim import Optimizer
from torch.distributions import Bernoulli, Normal


class SGD(Optimizer):
    """
    Stochastic gradient descent. Also includes implementations of momentum,
    Nesterov's momentum, L2 regularization, SGDW and Learning Rate Dropout.
    """
    def __init__(self, params, lr, mu=0, nesterov=False, weight_decay=0, lrd=1):
        defaults = {'lr': lr, 'mu': mu, 'nesterov': nesterov, 'weight_decay': weight_decay, 'lrd': lrd}
        super(SGD, self).__init__(params, defaults)

    def step(self):
        """
        Performs a single optimization step.
        """
        for group in self.param_groups:

            lr = group['lr']
            mu = group['mu']
            nesterov = group['nesterov']
            weight_decay = group['weight_decay']
            lrd_bernoulli = Bernoulli(probs=group['lrd'])

            if mu != 0 and 'v' not in group:
                group['v'] = []
                if nesterov:
                    group['theta'] = []
                for param in group['params']:
                    group['v'].append(torch.zeros_like(param))
                    if nesterov:
                        theta_param = torch.ones_like(param).mul_(param.data)
                        group['theta'].append(theta_param)

            for idx, param in enumerate(group['params']):
                param.grad.data -= weight_decay * param.data
                lrd_mask = lrd_bernoulli.sample(param.size()).to(param.device)

                if mu != 0:
                    v = group['v'][idx]
                    v = mu * v - lr * param.grad.data
                    group['v'][idx] = v

                    if nesterov:
                        group['theta'][idx] += lrd_mask * v
                        param.data = group['theta'][idx] + mu * v

                    else:
                        param.data += lrd_mask * v

                else:
                    param.data -= lrd_mask * lr * param.grad.data


class Adam(Optimizer):
    """
    Adam as proposed by https://arxiv.org/abs/1412.6980.
    Also includes a number of proposed extensions to the the Adam algorithm,
    such as Nadam, L2 regularization, AdamW, RAdam and Learning Rate Dropout.
    """
    def __init__(self, params, lr, beta1=0.9, beta2=0.999, nesterov=False, l2_reg=0, weight_decay=0, rectified=False, lrd=1, eps=1e-8):
        defaults = {'lr': lr, 'beta1': beta1, 'beta2': beta2, 'nesterov': nesterov, 'l2_reg': l2_reg,
                    'weight_decay': weight_decay, 'rectified': rectified, 'lrd': lrd, 'eps': eps}
        super(Adam, self).__init__(params, defaults)

    def step(self):
        """
        Performs a single optimization step.
        """
        for group in self.param_groups:

            lr = group['lr']
            beta1 = group['beta1']
            beta2 = group['beta2']
            nesterov = group['nesterov']
            l2_reg = group['l2_reg']
            weight_decay = group['weight_decay']
            rectified = group['rectified']
            lrd_bernoulli = Bernoulli(probs=group['lrd'])
            eps = group['eps']

            if 'm' not in group and 'v' not in group:
                group['m'] = []
                group['v'] = []
                group['t'] = 1
                if nesterov:
                    group['prev_grad'] = []
                for param in group['params']:
                    group['m'].append(torch.zeros_like(param))
                    group['v'].append(torch.zeros_like(param))
                    if nesterov:
                        group['prev_grad'].append(torch.zeros_like(param))

            for idx, param in enumerate(group['params']):
                if l2_reg:
                    param.grad.data += l2_reg * param.data

                if nesterov:
                    grad = group['prev_grad'][idx]
                else:
                    grad = param.grad.data

                lrd_mask = lrd_bernoulli.sample(param.size()).to(param.device)

                m = group['m'][idx]
                v = group['v'][idx]
                t = group['t']
                m = beta1 * m + (1 - beta1) * grad
                v = beta2 * v + (1 - beta2) * grad**2
                m_hat = m / (1 - beta1**t)
                v_hat = v / (1 - beta2**t)

                if nesterov:
                    group['prev_grad'][idx] = param.grad.data

                if rectified:
                    rho_inf = 2 / (1 - beta2) - 1
                    rho = rho_inf - 2 * t * beta2**t / (1 - beta2**t)
                    if rho >= 5:
                        numerator = (1 - beta2**t) * (rho - 4) * (rho - 2) * rho_inf
                        denominator = (rho_inf - 4) * (rho_inf - 2) * rho
                        r = np.sqrt(numerator / denominator)
                        param.data += - lrd_mask * lr * r * m_hat / (torch.sqrt(v) + eps)
                    else:
                        param.data += - lrd_mask * lr * m_hat
                else:
                    param.data += - lrd_mask * lr * m_hat / (torch.sqrt(v_hat) + eps)

                if weight_decay:
                    param.data -= weight_decay * param.data

                group['m'][idx] = m
                group['v'][idx] = v

            group['t'] += 1


class RMSProp(Adam):
    """
    RMSprop as proposed by http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf.
    Note that this implementation, unlike the original RMSprop, uses bias-corrected moments.
    """
    def __init__(self, params, lr, beta2):
        super(RMSProp, self).__init__(params, lr, beta2=beta2, beta1=0)


class Lookahead(Optimizer):
    """
    Lookahead Optimization as proposed by https://arxiv.org/abs/1907.08610.
    This is a wrapper class that can be applied to an instantiated optimizer.
    """
    def __init__(self, optimizer, k=5, alpha=0.5):
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = optimizer.param_groups

        self.counter = 0
        for group in optimizer.param_groups:
            group['phi'] = []
            for param in group['params']:
                phi_param = torch.ones_like(param).mul_(param.data)
                group['phi'].append(phi_param)

    def step(self):
        if self.counter == self.k:
            for group_idx, group in enumerate(self.param_groups):
                for idx, _ in enumerate(group['phi']):
                    theta = self.optimizer.param_groups[group_idx]['params'][idx].data
                    group['phi'][idx] = group['phi'][idx] + self.alpha * (theta - group['phi'][idx])
            self.counter = 0
        else:
            self.counter += 1
            self.optimizer.step()


class GradientNoise(Optimizer):
    """
    Gradient Noise as proposed by https://arxiv.org/abs/1511.06807.
    This is a wrapper class that can be applied to an instantiated optimizer.
    """
    def __init__(self, optimizer, eta=0.3, gamma=0.55):
        self.optimizer = optimizer
        self.eta = eta
        self.gamma = gamma
        self.t = 0
        self.param_groups = optimizer.param_groups

    def step(self):
        normal = torch.empty(1).normal_(mean=0, std=np.sqrt(self.eta/((1+self.t)**self.gamma)))\
            .to(self.optimizer.param_groups[0]['params'][0].device)
        for group_idx, group in enumerate(self.param_groups):
            for idx, param in enumerate(group['params']):
                self.optimizer.param_groups[group_idx]['params'][idx].grad.data += normal
                self.optimizer.step()
                self.t += 1


class GradientDropout(Optimizer):
    """
    Gradient dropout as proposed by https://arxiv.org/abs/1912.00144.
    This is a wrapper class that can be applied to an instantiated optimizer.
    Note that this method does not improve optimization significantly and
    is only here for comparison to Learning Rate Dropout.
    """
    def __init__(self, optimizer, grad_retain=0.9):
        self.optimizer = optimizer
        self.grad_retain = grad_retain
        self.grad_bernoulli = Bernoulli(probs=grad_retain)
        self.param_groups = optimizer.param_groups

    def step(self):
        for group_idx, group in enumerate(self.param_groups):
            for idx, param in enumerate(group['params']):
                grad_mask = self.grad_bernoulli.sample(param.size()).to(param.device)
                self.optimizer.param_groups[group_idx]['params'][idx].grad.data *= grad_mask
                self.optimizer.step()
    

##################################################
import argparse
from copy import deepcopy

import torch

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('-num_epochs', type=int, default=30)
	parser.add_argument('-dataset', type=str, default='cifar')
	parser.add_argument('-num_train', type=int, default=50000)
	parser.add_argument('-num_val', type=int, default=2048)
	parser.add_argument('-lr_schedule', type=bool, default=True)
	parser.add_argument('-only_plot', type=bool, default=True)
	args = parser.parse_args()

	data = getattr(misc, 'load_'+args.dataset)(
		num_train=args.num_train,
		num_val=args.num_val
	)

	print(f'Loaded data partitions: ({len(data[0])}), ({len(data[1])})')
	
	opt_tasks = [
		'sgd',
		'sgd_momentum',
		'sgd_nesterov',
		'sgd_weight_decay',
		'sgd_lrd',
		'rmsprop',
		'adam',
		'adam_l2',
		'adamW',
		'adam_lrd',
		'Radam',
		'RadamW',
		'Radam_lrd',
		'nadam',
		'lookahead_sgd',
		'lookahead_adam',
		'gradnoise_adam',
		'graddropout_adam'
	]
	opt_losses, opt_val_losses, opt_labels = [], [], []

	def do_stuff(opt):
		print(f'\nTraining {opt} for {args.num_epochs} epochs...')
		net = CNN() if args.dataset == 'cifar' else MLP()
		_, kwargs = split_optim_dict(optim_dict[opt])
		optimizer = task_to_optimizer(opt)(
			params=net.parameters(),
			**kwargs
		)
		optimizer = wrap_optimizer(opt, optimizer)

		return fit(net, data, optimizer, num_epochs=args.num_epochs, lr_schedule=True)

	for opt in opt_tasks:
		if args.only_plot:
			print("1")
			losses = load_losses(dataset=args.dataset, filename=opt)
			val_losses = load_losses(dataset=args.dataset, filename=opt+'_val')
		else:
			print("2")
			losses, val_losses = do_stuff(opt)
			save_losses(losses, dataset=args.dataset, filename=opt)
			save_losses(val_losses, dataset=args.dataset, filename=opt+'_val')

		if losses is not None:
			opt_losses.append(losses)
			opt_val_losses.append(val_losses)
			opt_labels.append(split_optim_dict(optim_dict[opt])[0])

	if torch.cuda.is_available():
		print(3)
		assert len(opt_losses) == len(opt_val_losses)
		plot_losses(
			losses=opt_losses,
			val_losses=opt_val_losses,
			labels=opt_labels,
			num_epochs=args.num_epochs,
			title=args.dataset,
			plot_val=False,
			yscale_log=False,
			max_epochs=30
		)


usage: ipykernel_launcher.py [-h] [-num_epochs NUM_EPOCHS] [-dataset DATASET] [-num_train NUM_TRAIN]
                             [-num_val NUM_VAL] [-lr_schedule LR_SCHEDULE] [-only_plot ONLY_PLOT]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\frafa\AppData\Roaming\jupyter\runtime\kernel-7e691e1b-ce2b-4f47-93db-6969d27feaa7.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
