In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
from sympy import sieve

In [25]:
from progress_bar import log_progress

In [2]:
cities = pd.read_csv('cities.csv')
sample_submission = pd.read_csv('sample_submission.csv')

In [3]:
def is_prime(n):
    """Determines if a positive integer is prime."""

    if n > 2:
        i = 2
        while i ** 2 <= n:
            if n % i:
                i += 1
            else:
                return False
    elif n != 2:
        return False
    return True

In [4]:
cities['is_prime'] = cities.CityId.apply(is_prime)

In [5]:
cities.is_prime.sum()

17802

In [27]:
#fig = plt.figure(figsize=(10,10))
#plt.scatter(cities.X, cities.Y, c=cities['is_prime'], marker=".", alpha=.5);

In [8]:
import math

from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2

def euclid_distance(x1, y1, x2, y2):
  # Euclidean distance between points.
    dist = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
    return dist

In [9]:
def create_distance_matrix(locations):
    size = len(locations)
    dist_matrix = {}
    
    for from_node in range(size):
        dist_matrix[from_node] = {}
        for to_node in range(size):
            x1 = locations[from_node][0]
            y1 = locations[from_node][1]
            x2 = locations[to_node][0]
            y2 = locations[to_node][1]
            dist_matrix[from_node][to_node] = euclid_distance(x1, y1, x2, y2)
    return dist_matrix

In [10]:
def create_distance_callback(dist_matrix):
  # Create the distance callback.

  def distance_callback(from_node, to_node):
    return int(dist_matrix[from_node][to_node])

  return distance_callback

In [135]:
def main():
    locations = create_data_array()
    dist_matrix = create_distance_matrix(locations)
    dist_callback = create_distance_callback(dist_matrix)
    tsp_size = len(locations)
    num_routes = 1
    depot = 0
    node_lst = []
    
    if tsp_size > 0:
        routing = pywrapcp.RoutingModel(tsp_size, num_routes, depot)
        search_parameters = pywrapcp.RoutingModel.DefaultSearchParameters()
        routing.SetArcCostEvaluatorOfAllVehicles(dist_callback)
        # Solve the problem.
        assignment = routing.SolveWithParameters(search_parameters)
        if assignment:
            # Solution cost.
            print("Total distance: " + str(assignment.ObjectiveValue()) + "\n")
            # Inspect solution.
            # Only one route here; otherwise iterate from 0 to routing.vehicles() - 1.
            route_number = 0
            node = routing.Start(route_number)
            start_node = node
            route = ''

            while not routing.IsEnd(node):
                route += str(node) + ' -> '
                node_lst.append(node)
                node = assignment.Value(routing.NextVar(node))
            route += '0'
            print("Route:\n\n" + route)
        else:
            print('No solution found.')
    else:
        print('Specify an instance greater than 0.')
        
    return(node_lst)
        

#if __name__ == '__main__':
#    main()

In [133]:
def create_data_array():
    locations = []
    for i in cities_subset.CityId:
        locations.append([cities_subset.X[i], cities_subset.Y[i]])
            
    return(locations)

In [136]:
for i in log_progress(range(1, 4)):
    cities_subset = cities.loc[cities['bins'] == i]
    main()

VBox(children=(HTML(value=''), IntProgress(value=0, max=3)))

Total distance: 89668

Route:

0 -> 154 -> 836 -> 9 -> 774 -> 439 -> 938 -> 715 -> 148 -> 343 -> 164 -> 757 -> 572 -> 936 -> 934 -> 307 -> 390 -> 578 -> 12 -> 82 -> 863 -> 985 -> 693 -> 138 -> 359 -> 14 -> 165 -> 682 -> 907 -> 213 -> 770 -> 357 -> 789 -> 555 -> 553 -> 166 -> 634 -> 412 -> 487 -> 372 -> 959 -> 423 -> 971 -> 101 -> 871 -> 670 -> 533 -> 514 -> 994 -> 921 -> 958 -> 204 -> 145 -> 211 -> 120 -> 163 -> 860 -> 119 -> 455 -> 778 -> 202 -> 845 -> 382 -> 276 -> 523 -> 497 -> 908 -> 846 -> 852 -> 706 -> 607 -> 554 -> 557 -> 597 -> 524 -> 949 -> 965 -> 335 -> 384 -> 378 -> 81 -> 110 -> 46 -> 88 -> 536 -> 341 -> 193 -> 23 -> 402 -> 477 -> 454 -> 222 -> 387 -> 892 -> 899 -> 978 -> 476 -> 5 -> 488 -> 485 -> 92 -> 362 -> 105 -> 525 -> 329 -> 325 -> 215 -> 888 -> 930 -> 326 -> 641 -> 68 -> 45 -> 945 -> 618 -> 319 -> 653 -> 121 -> 377 -> 619 -> 152 -> 50 -> 459 -> 995 -> 189 -> 74 -> 588 -> 545 -> 19 -> 194 -> 797 -> 796 -> 704 -> 414 -> 566 -> 479 -> 610 -> 330 -> 278 -> 647 -> 575 -> 1

Total distance: 90029

Route:

0 -> 971 -> 159 -> 24 -> 410 -> 12 -> 860 -> 78 -> 351 -> 382 -> 19 -> 939 -> 188 -> 631 -> 954 -> 395 -> 923 -> 201 -> 809 -> 489 -> 915 -> 294 -> 481 -> 467 -> 787 -> 894 -> 898 -> 131 -> 185 -> 61 -> 31 -> 334 -> 409 -> 865 -> 588 -> 667 -> 804 -> 789 -> 80 -> 726 -> 583 -> 453 -> 526 -> 50 -> 867 -> 427 -> 650 -> 22 -> 773 -> 44 -> 849 -> 365 -> 845 -> 401 -> 788 -> 56 -> 423 -> 311 -> 262 -> 888 -> 329 -> 29 -> 540 -> 496 -> 226 -> 581 -> 133 -> 736 -> 256 -> 863 -> 681 -> 324 -> 403 -> 515 -> 1 -> 462 -> 958 -> 737 -> 529 -> 569 -> 28 -> 344 -> 250 -> 65 -> 296 -> 940 -> 398 -> 851 -> 313 -> 335 -> 319 -> 905 -> 585 -> 107 -> 772 -> 69 -> 62 -> 837 -> 562 -> 192 -> 134 -> 181 -> 806 -> 263 -> 655 -> 734 -> 115 -> 394 -> 187 -> 598 -> 749 -> 196 -> 785 -> 972 -> 183 -> 719 -> 586 -> 284 -> 217 -> 446 -> 145 -> 548 -> 934 -> 209 -> 761 -> 893 -> 766 -> 404 -> 706 -> 237 -> 956 -> 730 -> 384 -> 87 -> 856 -> 283 -> 672 -> 435 -> 738 -> 156 -> 816 -> 899

In [90]:
cities['bins'] = pd.cut(cities['CityId'], list(range(0, len(cities), 1000)), 
       labels=list(range(0, len(list(range(0, len(cities), 1000)))-1)), include_lowest=True)

In [111]:
cities['bins'].fillna(0, inplace=True)
cities['bins'] = cities['bins'].astype(int)

In [113]:
cities

Unnamed: 0,CityId,X,Y,is_prime,bins
0,0,316.836739,2202.340707,False,0
1,1,4377.405972,336.602082,False,0
2,2,3454.158198,2820.053011,True,0
3,3,4688.099298,2935.898056,True,0
4,4,1010.696952,3236.750989,False,0
5,5,2474.230877,1435.514651,True,0
6,6,1029.277795,2721.800952,False,0
7,7,3408.887685,199.585793,True,0
8,8,1304.006125,2657.427246,False,0
9,9,4211.525725,2294.595208,False,0
