In [1]:
import polars as pl
from scipy.interpolate import Rbf
import numpy as np
from PIL import Image, ImageDraw
from tps import ThinPlateSpline


In [2]:
coords_df = pl.read_csv("MTA_Subway_Stations_filled.csv").select(
    pl.col("GTFS Stop ID").alias("id"),
    pl.col("Stop Name").alias("name"),
    pl.col("GTFS Latitude").alias("lat"),
    pl.col("GTFS Longitude").alias("long"),
    pl.col("Image coords").str.split(",").list[0].cast(pl.Float64).alias("x"),
    pl.col("Image coords").str.split(",").list[1].cast(pl.Float64).alias("y"),
)
coords_df

id,name,lat,long,x,y
str,str,f64,f64,f64,f64
"""R31""","""Atlantic Av-Barclays Ctr""",40.683666,-73.97881,3225.0,6172.0
"""R32""","""Union St""",40.677316,-73.98311,3225.0,6386.0
"""R33""","""4 Av-9 St""",40.670847,-73.988302,3221.0,6568.0
"""R34""","""Prospect Av""",40.665414,-73.992872,3224.0,6659.0
"""R35""","""25 St""",40.660397,-73.998091,3219.0,6755.0
…,…,…,…,…,…
"""S15""","""Prince's Bay""",40.525507,-74.200064,967.0,7744.0
"""S14""","""Pleasant Plains""",40.52241,-74.217847,907.0,7814.0
"""S13""","""Richmond Valley""",40.519631,-74.229141,844.0,7879.0
"""S11""","""Arthur Kill""",40.516578,-74.242096,784.0,7939.0


In [3]:
lat_min = coords_df["lat"].min()
lat_max = coords_df["lat"].max()
long_min = coords_df["long"].min()
long_max = coords_df["long"].max()

In [4]:
tps = ThinPlateSpline(alpha=1e-5, enforce_tps_kernel=True)
X_train = coords_df.select("long", "lat").to_numpy()
Y_train = coords_df.select("x", "y").to_numpy()
tps.fit(X_train, Y_train)

<tps.thin_plate_spline.ThinPlateSpline at 0x115485be0>

In [5]:
img = Image.open("Subway diagram.png")
draw = ImageDraw.Draw(img)

In [6]:
num_lines = 150
for lon in np.linspace(long_min, long_max, num_lines):
    points_to_draw = []
    for lat in np.linspace(lat_min, lat_max, 1000):
        xy = tps.transform(np.asarray([[lon, lat]]))
        x, y = xy[0][0].item(), xy[0][1].item()
        points_to_draw.append((x, y))
    draw.line(points_to_draw, fill="blue", width=5)
    # break

for lat in np.linspace(lat_min, lat_max, num_lines):
    points_to_draw = []
    for lon in np.linspace(long_min, long_max, 1000):
        xy = tps.transform(np.asarray([[lon, lat]]))
        x, y = xy[0][0].item(), xy[0][1].item()
        points_to_draw.append((x, y))
    draw.line(points_to_draw, fill="purple", width=5)
    # break


In [7]:
img.save("distorted_grid_map.png")