In [1]:
import numpy as np
import logging
import datetime

In [4]:
class NelderMead:
    
    def __init__(self):
        
        # Configuration
        self.alpha = 1.
        self.gamma = 2.
        self.rho = .5
        self.sigma = .5
        
    def order(self, x):
        return np.array(sorted(list(x), key=lambda t: self.obj_func(t)))
    
    def centroid(self, x):
        return x.mean(axis=0)
    
    def reflection(self, x_np, x_o):
        return x_o + self.alpha*(x_o - x_np)
    
    def expansion(self, x_r, x_o):
        return x_o + self.gamma*(x_r - x_o)
    
    def contraction(self, x_np, x_o):
        return x_o + self.rho*(x_np - x_o)
    
    def shrink(self, x):
        x[1:] = x[0] + self.sigma*(x[1:] - x[0])
        return x
    
    def run(self, obj_func, x0, max_iter=1000000, tol=1e-20):
        
        # Logging
        logger = logging.getLogger("nelder-mead")
        logger.setLevel(logging.INFO)
        file_handler = logging.FileHandler("./log/nelder_mead_{}.log".format(datetime.datetime.now().strftime("%Y%m%d%H%M%S")))
        logger.addHandler(file_handler)
        
        self.obj_func = obj_func  # Objective Function
        x = x0  # Initial Points
        
        for i in range(max_iter):
            
            #1. Order
            x = self.order(x)
            #2. Centroid
            x_o = self.centroid(x[:-1])   
            #3. Reflection
            x_r = self.reflection(x[-1], x_o)
            if(obj_func(x_r) >= obj_func(x[0]) and obj_func(x_r) < obj_func(x[-2])):
                x[-1] = x_r
                logger.info("i: {}, x: {}, f(x): {}, Reflection".format(i, x[0], self.obj_func(x[0])))
            elif (obj_func(x_r) < obj_func(x[0])):
                #4. Expansion
                x_e = self.expansion(x_r, x_o)
                if(obj_func(x_e) < obj_func(x_r)):
                    x[-1] = x_e
                    logger.info("i: {}, x: {}, f(x): {}, Expansion".format(i, x[0], self.obj_func(x[0])))
                else:
                    x[-1] = x_r
                    logger.info("i: {}, x: {}, f(x): {}, Refelection after Expansion".format(i, x[0], self.obj_func(x[0])))
            else:
                #5. Contraction
                x_c = self.contraction(x[-1], x_o)
                if(obj_func(x_c) < obj_func(x[-1])):
                    x[-1] = x_c
                    logger.info("i: {}, x: {}, f(x): {}, Contraction".format(i, x[0], self.obj_func(x[0])))
                else:
                    #6. Shrink
                    x = self.shrink(x)
                    logger.info("i: {}, x: {}, f(x): {}, Shrink".format(i, x[0], self.obj_func(x[0])))
               
            # Termination Criterion
            criterion = np.array([self.obj_func(t) for t in x]).std()
            if criterion < tol:
                logger.info("Standard deviation: {}".format(criterion))
                break
                                        
        return x

In [5]:
if __name__ == "__main__":
    
    np.random.seed(20190605)
    x0 = np.random.random([3,2])
    rosenbrock = lambda x: np.sum(100*(x[1:]-x[:-1]**2)**2 + (1-x[:-1])**2)
    nm = NelderMead()
    x = nm.run(rosenbrock, x0)