In [None]:
# OR-Tools, PyTorch 설치
!pip install torch
!pip install ortools

In [None]:
import torch
from model_search import TSP_net
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from utils import compute_tour_length
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
N = 50

# 50 개의 도시를 가진 랜덤 외판원 문제 생성
data = torch.rand(N, 2).to(device)

print(data.cpu().numpy())
plt.scatter(data.cpu()[:,0], data.cpu()[:,1])

In [None]:
# Transformer 모델 만들기
model = TSP_net('linear', None, None, 2, 128, 512, 6, 2, 8, 1000).to(device)

# 경로 탐색 (Transformer + Greedy 알고리즘)
tour, _, _, _ = model(data[None], 1, greedy=True, beamsearch=False)

print('=<탐색 결과 (경로)>=')
for i in range(tour.size(1)-1):
    print(f'{tour[0][i].item()} -> ', end='')
    if (i+1) % 10 == 0:
        print()
print(f'{tour[0][i+1].item()}')
print()

print('=<경로 길이>=')
tour_length = compute_tour_length(data[None], tour)
print(tour_length.item())

In [None]:
# 경로 그려보기
plt.scatter(data.cpu()[:,0], data.cpu()[:,1])
sorted_data = data.cpu()[tour[0].cpu()]
plt.plot(sorted_data[:,0], sorted_data[:,1], color='black')
plt.show()

In [None]:
# 미리 학습된 Transformer 모델 읽어오기
checkpoint = torch.load('checkpoint/transformer_tsp50_demo.pt', map_location=torch.device('cuda'))
model.load_state_dict(
    checkpoint
)
model = model.to(device)

In [None]:
# 경로 탐색 (Transformer + Greedy 알고리즘)
tour, _, _, _ = model(data[None], 1, greedy=True, beamsearch=False)

print('=<탐색 결과 (경로)>=')
for i in range(tour.size(1)-1):
    print(f'{tour[0][i].item()} -> ', end='')
    if (i+1) % 10 == 0:
        print()
print(f'{tour[0][i+1].item()}')
print()

print('=<경로 길이>=')
tour_length = compute_tour_length(data[None], tour)
print(tour_length.item())

In [None]:
# 경로 그려보기
plt.scatter(data.cpu()[:,0], data.cpu()[:,1])
sorted_data = data.cpu()[tour[0].cpu()]
plt.plot(sorted_data[:,0], sorted_data[:,1], color='black')
plt.plot((sorted_data[-1,0], sorted_data[0,0]), (sorted_data[-1,1], sorted_data[0,1]), color='black')
plt.show()

In [None]:
from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2
ortools_data = (data * 100).long()

def create_data_model():
    """Stores the data for the problem."""
    data = {}
    data["distance_matrix"] = torch.cdist(ortools_data.float(), ortools_data.float()).long().cpu().numpy()
    data["num_vehicles"] = 1
    data["depot"] = 0
    return data

data_model = create_data_model()
manager = pywrapcp.RoutingIndexManager(
    len(data_model["distance_matrix"]), data_model["num_vehicles"], data_model["depot"]
)
routing = pywrapcp.RoutingModel(manager)

def distance_callback(from_index, to_index):
    """Returns the distance between the two nodes."""
    # Convert from routing variable Index to distance matrix NodeIndex.
    from_node = manager.IndexToNode(from_index)
    to_node = manager.IndexToNode(to_index)
    return data_model["distance_matrix"][from_node][to_node]

transit_callback_index = routing.RegisterTransitCallback(distance_callback)
  
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

import time
search_parameters = pywrapcp.DefaultRoutingSearchParameters()
search_parameters.first_solution_strategy = (
    routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC
)
def print_solution(manager, routing, solution):
    """Prints solution on console."""
    index = routing.Start(0)
    plan_output = "OR-Tools가 찾은 경로:\n"
    route_distance = 0
    tour = [0]
    while not routing.IsEnd(index):
        plan_output += f" {manager.IndexToNode(index)} ->"
        previous_index = index
        index = solution.Value(routing.NextVar(index))
        route_distance += routing.GetArcCostForVehicle(previous_index, index, 0)
        tour.append(index)
    plan_output += f" {manager.IndexToNode(index)}\n"
    # print(plan_output)
    del tour[-1]
    return torch.LongTensor(tour).to(device)[None]

solution = routing.SolveWithParameters(search_parameters)
if solution:
    tour = print_solution(manager, routing, solution)

print('=<탐색 결과 (경로)>=')
for i in range(tour.size(1)-1):
    print(f'{tour[0][i].item()} -> ', end='')
    if (i+1) % 10 == 0:
        print()
print(f'{tour[0][i+1].item()}')
print()

print('=<경로 길이>=')
tour_length = compute_tour_length(data[None], tour)
print(tour_length.item())

# 경로 그려보기
plt.scatter(data.cpu()[:,0], data.cpu()[:,1])
sorted_data = data.cpu()[tour[0]]
plt.plot(sorted_data[:,0], sorted_data[:,1], color='black')
plt.plot((sorted_data[-1,0], sorted_data[0,0]), (sorted_data[-1,1], sorted_data[0,1]), color='black')
plt.show()