In [None]:
%matplotlib inline
import torch
import math
from math import sin, cos
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

π = math.pi

In [None]:
# NOTE This is hacktastic; expects tensors for r, φ and scalars for t, θ
def build_schwarzschild_t (t, r, θ, φ, rS = 1.0):
    # TODO protect against r == 1. - set to 1.000001?
    m = torch.zeros(len(φ), len(r), 4, 4)
    
    rS_r = (rS / r).repeat(len(φ)).view(len(φ), len(r))
    r2 = (r * r).repeat(len(φ)).view(len(φ), len(r))
    sθ = sin(θ)
    
    m[:,:,0,0] = 1 - rS_r
    m[:,:,1,1] = -(1 / (1 - rS_r))
    m[:,:,2,2] = -r2
    m[:,:,3,3] = -r2 * (sθ * sθ)
    return m


def build_sch_map2(distance_increment = 0.001, angular_slices = 64): # assumes constant time and theta - work in equatorial plane
    # work in coords of rS = 1, so r is ratio
    # t_vals = [0], theta_vals = [0]
    
    # distance values to check
    # NOTE start a bit outside the Schwarzschild radius, or Z blows up
    min_r = 1.5
    δ = distance_increment
    distances = int(20. / δ)
    r_vals = torch.tensor([min_r + i * δ for i in range(distances)])
    
    # angular values to check
    # NOTE: some redundant computation here, because spherical symmetry
    #       will keep for more complex cases, though
    angles = angular_slices
    φ_vals = torch.tensor([2. * π * i / angles for i in range(angles + 1)])
    
    v = dim_elements = torch.tensor([1., 1., 1., 1.]) # infinitesimal steps in each dim
    
    # build surfaces to represent space curvature and time dilation
    g = build_schwarzschild_t(0, r_vals, π / 2., φ_vals)
    metric_coeff = torch.matmul(g, v) # √
    r_grid = r_vals.repeat(len(φ_vals), 1)
    φ_grid = φ_vals.repeat(len(r_vals), 1).transpose(0,1)
    assert r_grid.shape == φ_grid.shape
    ws_grid = g[:,:,1,1]
    wt_grid = g[:,:,0,0]
    
    # TODO also build constant-φ radial curves - just for space, skip time
    radial_curves = [] # pardon my old, commented code below
#     for φ in φ_vals:
#         rs = []
#         φs = []
#         zs = []
#         for r in r_vals:
#             g = build_schwarzschild(0, r, 0, φ)
#             metric_coeff = torch.mv(g, v)
#             rs.append(r)
#             φs.append(φ)
#             zs.append(metric_coeff[1])
#         xs = torch.tensor(rs) * torch.cos(torch.tensor(φs))
#         ys = torch.tensor(rs) * torch.sin(torch.tensor(φs))
#         zs = torch.tensor(zs)
#         radial_curves.append((xs, ys, zs))

    x_grid, y_grid = r_grid * torch.cos(φ_grid), r_grid * torch.sin(φ_grid)
    return x_grid, y_grid, ws_grid, wt_grid, radial_curves # or maybe should convert to x, y?


def teq(t1, t2):
    return bool(torch.all(torch.eq(t1, t2)).item())

x_grid, y_grid, scurv, tdil, radial_curves = build_sch_map2()


In [None]:
def plot_scurv(x, y, s, curves):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.gca(projection='3d')
    ax.grid(False)

    surf = ax.plot_surface(x.numpy(), y.numpy(), s.numpy(), cmap=cm.spring, linewidth=1, antialiased=True, alpha=0.8)

    for curve in curves:
        xs, ys, zs = curve
        ax.plot(xs.numpy(), ys.numpy(), zs.numpy(), color="#333399")

    ax.set_xlim(-22, 22)
    ax.set_ylim(-22, 22)
    ax.set_zlim(-4, 1)
    ax.set_aspect('equal')

    # fig.colorbar(surf, shrink=0.5, aspect=5)

    plt.show()

plot_scurv(x_grid, y_grid, scurv, radial_curves)

In [None]:
def plot_tdil(x, y, t):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.gca(projection='3d')
    ax.grid(False)

    surf = ax.plot_surface(x.numpy(), y.numpy(), (1. / t).numpy(), cmap=cm.summer, linewidth=1, antialiased=True)

    ax.set_xlim(-22, 22)
    ax.set_ylim(-22, 22)
    ax.set_zlim(0, 4)
    ax.set_aspect('equal')

    # fig.colorbar(surf, shrink=0.5, aspect=5)

    plt.show()

plot_tdil(x_grid, y_grid, tdil)
print(x_grid.shape)