In [32]:
import numpy as np
import math
import multiprocessing as mp
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable

from collections import namedtuple
from NetworkModel import *
from es import SimpleGA, CMAES, PEPG, OpenES, PEPGVariant 
from es import compute_ranks, compute_centered_ranks,compute_weight_decay
import copy 

In [33]:
def config():
	global NPARAMS
	global NPOPULATION
	global args 
	global model_shapes
	global weight_decay_coef 


	torch.manual_seed(0)
	np.random.seed(0)
	NPOPULATION=101  
	weight_decay_coef = 0.1

	Args = namedtuple('Args', ['batch_size', 'test_batch_size', 'epochs', 'lr', 'cuda', 'seed', 'log_interval'])
	args = Args(batch_size=100, test_batch_size=1000, epochs=3, lr=0.001, cuda=False, seed=0, log_interval=10)
    

In [34]:
def dataFeed():
	global train_loader
	global valid_loader
	global test_loader

	kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}

	train_loader = torch.utils.data.DataLoader(datasets.MNIST('MNIST_data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
  			batch_size=args.batch_size, shuffle=True, **kwargs)

	valid_loader = train_loader

	test_loader = torch.utils.data.DataLoader(
  		datasets.MNIST('MNIST_data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),
  			batch_size=args.batch_size, shuffle=True, **kwargs)

In [35]:
def debugSolver(es,printLog=True):
	best_valid_acc = 0
	training_log=[]
	for epoch in range(1, 10*args.epochs + 1):
	  # train loop
	  model.eval()
	  for batch_idx, (data, target) in enumerate(train_loader):
	    data, target = Variable(data), Variable(target)
	    solutions = es.ask() 
	    reward = np.zeros(es.popsize)
	    for i in range(es.popsize):
	      update_model(solutions[i], model, model_shapes)
	      output = model(data)
	      loss = F.nll_loss(output, target) # loss function used = =  
	      reward[i] = -loss.data[0]
	    best_raw_reward = reward.max()
	    reward = compute_centered_ranks(reward)
	    l2_decay = compute_weight_decay(weight_decay_coef, solutions)
	    reward += l2_decay
	    es.tell(reward)
	    result = es.result()
	    if (batch_idx % 50 == 0) and printLog:
	    	print(epoch, batch_idx, best_raw_reward,result[1]) 
	  curr_solution = es.current_param()
	  update_model(curr_solution, model, model_shapes)
	  valid_acc = evaluate(model, valid_loader, print_mode=False)
	  training_log.append(valid_acc) 
	  print('valid_acc', valid_acc * 100.)
	  if valid_acc >= best_valid_acc:
	    best_valid_acc = valid_acc
	    best_model = copy.deepcopy(model)
	    print('best valid_acc', best_valid_acc * 100.)
	return training_log,best_model

In [None]:
config()
model=Net()
NPARAMS,model_shapes=cal_nparams(model)
dataFeed()

In [None]:
pepg = PEPG(NPARAMS,                         # number of model parameters
	    sigma_init=0.01,                  # initial standard deviation
	    learning_rate=0.1,               # learning rate for standard deviation
	    learning_rate_decay=1.0,       # don't anneal the learning rate
	    popsize=NPOPULATION,             # population size
	    average_baseline=False,          # set baseline to average of batch
	    weight_decay=0.00,            # weight decay coefficient
	    rank_fitness=False,           # use rank rather than fitness numbers
	    forget_best=False)            # don't keep the historical best solution)

In [None]:
# print(pepg)

In [None]:
pepgv = PEPGVariant(NPARAMS,                         # number of model parameters
	    sigma_init=0.01,                  # initial standard deviation
	    learning_rate=0.1,               # learning rate for standard deviation
	    learning_rate_decay=1.0,       # don't anneal the learning rate
	    popsize=NPOPULATION,             # population size
	    average_baseline=False,          # set baseline to average of batch
	    weight_decay=0.00,            # weight decay coefficient
	    rank_fitness=False,           # use rank rather than fitness numbers
	    forget_best=False)            # don't keep the historical best solution)

In [None]:
PEPG_history,best_model=debugSolver(pepg,printLog=True)
PEPGV_history,best_model=debugSolver(pepgv,printLog=True)