In [1]:
import torch
from scipy.spatial import transform
import numpy as np
from orbit_gen import get_random_orbit, get_Rz, get_nadir_attitude, get_nadir_attitude_vectors

In [2]:
def get_eci_orbit(tf=None):
    if tf is not None:
        orbit_eci, tsamp, orbit_eci_q = get_random_orbit(tf=tf)
    else:
        orbit_eci, tsamp, orbit_eci_q = get_random_orbit()
    return orbit_eci, tsamp, orbit_eci_q

def get_orbit(tf=None):
    if tf is not None:
        orbit_eci, tsamp, orbit_eci_q = get_eci_orbit(tf)
    else:   
        orbit_eci, tsamp, orbit_eci_q = get_eci_orbit()
    # get time samples for Earth rotation
    Rzs = get_Rz(tsamp)
    # Get ECI position vector
    r_eci = orbit_eci[:, :3, np.newaxis]
    # Convert ECI position vector to ECEF
    r_ecef = np.matmul(Rzs, r_eci)

    # Stack position and attitude and convert to meters
    orbit_ecef = np.concatenate([r_ecef[:,:,0]*1000, orbit_eci[:,3:]], axis=1)
    return orbit_ecef, orbit_eci, tsamp, orbit_eci_q

In [3]:
orbit_ecef, orbit_eci, tsamp, orbit_eci_q = get_orbit(1000)

In [4]:
class Vector3D:
    def __init__(self, x, y, z, device='cuda', dtype=torch.float64):
        self.x = torch.tensor(x, device=device, dtype=dtype)
        self.y = torch.tensor(y, device=device, dtype=dtype)
        self.z = torch.tensor(z, device=device, dtype=dtype)
    def get(self):
        return torch.stack([self.x, self.y, self.z], dim=0)

In [23]:
class SatelliteCamera:
    def __init__(self, fov, res, device='cuda', dtype=torch.float64):
        self.hfov = torch.tensor(fov, device=device, dtype=dtype)
        self.res = torch.tensor(res, device=device, dtype=dtype)
        self.vfov = self.hfov * (self.res[1] / self.res[0])
        self.focal_length = (self.res[0] / 2) / torch.tan(torch.deg2rad(self.hfov) / 2)
        self.device = device
        self.dtype = dtype
        self.K = self.intrinsic_matrix()
        self.pixel_locs = self.get_pixel_coords(res[1], res[0])
    
    def get_intrinsic_matrix(self):
        return self.K

    def intrinsic_matrix(self):
        return torch.tensor([[self.focal_length, 0, self.res[0]/2], [0, self.focal_length, self.res[1]/2], [0, 0, 1]], device=self.device, dtype=self.dtype)
    
    def get_extrinsic_matrices(self, pos_vecs, dir_vecs, up_vecs, right_vecs):
        R_cw = torch.stack([right_vecs.get(), -up_vecs.get(), dir_vecs.get()]).permute(2, 0, 1)
        pos_c = -torch.matmul(R_cw, pos_vecs.get().t().unsqueeze(-1))
        K_hom = torch.cat((self.K, torch.tensor([[0, 0, 0]], dtype=self.dtype, device=self.device)))
        K_hom = torch.cat((K_hom, torch.tensor([[0, 0, 0, 1]], dtype=self.dtype, device=self.device).t()), dim=1)
        E = torch.cat((R_cw, pos_c), dim=-1)
        hombatch = torch.zeros((E.shape[0], 1, 4), dtype=self.dtype, device=self.device)
        hombatch[:,:,-1] = 1
        E_hom = torch.cat((E, hombatch),dim=1)
        C_cw = torch.matmul(K_hom, E_hom)
        self.C_cw = C_cw
        return C_cw
    
    def get_pixel_coords(self, h, w):
        xs = torch.arange(0, w, device=self.device, dtype=self.dtype) + 0.5
        ys = torch.arange(0, h, device=self.device, dtype=self.dtype) + 0.5
        cartesian_prod = torch.cartesian_prod(xs, ys)
        pixel_locs = torch.cat((cartesian_prod, torch.ones((w*h, 1), device=self.device, dtype=self.dtype)), dim=1)
        return pixel_locs

    def get_corner_rays(self):
        C_cw_inv = torch.inverse(self.C_cw)
        corner_locs_hom = torch.tensor([[0, 0, 1, 0], [0, self.res[1], 1, 0], [self.res[0], self.res[1], 1, 0], [self.res[0], 0, 1, 0]], device=self.device, dtype=self.dtype)
        ray_vecs = torch.matmul(C_cw_inv, corner_locs_hom.t())
        ray_vecs = ray_vecs[:, :3, :] / torch.norm(ray_vecs[:, :3, :], dim=1).unsqueeze(1)
        return ray_vecs

    def get_ray_directions(self):
        C_cw_inv = torch.inverse(self.C_cw)
        pixel_locs_hom = torch.cat([self.pixel_locs, torch.zeros((self.pixel_locs.shape[0], 1), device=self.device, dtype=self.dtype)], dim=1)
        ray_vecs = torch.matmul(C_cw_inv, pixel_locs_hom.t())
        ray_vecs = ray_vecs[:, :3, :] / torch.norm(ray_vecs[:, :3, :], dim=1).unsqueeze(1)
        return ray_vecs
        
    def cast_rays(self, rays, pos_vecs):
        a = b = torch.tensor(6378137.0, device=self.device, dtype=self.dtype)
        c = torch.tensor(6356752.31424518, device=self.device, dtype=self.dtype)
        us = rays[:, 0, :]
        vs = rays[:, 1, :]
        ws = rays[:, 2, :]
        x = pos_vecs.x.unsqueeze(-1)
        y = pos_vecs.y.unsqueeze(-1)
        z = pos_vecs.z.unsqueeze(-1)
        value = -a**2*b**2*ws*z - a**2*c**2*vs*y - b**2*c**2*us*x
        radical = a**2*b**2*ws**2 + a**2*c**2*vs**2 - a**2*vs**2*z**2 + 2*a**2*vs*ws*y*z - a**2*ws**2*y**2 + b**2*c**2*us**2 - b**2*us**2*z**2 + 2*b**2*us*ws*x*z - b**2*ws**2*x**2 - c**2*us**2*y**2 + 2*c**2*us*vs*x*y - c**2*vs**2*x**2
        magnitude = a**2*b**2*ws**2 + a**2*c**2*vs**2 + b**2*c**2*us**2
        d = (value - a*b*c*torch.sqrt(radical)) / magnitude
        d[d < 0] = torch.nan

        xs_ecef = x + us*d
        ys_ecef = y + vs*d
        zs_ecef = z + ws*d

        return torch.stack([xs_ecef, ys_ecef, zs_ecef], dim=1)
    
    def ecef_to_llh(self, ecef_locs):
        a = torch.tensor(6378137.0, device=self.device, dtype=self.dtype)
        c = torch.tensor(6356752.31424518, device=self.device, dtype=self.dtype)
        f = (a - c) / a
        e_sq = f * (2 - f)
        eps = e_sq / (1.0 - e_sq)
        xs = ecef_locs[:, 0]
        ys = ecef_locs[:, 1]
        zs = ecef_locs[:, 2]
        p = torch.sqrt(xs**2 + ys**2)
        q = torch.atan2(zs * a, p * c)
        sin_q = torch.sin(q)
        cos_q = torch.cos(q)
        sin_q3 = sin_q**3
        cos_q3 = cos_q**3
        phi = torch.atan2(zs + eps * c * sin_q3, p - e_sq * a * cos_q3)
        lam = torch.atan2(ys, xs)
        # v = a / torch.sqrt(1.0 - e_sq * torch.sin(phi)**2)
        # h = p / torch.cos(phi) - v
        lat = torch.rad2deg(phi)
        lon = torch.rad2deg(lam)
        return lon, lat
    

cam = SatelliteCamera(66, [4608, 2592])

In [25]:
device = 'cuda'
dtype = torch.float64
pos_vecs = Vector3D(orbit_ecef[:, 0], orbit_ecef[:, 1], orbit_ecef[:, 2], device=device, dtype=dtype)
dir_vecs = Vector3D(orbit_ecef[:, 3], orbit_ecef[:, 4], orbit_ecef[:, 5], device=device, dtype=dtype)
up_vecs = Vector3D(orbit_ecef[:, 6], orbit_ecef[:, 7], orbit_ecef[:, 8], device=device, dtype=dtype)
right_vecs = Vector3D(orbit_ecef[:, 9], orbit_ecef[:, 10], orbit_ecef[:, 11], device=device, dtype=dtype)
C_cw = cam.get_extrinsic_matrices(pos_vecs, dir_vecs, up_vecs, right_vecs)
batch_size = 1
num_batches = orbit_ecef.shape[0] // batch_size

if num_batches * batch_size < orbit_ecef.shape[0]:
    num_batches += 1
num_batches=1
for batch in range(num_batches):
    start_idx = batch * batch_size
    end_idx = min((batch + 1) * batch_size, orbit_ecef.shape[0])
    pos_vecs = Vector3D(orbit_ecef[start_idx:end_idx, 0], orbit_ecef[start_idx:end_idx, 1], orbit_ecef[start_idx:end_idx, 2], device=device, dtype=dtype)
    dir_vecs = Vector3D(orbit_ecef[start_idx:end_idx, 3], orbit_ecef[start_idx:end_idx, 4], orbit_ecef[start_idx:end_idx, 5], device=device, dtype=dtype)
    up_vecs = Vector3D(orbit_ecef[start_idx:end_idx, 6], orbit_ecef[start_idx:end_idx, 7], orbit_ecef[start_idx:end_idx, 8], device=device, dtype=dtype)
    right_vecs = Vector3D(orbit_ecef[start_idx:end_idx, 9], orbit_ecef[start_idx:end_idx, 10], orbit_ecef[start_idx:end_idx, 11], device=device, dtype=dtype)
    C_cw = cam.get_extrinsic_matrices(pos_vecs, dir_vecs, up_vecs, right_vecs)
    ray_vecs = cam.get_ray_directions()
    print(ray_vecs.shape)
    #ray_vecs = ray_vecs[:, :3, :] / torch.norm(ray_vecs[:, :3, :], dim=1).unsqueeze(1)
    ecef_locs = cam.cast_rays(ray_vecs, pos_vecs)
    lon, lat = cam.ecef_to_llh(ecef_locs)
    print(lon, lat)



torch.Size([1, 3, 11943936])
torch.Size([1, 11943936]) tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float64)


In [None]:
# index = 0
# C_cw_inv = torch.inverse(C_cw[index])
# pixel_locs_hom = torch.cat([cam.pixel_locs, torch.zeros((cam.pixel_locs.shape[0], 1), device=cam.device, dtype=cam.dtype)], dim=1)
# ray_vecs = torch.matmul(C_cw_inv, pixel_locs_hom.t())
# print(ray_vecs.shape)
# corner_rays = cam.get_corner_rays()
# rays = corner_rays
# ecef_locs = cam.cast_rays(rays, pos_vecs)
# lon, lat = cam.ecef_to_llh(ecef_locs)
# print(lon, lat)

# a = torch.tensor(6378137.0, device='cuda', dtype=torch.double)
# c = torch.tensor(6356752.31424518, device='cuda', dtype=torch.double)
# f = (a - c) / a
# e_sq = f * (2 - f)
# eps = e_sq / (1.0 - e_sq)
# xs = ecef_locs[:, 0]
# ys = ecef_locs[:, 1]
# zs = ecef_locs[:, 2]
# p = torch.sqrt(xs**2 + ys**2)
# q = torch.atan2(zs * a, p * c)
# sin_q = torch.sin(q)
# cos_q = torch.cos(q)
# sin_q3 = sin_q**3
# cos_q3 = cos_q**3
# phi = torch.atan2(zs + eps * c * sin_q3, p - e_sq * a * cos_q3)
# lam = torch.atan2(ys, xs)
# v = a / torch.sqrt(1.0 - e_sq * torch.sin(phi)**2)
# h = p / torch.cos(phi) - v
# lat = torch.rad2deg(phi)
# lon = torch.rad2deg(lam)
# print(lat, lon)

# print(corner_rays.shape)
# a = b = torch.tensor(6378137.0, device=device, dtype=dtype)
# c = torch.tensor(6356752.31424518, device=device, dtype=dtype)
# us = rays[:, 0, :]
# vs = rays[:, 1, :]
# ws = rays[:, 2, :]
# x = pos_vecs.x.unsqueeze(-1)
# y = pos_vecs.y.unsqueeze(-1)
# z = pos_vecs.z.unsqueeze(-1)
# print(x.shape, us.shape)

# value = -a**2*b**2*ws*z - a**2*c**2*vs*y - b**2*c**2*us*x
# radical = a**2*b**2*ws**2 + a**2*c**2*vs**2 - a**2*vs**2*z**2 + 2*a**2*vs*ws*y*z - a**2*ws**2*y**2 + b**2*c**2*us**2 - b**2*us**2*z**2 + 2*b**2*us*ws*x*z - b**2*ws**2*x**2 - c**2*us**2*y**2 + 2*c**2*us*vs*x*y - c**2*vs**2*x**2
# magnitude = a**2*b**2*ws**2 + a**2*c**2*vs**2 + b**2*c**2*us**2
# d = (value - a*b*c*torch.sqrt(radical)) / magnitude
# d[d < 0] = torch.nan

# xs_ecef = x + us*d
# ys_ecef = y + vs*d
# zs_ecef = z + ws*d

# torch.stack([xs_ecef, ys_ecef, zs_ecef], dim=1).shape

# C_cw_inv = torch.inverse(C_cw)
# print(C_cw.shape)

# pixel_locs = cam.pixel_locs
# pixel_locs_hom = torch.cat([pixel_locs, torch.zeros((pixel_locs.shape[0], 1), device=device, dtype=dtype)], dim=1)
# corner_locs_hom = torch.tensor([[0, 0, 1, 0], [0, 2592, 1, 0], [4608, 0, 1, 0], [4608, 2592, 1, 0]], device=device, dtype=dtype)
# print(corner_locs_hom)

# ray_vecs = torch.matmul(C_cw_inv, corner_locs_hom.t())
# print(ray_vecs[0])
# print(ray_vecs[:, :3, :].shape, torch.norm(ray_vecs[:, :3, :], dim=1).unsqueeze(1).shape)
# ray_vecs = ray_vecs[:, :3, :] / torch.norm(ray_vecs[:, :3, :], dim=1).unsqueeze(1)

In [None]:
# R_cw = torch.stack([right_vecs.get().t(), -up_vecs.get().t(), dir_vecs.get().t()],dim=-1)
# print(R_cw.shape)
# pos_c = -torch.matmul(R_cw, pos_vecs.get().t().unsqueeze(-1))
# K = cam.K
# print(K)
# K_hom = torch.cat((K, torch.tensor([[0, 0, 0]], dtype=dtype, device=device)))
# K_hom = torch.cat((K_hom, torch.tensor([[0, 0, 0, 1]], dtype=dtype, device=device).t()), dim=1)
# print(K_hom)
# print(R_cw.shape, pos_c.shape)
# E = torch.cat((R_cw, pos_c), dim=-1)
# hombatch = torch.zeros((E.shape[0], 1, 4), dtype=dtype, device=device)
# hombatch[:,:,-1] = 1
# print(E.shape, hombatch.shape)
# E_hom = torch.cat((E, hombatch),dim=1)
# torch.matmul(K_hom, E_hom)


In [None]:
# device = 'cuda'
# dtype = torch.float64
# pose = [549496.7936069818, -6128343.990241624, 3287548.8567509046,
#         -0.07876803246160267, 0.878472092968163, -0.4712561712468741,
#         -0.04208625123151773, 0.46937312065726633, 0.8819963838144332,
#         0.9960041890069568, 0.08930652540881052, -0.0]
# pos_vec = Vector3D(*pose[:3], device, dtype)
# dir_vec = Vector3D(*pose[3:6], device, dtype)
# up_vec = Vector3D(*pose[6:9], device, dtype)
# right_vec = Vector3D(*pose[9:12], device, dtype)