In [None]:
import numpy as np
from grblas import *

In [None]:
A = io.mmread('geolocation.mtx').dup(int)
A

In [None]:
A.isequal(A.T)

In [None]:
# sample.labels
label_data = [
    [1, 37.7449063493, -160.009432884],
    [2, 37.8668048274, -130.797325311],
    [3, 63.6431858915, -112.816156983],
    [4, 51.6431858915, -172.816156983],
    [5, 21.8691125061, -133.259106041],
    [6, 55.6431858915, -177.816156983],
    [7, 63.8652346572, -131.250634008],
    [8, 61.2043433677, -114.300341275],
    [13, 37.8691125062, -122.259106043],
    [14, 30.8691125062, -100.259106043],
    [15, 17.1211125062, -112.259106043],
    [17, -37.8691125062, 122.259106043],
    [18, 60.2043433677, -110.300341275],
    [21, 86.8804435732, -170.640705659],
    [22, 37.8804435732, -110.640705659],
    [25, 38.2043433677, -114.300341275],
    [28, 61.2043433677, -114.300341275],
    [30, 37.8691125062, -122.259106043],
    [31, 37.6431858915, -121.816156983],
    [32, 37.8652346572, -122.250634008],
    [33, 38.2043433677, -114.300341275],
    [34, 36.7582225593, -118.167916598],
    [35, 33.9774659389, -114.886512278],
    [36, 39.2598884729, -106.804662071],
    [37, 37.8804435732, -122.230147039],
    [39, 9.42761644853, -110.640705659],
]

In [None]:
# locations.labels
label_data = [
    [1, 37.7449063493, -122.009432884],
    [2, 37.8668048274, -122.257973253],
    [4, 37.869112506, -122.25910604],
    [6, 37.6431858915, -121.816156983],
    [11, 37.8652346572, -122.250634008],
    [19, 38.2043433677, -114.300341275],
    [21, 36.7582225593, -118.167916598],
    [22, 33.9774659389, -114.886512278],
    [30, 39.2598884729, -106.804662071],
    [31, 37.880443573, -122.230147039],
    [39, 9.4276164485, -110.640705659],
]

In [None]:
max_iter = 1000  # max iter in spatial median
eps = 0.001  # epsilon check in spatial median
max_mad = 1500.  # The sample data is probably pretty spread out, so we're using a very large value here

In [None]:
def haversine_distance(many_lat, many_lon, single_lat, single_lon, *, radius=6371.0, to_radians=True):
    """Compute the distances between many_{lat,lon} and single_{lat,lon}"""
    # many_lat (and many_lon) may be a Matrix or a Vector
    # single_lat (and single_lon) must be a Vector
    if to_radians:
        many_lat = op.numpy.radians(many_lat)
        many_lon = op.numpy.radians(many_lon)
        single_lat = op.numpy.radians(single_lat)
        single_lon = op.numpy.radians(single_lon)
    if utils.output_type(many_lat) is Vector:
        diff_lat = op.minus(single_lat & many_lat)
        diff_lon = op.minus(single_lon & many_lon)
        cos_terms = op.times(op.cos(single_lat) & op.cos(many_lat))
    else:
        single_lat = ss.diag(single_lat)
        single_lon = ss.diag(single_lon)
        diff_lat = op.any_minus(single_lat @ many_lat)
        diff_lon = op.any_minus(single_lon @ many_lon)
        cos_terms = op.any_times(op.cos(single_lat) @ op.cos(many_lat))
    a = op.plus(
        op.sin(0.5 * diff_lat)**2
        & op.times(cos_terms & op.sin(0.5 * diff_lon)**2)
    )
    return (2 * radius * op.asin(op.sqrt(a))).new()

In [None]:
# Sanity check
# https://www.igismap.com/haversine-formula-calculate-geographic-distance-earth/
# Nebraska
v1 = Vector.from_values([0], [41.507483])
w1 = Vector.from_values([0], [-99.436554])
# Kansas
v2 = Vector.from_values([0], [38.504048])
w2 = Vector.from_values([0], [-98.315949])

haversine_distance(v1, w1, v2, w2)[0].new().isclose(347.3, abs_tol=0.1)

In [None]:
indices, lat, lon = zip(*label_data)
indices = [idx - 1 for idx in indices]  # change from 1-based to 0-based
lat = Vector.from_values(indices, lat, size=A.nrows)
lon = Vector.from_values(indices, lon, size=A.nrows)
assert (op.abs(lat) <= 90).reduce(op.land)
assert (op.abs(lon) <= 180).reduce(op.land)
lat

In [None]:
# We need to try to compute locations for these nodes
unknown = Vector.new(int, size=lat.size)
unknown(~lat.S) << 1
unknown

In [None]:
U = op.any_second(ss.diag(unknown) @ A).new()

## Outer loop
By default, do three iterations (or so).

**Partition edges of unknown locations based on number of neighbors with locations: 1, 2, >2**

In [None]:
Ulat = op.any_second(U @ ss.diag(lat)).new()
Ulon = op.any_second(U @ ss.diag(lon)).new()

In [None]:
degrees = Ulat.reduce_rowwise(agg.count).new(dtype=int)
min_degrees = degrees.reduce(op.min).new()

In [None]:
if min_degrees == 1:
    # one_neighbor = select(degrees, 'EQ_THUNK', 1)
    one_neighbor = (degrees == 1).new()
    one_neighbor = one_neighbor.dup(int, mask=one_neighbor.V)
else:
    one_neighbor = Vector.new(int, size=degrees.size)
one_neighbor

In [None]:
if min_degrees.value <= 2:
    # two_neighbors = select(degrees, 'EQ_THUNK', 2)
    two_neighbors = (degrees == 2).new()
    two_neighbors = two_neighbors.dup(int, mask=two_neighbors.V)
else:
    two_neighbors = Vector.new(int, size=degrees.size)
two_neighbors

In [None]:
if min_degrees.value > 2:
    # Unused.  We can use Ulat and Ulon directly
    many_neighbors = unknown
else:
    # many_neighbors = select(degrees, 'GT_THUNK', 2)
    many_neighbors = (degrees > 2).new()
    many_neighbors = many_neighbors.dup(int, mask=many_neighbors.V)
many_neighbors

**Compute where # neighbors == 1**

In [None]:
# median abs deviation
mad = op.times(one_neighbor, 0.0).new()
if one_neighbor.nvals > 0:
    one_neighbor = ss.diag(one_neighbor)
    lat(op.second) << op.any_second(one_neighbor @ Ulat).reduce_rowwise(op.any)
    lon(op.second) << op.any_second(one_neighbor @ Ulon).reduce_rowwise(op.any)

**Compute where # neighbors == 2**

In [None]:
if two_neighbors.nvals > 0:
    two_neighbors = ss.diag(two_neighbors)
    two_lat = op.any_second(two_neighbors @ Ulat).new()
    two_lon = op.any_second(two_neighbors @ Ulon).new()

    lat1 = two_lat.reduce_rowwise(agg.first).new()
    lat2 = two_lat.reduce_rowwise(agg.last).new()
    lon1 = two_lon.reduce_rowwise(agg.first).new()
    lon2 = two_lon.reduce_rowwise(agg.last).new()

    lat1 = op.numpy.radians(lat1).new()
    lat2 = op.numpy.radians(lat2).new()
    lon1 = op.numpy.radians(lon1).new()
    lon2 = op.numpy.radians(lon2).new()

    cos_lat2 = op.cos(lat2).new()
    diff_lon = op.minus(lon2 & lon1).new()
    bx = op.times(cos_lat2 & op.cos(diff_lon)).new()
    by = op.times(cos_lat2 & op.sin(diff_lon)).new()
    cos_lat1_bx = op.plus(op.cos(lat1) & bx).new()
    lat3 = op.atan2(
        op.plus(op.sin(lat1) & op.sin(lat2))
        & op.hypot(cos_lat1_bx & by)
    ).new()
    lon3 = op.plus(lon1 & op.atan2(by & cos_lat1_bx)).new()
    lat(op.second) << op.numpy.degrees(lat3)
    lon(op.second) << op.numpy.degrees(lon3)

    # Do we need to make sure lat is within -90 to 90, and lon is within -180 to 180?
    assert (op.abs(lat3) <= np.pi / 2).reduce(op.land)
    assert (op.abs(lon3) <= np.pi).reduce(op.land)
    
    # median abs deviation
    mad(op.any) << haversine_distance(lat1, lon1, lat3, lon3, to_radians=False)

In [None]:
if two_neighbors.nvals > 0:
    # Sanity check
    d1 = haversine_distance(lat1, lon1, lat3, lon3, to_radians=False)
    d2 = haversine_distance(lat2, lon2, lat3, lon3, to_radians=False)
    d12 = haversine_distance(lat1, lon1, lat2, lon2, to_radians=False)
    assert d1.isclose(d2)
    assert d12.isclose(d1 + d2)

**Compute where # neighbors > 2**

In [None]:
if many_neighbors.nvals == 0:
    print('STOP!  No need to continue.  Start new iteration above.')

In [None]:
# if many_neighbors.nvals > 0:
if one_neighbor.nvals == 0 and two_neighbors.nvals == 0:
    many_lat = Ulat
    many_lon = Ulon
else:
    many_neighbors = ss.diag(many_neighbors)
    many_lat = op.any_second(many_neighbors @ Ulat).new()
    many_lon = op.any_second(many_neighbors @ Ulon).new()

In [None]:
cur_lat = many_lat.reduce_rowwise(agg.mean).new()
cur_lon = many_lon.reduce_rowwise(agg.mean).new()

In [None]:
i = 0
while True:
    D = haversine_distance(many_lat, many_lon, cur_lat, cur_lon)
    Dinv = op.minv(D).new(mask=D.V)  # drop 0s
    Dinvs = Dinv.reduce_rowwise(op.plus).new()
    W = op.any_rdiv(ss.diag(Dinvs) @ Dinv).new()

    Tlat = op.times(many_lat & W).reduce_rowwise(op.plus).new()
    Tlon = op.times(many_lon & W).reduce_rowwise(op.plus).new()

    Dcounts = D.reduce_rowwise(agg.count[int]).new()
    Dinv_counts = Dinv.reduce_rowwise(agg.count[int]).new()

    num_zeros = op.minus(Dcounts & Dinv_counts).new()
    # Other implementations partition the following calculations into three groups:
    # if num_zeros == 0: cur_lat, cur_lon = Tlat, Tlon
    # elif num_zeros == Dcounts: break
    # else: r = ... ; alpha = ... ; next_lat = ... ;
    #
    # Let's simplify the calculation at the cost of doing a little more work

    Rlat = op.times(op.minus(Tlat & cur_lat) & Dinvs).new()
    Rlon = op.times(op.minus(Tlon & cur_lon) & Dinvs).new()
    # r = op.sqrt(op.plus(Rlat**2 & Rlon**2)).new()
    r = op.hypot(Rlat & Rlon).new()

    # set rinv to 0 where divided by 0
    rinv = op.truediv(num_zeros & r).new()
    rinv(op.isinf(rinv).V) << 0.0
    alpha = op.max(0.0, op.minus(1, rinv)).new()
    beta = op.min(1.0, rinv).new()
    next_lat = op.plus(
        op.times(alpha & Tlat)
        & op.times(beta & cur_lat)
    ).new()
    next_lon = op.plus(
        op.times(alpha & Tlon)
        & op.times(beta & cur_lon)
    ).new()

    if next_lat.nvals != cur_lat.nvals:
        next_lat(op.first) << cur_lat
        next_lon(op.first) << cur_lon

    if i >= max_iter:
        cur_lat = next_lat
        cur_lon = next_lon
        break

    diff_lat = op.minus(cur_lat & next_lat).new()
    diff_lon = op.minus(cur_lon & next_lon).new()
    if (op.hypot(diff_lat & diff_lon) < eps).reduce(op.land).new():
        # Once a node converges (either here or where num_zeros == Dcounts), we could
        # remove it from future iterations.  Let's leave this as a future optimization.
        cur_lat = next_lat
        cur_lon = next_lon
        break

    cur_lat = next_lat
    cur_lon = next_lon
    i += 1

In [None]:
assert (op.abs(cur_lat) <= 90).reduce(op.land)
assert (op.abs(cur_lon) <= 180).reduce(op.land)

In [None]:
lat(op.second) << cur_lat
lon(op.second) << cur_lon

In [None]:
# XXX: This should actually use MEDIAN instead of MEAN!!!
mad(op.any) << D.reduce_rowwise(agg.mean)

In [None]:
# Drop values with large absolute deviation
mask = ~(mad > max_mad).V
lat = lat.dup(mask=mask)
lon = lon.dup(mask=mask)

In [None]:
if lat.nvals == lat.size:
    print("All done!  We can stop early")

In [None]:
lat

In [None]:
mad