In [11]:
from problem_solver import problem_solver
import torch
from utility import path_cost
import numpy as np
from typing import Union, Callable

graph = torch.load('generated_20_eval/generated_20_eval_0.mio')

In [96]:
def perturb(tour: np.array, n_permutations: int = 2) -> np.array:
    for i in range(n_permutations):
        x, y = None, None
        while (x is None or y is None) or (x == y):
            x, y = np.random.randint(1, tour.shape[0]-1), np.random.randint(1, tour.shape[0]-1)
        _ = tour[x]
        tour[x] = tour[y]
        tour[y] = _
    return tour

def hillclimbing(objective: Callable, graph, start_pt: np.array, n_iterations: int) -> Union[np.array, float]:
    best = start_pt
    best_eval = objective(best, graph.weights)
    for i in range(n_iterations):
        start_pt = None
        while start_pt is None:
            start_pt = perturb(best, 7)
            proposed_tour, proposed_eval = start_pt, objective(start_pt, graph.weights)
        if proposed_eval < best_eval:
            best, best_eval = proposed_tour, proposed_eval
    return best, best_eval

def iterated_local_search(objective: Callable, graph, n_restarts: int, n_iterations: int, start_pt: np.array) -> Union[np.array, float]:
    best = start_pt
    best_eval = objective(best, graph.weights)
    for i in range(n_restarts):
        start_pt = None
        while start_pt is None:
            start_pt = perturb(best, 15)
        proposed_tour, proposed_eval = hillclimbing(objective, graph, start_pt, n_iterations)
        if proposed_eval < best_eval:
            best, best_eval = proposed_tour, proposed_eval
    return best, best_eval

In [100]:
start_pt = np.arange(21) + 1
start_pt[-1] = start_pt[0]
eval_pt = path_cost(start_pt, graph.weights)
n_restarts = 100
n_iterations = 500

print(iterated_local_search(path_cost, graph, n_restarts, n_iterations, start_pt))
print('start_pt:', start_pt, eval_pt)
print('optimal_pt:', graph.sub_opt, graph.sub_opt_cost)

(array([ 1,  4,  9, 13, 16,  7, 14, 11, 20, 10,  8, 18,  2, 15, 17, 19, 12,
        6,  3,  5,  1]), 6.842883885082985)
start_pt: [ 1  4  9 13 16  7 14 11 20 10  8 18  2 15 17 19 12  6  3  5  1] 11.729307031820374
optimal_pt: [ 1 11  7 20 17 12 15  9 16  8  6  5 13 19 10  2 18  3 14  4  1] 4.187348527395455
