In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from itertools import product
from random import randint
from python_toolbox import caching

In [None]:
def construct_stations(r):
    """Find the stations we can break the trip from NorthPole to SouthPole into.
    Approach: Since we know that z-coordinate must be an integer, search for x,y for each integer z between R and -R
    This reduces problem to finding integer coordinates on surface a 2D circle at integer heights
    """
    assert isinstance(r, int)
    station_list = []

    # add north pole
    station_list.append((0, 0, r))

    # iterate over each integer z
    for i in range(1, r):
        z = r - i
        temp_radius2 = 2 * r * i - i**2
        lattice_points = _find_latticePoints(temp_radius2)

        for x, y in lattice_points:
            # add for each of the 8 octants, but we are assuming that the minimum path does not full longitude range
            station_list.append((x, y, z))
            station_list.append((x, y, -z))

    # append south pole
    station_list.append((0, 0, -r))

    # sort by height
    station_list = sorted(station_list, key=lambda x: x[2], reverse=True)
    return station_list


@caching.cache()
def _find_latticePoints(r2):
    # r2 is radius squared
    # serach for integer coordinates on surface of a circle
    # due to symmetry, we are only searching in first quadrant
    # we know intersects at (0,r) and is monotonically decreasing since centered at 0
    lattice_points = []
    temp_x = 0
    temp_y = int(r2**0.5)
    # rule: if outside circle step down, if inside circle step right, if on circle add to list and step down
    while temp_y >= 0 and temp_x <= r2**0.5:
        temp_distance = temp_x**2 + temp_y**2
        if temp_distance == r2:
            lattice_points.append((temp_x, temp_y))
            temp_y -= 1
        elif temp_distance > r2:
            temp_y -= 1
        else:
            temp_x += 1
    return lattice_points


def _basic(r):
    station_list = []
    station_list.append((0, 0, r))
    station_list.append((0, r, 0))
    station_list.append((0, 0, -r))
    return station_list

In [None]:
def plot(radius, station_list, visited_list):
    fig = plt.figure(figsize=(9, 9))
    plt.clf()
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)

    # draw sphere
    sphere_radius = 1 * radius
    u = np.linspace(0, np.pi, 40)
    v = np.linspace(0, 2 * np.pi, 40)
    x = sphere_radius * np.outer(np.sin(u), np.sin(v))
    y = sphere_radius * np.outer(np.sin(u), np.cos(v))
    z = sphere_radius * np.outer(np.cos(u), np.ones_like(v))
    ax.plot_wireframe(x, y, z, color="gray", alpha=0.1)

    # plot all stations
    ax.scatter(*zip(*station_list), color="b", alpha=0.2)

    # plot visited stations
    arc_list = []
    for v in range(0, len(visited_list) - 1):
        arc_list += arc_plot(visited_list[v], visited_list[v + 1])
    ax.plot(*zip(*arc_list), color="r")
    ax.scatter(*zip(*visited_list), color="b", alpha=0.7)

    ax.set_box_aspect((1, 1, 1))

In [None]:
def arc_plot(c1, c2, N=10):
    # TODO: this should be great-circle arc instead, current code will connect in a long arc not guaranteed to be the most direct path
    # is fine for small distances which shortest risk path should be anyway
    """Given 2 Cartesian coordinates, returns an arc path sequence along surface of sphere with resolution N"""
    arc_list = []
    # 1. Convert points to spherical
    s1 = cartesianToSpherical(c1)
    s2 = cartesianToSpherical(c2)

    # 2. Interpolate spherical parameters
    radius = s1[0]
    theta_list = np.linspace(s1[1], s2[1], N)
    phi_list = np.linspace(s1[2], s2[2], N)

    # 3. Convert back to Cartesian
    for theta, phi in zip(theta_list, phi_list):
        x = radius * np.sin(theta) * np.cos(phi)
        y = radius * np.sin(theta) * np.sin(phi)
        z = radius * np.cos(theta)
        arc_list.append((x, y, z))
    return arc_list


def cartesianToSpherical(c):
    r = np.sqrt(c[0] ** 2 + c[1] ** 2 + c[2] ** 2)
    theta = np.arccos(c[2] / r)
    if c[0] == 0:
        phi = np.pi / 2
    elif c[0] > 0:
        phi = np.arctan(c[1] / c[0])
    else:
        phi = np.arctan(c[1] / c[0]) + np.pi
    return (r, theta, phi)


def arc_length(c1, c2, radius):
    # https://en.wikipedia.org/wiki/Great-circle_distance#Formulas
    c1 = np.array(c1) / np.linalg.norm(c1)
    c2 = np.array(c2) / np.linalg.norm(c2)
    d_sigma = np.arccos(np.dot(c1, c2))
    return radius * d_sigma

In [None]:
def dijkstra_path(station_list, radius):
    # let distance of start vertex from start vertex = 0
    # let distance of all other vertices from start = inf
    distance_map = dict.fromkeys(station_list, np.inf)
    distance_map[station_list[0]] = 0

    previous_map = dict.fromkeys(station_list, None)

    unvisited = station_list.copy()
    # repeat until all vertices are visited
    while len(unvisited) > 0:
        # visit the univsited vertex with the smallest known distance from the start vertex
        temp_min = np.inf
        for temp_index, temp_station in enumerate(unvisited):
            if distance_map[temp_station] < temp_min:
                temp_min = distance_map[temp_station]
                current_station = temp_station
                current_index = temp_index

        # for the current vertex, examine its unvisited neighbors
        # heuristic, likely not not be skipping many z layers at once so limit neighbors
        delta_z_cap = 8
        delta_z_count = 0
        delta_z_temp = current_station[2]

        # z_index is where current_neighbors height starts in the list
        z_index = current_index
        while unvisited[z_index][2] == current_station[2] and z_index > 0:
            z_index -= 1

        for neighbor_station in unvisited[z_index:]:
            if neighbor_station == current_station:
                continue

            if neighbor_station[2] != delta_z_temp:
                delta_z_count += 1
                delta_z_temp = neighbor_station[2]
            if delta_z_count >= delta_z_cap:
                break

            # calculate distance of each neighbor from start vertex
            neighbor_risk = (
                arc_length(current_station, neighbor_station, radius) / (np.pi * radius)
            ) ** 2
            neighbor_distance = distance_map[current_station] + neighbor_risk

            # if the calculated distance of a vertex is less than the known distance, update the shortest distance
            if neighbor_distance < distance_map[neighbor_station]:
                distance_map[neighbor_station] = neighbor_distance

                # update the previous vertex for each of the updated distances
                previous_map[neighbor_station] = current_station

        unvisited.remove(current_station)

    # return shortest path
    shortest_path = []
    temp = station_list[-1]
    while temp != None:
        shortest_path.append(temp)
        temp = previous_map[temp]

    return shortest_path, distance_map[station_list[-1]]

In [None]:
r = 2**11 - 1
station_list = construct_stations(r)
visited_list, total_risk = dijkstra_path(station_list, r)
print(total_risk)
plot(r, station_list, visited_list)

In [None]:
r = 2**11 - 1
station_list = construct_stations(r)
visited_list, total_risk = dijkstra_path(station_list, r)
print(total_risk)
plot(r, station_list, visited_list)

In [None]:
sum = 0
for r in [2**n - 1 for n in range(1, 16)]:
    print(f"Calculating for r={r}...")
    station_list = construct_stations(r)
    visited_list, minimum_risk = dijkstra_path(station_list, r)
    print(f"minimum risk {minimum_risk}")
    sum += minimum_risk
print(f"sum {sum}")