In [33]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from dataset import Dataset, to_device
from model import NNSingleFeatureModel, tims_mse_loss
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import pickle
import normalize_data
import random
from time import time
from tqdm.notebook import tqdm
tqdm.pandas()
from datetime import datetime
from sgp4.api import Satrec, SatrecArray, WGS72
pd.set_option('display.max_columns', 999)
pd.set_option('display.precision', 12)

In [2]:
raw_data = {} # loads raw data and stores as a dict cache

def dataset_key(dataset='', validation=False):
    return dataset+('test' if validation else 'train')

def load_data(raw, dataset='', validation=False):
    key = dataset+('test' if validation else 'train')
    if key not in raw:
        print(f"Loading data to cache for: {key}")
        raw[key] = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{key}.pkl')
    return raw[key]

In [3]:
def load_sub_model_with_config(train_config, model_configs, sub_model_key, X_count=0, force_recreate=False):
    path = train_config['model_path']
    prefix = train_config['model_prefix']
    model_config = model_configs[sub_model_key]
    f = f"{path}/{prefix}{sub_model_key}.pth"
    if os.path.exists(f) and not force_recreate:
        print("Loading existing model")
        checkpoint = torch.load(f)
        net = checkpoint['net']
        loss_func = checkpoint['loss_func']
        optimizer = checkpoint['optimizer']
        mean_losses = checkpoint['mean_losses']
        next_epoch = checkpoint['next_epoch']
    else:
        raise Exception('Model does not exist')
    return net, loss_func, optimizer, mean_losses, next_epoch

In [4]:
def predict(model, X, y, device='cpu'):
    pyt_device = torch.device(device)
    model.eval()
    X_tensor = torch.from_numpy(X.to_numpy()).float()
    nn_results = model(X_tensor).detach().numpy()
    return nn_results

In [5]:
def get_ref_X_y(df):
    ref_cols = [c for c in df.columns if c.startswith('__')]
    X_cols = [c for c in df.columns if c.startswith('X_')]
    y_cols = [c for c in df.columns if c.startswith('y_')]
    return (df[ref_cols], df[X_cols], df[y_cols])

In [51]:
def __jday_convert(x):
    '''
    Algorithm from python-sgp4:

    from sgp4.functions import jday
    jday(x.year, x.month, x.day, x.hour, x.minute, x.second + x.microsecond * 1e-6)
    '''
    jd = (367.0 * x.year
         - 7 * (x.year + ((x.month + 9) // 12.0)) * 0.25 // 1.0
           + 275 * x.month / 9.0 // 1.0
           + x.day
         + 1721013.5)
    fr = (x.second + (x.microsecond * 1e-6) + x.minute * 60.0 + x.hour * 3600.0) / 86400.0;
    return jd, fr

def get_satrec_erv(bst, ecc, aop, inc, mea, mem, raa, mmdot=0, mmddot=0, norad=0, epoch=None):
    '''
    Get cartesian coordinates of a satellite based on TLE parameters

     Parameters
     ----------
     bst : float : B-star
     ecc : float : eccentricity (in degrees)
     aop : float : argument of perigee (in degrees)
     inc : float : inclination (in degrees)
     mea : float : mean anomaly (in degrees)
     mem : float : mean motion (in degrees per minute)
     raa : float : right ascension of ascending node (in degrees)
     mmdot : float : NOT USED - ballistic coefficient
     mmddot : float : NOT USED - mean motion 2nd derivative
     norad : int : NOT USED - NORAD ID
     epoch : Timestamp : moment in time to get position

     Returns
     -------
     list
         [e, rx, ry, rz, vx, vy, yz] error, position xyz, velocity xyz.  error = 0 is good
    '''
    try:
        r = datetime.strptime('12/31/1949 00:00:00', '%m/%d/%Y %H:%M:%S')
        epoch_days = (epoch-r)/np.timedelta64(1, 'D')
        s = Satrec()
        s.sgp4init(
             WGS72,           # gravity model
             'i',             # 'a' = old AFSPC mode, 'i' = improved mode
             norad,               # satnum: Satellite number
             epoch_days,       # epoch: days since 1949 December 31 00:00 UT
             bst,      # bstar: drag coefficient (/earth radii)
             mmdot,   # ndot (NOT USED): ballistic coefficient (revs/day)
             mmddot,             # nddot (NOT USED): mean motion 2nd derivative (revs/day^3)
             ecc,       # ecco: eccentricity
             aop*np.pi/180, # argpo: argument of perigee (radians)
             inc*np.pi/180, # inclo: inclination (radians)
             mea*np.pi/180, # mo: mean anomaly (radians)
             mem*np.pi/(4*180), # no_kozai: mean motion (radians/minute)
             raa*np.pi/180, # nodeo: right ascension of ascending node (radians)
        )
        jday = __jday_convert(epoch)
        e,r,v = s.sgp4(*jday)
        return pd.Series([e, *r, *v])
    except:
        # e is SGP4 propagation errors, i've also added error 999 for when something goes wrong
        return pd.Series([999, 0,0,0,0,0,0])

In [7]:
train_config = {
    'dataset' : 'sample_', # '', 'sample_', 'secret_'
    'model_prefix' : "TRY_2_", 
    'model_path' : f"{os.environ['GP_HIST_PATH']}/../t5_models",
    'device' : 'cpu',
}

# sample_ uses train dataset, get SGP4 from that
sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/train_sgp4rv.pkl')

In [8]:
%%time
test_df = normalize_data.normalize_all_columns(load_data(raw_data,dataset=train_config['dataset'],validation=True)).dropna()
ref_test, X_test, y_test = get_ref_X_y(test_df)
y_cols = ['y_INCLINATION', 'y_ECCENTRICITY', 'y_MEAN_MOTION', 'y_RA_OF_ASC_NODE_REG', 'y_ARG_OF_PERICENTER_REG', 'y_REV_MA_REG', 'y_BSTAR']
y_test = y_test[y_cols]

Loading data to cache for: sample_test
CPU times: user 356 ms, sys: 600 ms, total: 956 ms
Wall time: 2.62 s


In [9]:
model_configs = {
    'y_INCLINATION': { 'feature_index': X_test.columns.get_loc('X_INCLINATION_1') },
    'y_ECCENTRICITY': { 'feature_index': X_test.columns.get_loc('X_ECCENTRICITY_1') },
    'y_MEAN_MOTION': { 'feature_index': X_test.columns.get_loc('X_MEAN_MOTION_1') },
    'y_RA_OF_ASC_NODE_REG': { 'feature_index': X_test.columns.get_loc('X_RA_OF_ASC_NODE_1') },
    'y_ARG_OF_PERICENTER_REG': { 'feature_index': X_test.columns.get_loc('X_ARG_OF_PERICENTER_1') },
    'y_REV_MA_REG': { 'feature_index': X_test.columns.get_loc('X_MEAN_ANOMALY_1') },
    'y_BSTAR': { 'feature_index': X_test.columns.get_loc('X_BSTAR_1') },
}

In [10]:
# Create or load all new sub models here if needed.
all_models = {}
pred_data = []
X_sample = X_test
y_sample = y_test[y_cols]
for sub_key in y_cols:
    # When new models are created, a dummy optimizer is used
    model, _, _, _, _ = load_sub_model_with_config(train_config, model_configs, sub_key)
    all_models[sub_key] = model
    y_sample_pred = predict(model, X_sample, y_sample, device="cpu") # get predictions for each train
    y_sample_pred_df = pd.DataFrame(y_sample_pred, columns=[sub_key], index=y_sample.index)  # put results into a dataframe
    pred_data.append(y_sample_pred_df)
pred_df = pd.concat(pred_data, axis=1)

Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model


In [14]:
def denormalize_predictions(indf):
    indf = indf.copy()
    d360 = ['y_ARG_OF_PERICENTER_REG','y_RA_OF_ASC_NODE_REG','y_REV_MA_REG']
    indf['y_REV_MA_REG'] = normalize_data.normalize(indf['y_REV_MA_REG'],min=0,max=90,reverse=True)
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf[d360] = indf[d360]%360
    indf['y_INCLINATION'] = normalize_data.normalize(indf['y_INCLINATION'],min=0,max=180,reverse=True)
    indf['y_INCLINATION'] = indf['y_INCLINATION']%180
    indf['y_ECCENTRICITY'] = normalize_data.normalize(indf['y_ECCENTRICITY'],min=0,max=0.25,reverse=True)
    indf['y_MEAN_MOTION'] = normalize_data.normalize(indf['y_MEAN_MOTION'],min=11.25,max=20,reverse=True)
    indf.name="Predictions"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def denormalize_ground_truths(indf):
    indf = indf.copy()
    d360 = ['X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1']
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf['X_INCLINATION_1'] = normalize_data.normalize(indf['X_INCLINATION_1'],min=0,max=180,reverse=True)
    indf['X_ECCENTRICITY_1'] = normalize_data.normalize(indf['X_ECCENTRICITY_1'],min=0,max=0.25,reverse=True)
    indf['X_MEAN_MOTION_1'] = normalize_data.normalize(indf['X_MEAN_MOTION_1'],min=11.25,max=20,reverse=True)
    indf.name="Ground Truths"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def get_ground_truths_xyz(df):
    s = ref_test['__GP_ID_2']
    s.name = 'GP_ID'
    df = df.merge(s, left_index=True, right_index=True)
    df = df.merge(sgp4xyz, left_on='GP_ID', right_index=True)
    return df

In [19]:
ground_truths = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
ground_truths = get_ground_truths_xyz(ground_truths).sort_index()
ground_truths = ground_truths[['SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(ground_truths)

Unnamed: 0,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,-2247.603194235576,6976.489740265147,-0.010024535399,-0.859817553482,-0.284764916051,7.335147277747
1,-2102.945801140380,7021.610715137136,-0.249484658555,-0.866219852855,-0.264526097323,7.334987886237
2,-2017.158478116495,7046.930142214801,-0.021140294377,-0.869633200356,-0.252884757147,7.334800602015
3,-1977.405754763013,7058.213969611379,-0.030735170327,-0.871246764901,-0.247293549558,7.334772939990
4,-1924.326182443856,7073.082960036271,-0.307548697930,-0.873307395759,-0.239759332089,7.334567451932
...,...,...,...,...,...,...
335637,-4005.029582473137,6263.597321480315,0.001428493234,-1.567743225900,-1.294414134696,7.006534060415
335638,-3632.923410636105,6468.449550943837,-0.001317372866,-1.641186266291,-1.207826825990,7.021405970211
335639,-3632.923410636105,6468.449550943837,-0.001317372866,-1.641186266291,-1.207826825990,7.021405970211
335640,5320.203614858912,4759.098905177199,-0.004350172217,-1.449433720895,1.523145072713,7.297762542087


In [53]:
predictions = denormalize_predictions(pred_df)
predictions = predictions.merge(ref_test, left_index=True, right_index=True)
predictions = predictions.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                       ecc=x.ECCENTRICITY,
                                                       aop=x.ARG_OF_PERICENTER,
                                                       inc=x.INCLINATION,
                                                       mea=x.MEAN_ANOMALY,
                                                       mem=x.MEAN_MOTION,
                                                       raa=x.RA_OF_ASC_NODE,
                                                       epoch=x.__EPOCH_2,), axis=1)
predictions.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
display(predictions)

  0%|          | 0/335642 [00:00<?, ?it/s]

Unnamed: 0,SAT_E,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,0.0,-2153.676425554715,6729.127575410055,1940.772198652492,-0.299264635445,-2.143829207741,7.068600047701
1,0.0,-1859.778852188865,7088.942428998477,81.718689110822,-0.853224656114,-0.313662766549,7.334891482285
2,0.0,-1733.331500590539,7120.151273681155,-154.213652586247,-0.914883013071,-0.067578510818,7.333415732738
3,0.0,-1873.900945825183,6861.262065063071,1762.651682873544,-0.433470377125,-1.951257880759,7.116697180255
4,0.0,-1689.396035971658,7128.418339661092,231.397145215587,-0.826771881005,-0.436408716817,7.331389218671
...,...,...,...,...,...,...,...
335637,0.0,-3847.084763584748,6364.920522416050,-84.651401851257,-1.640725446148,-1.185990709626,7.005690838378
335638,0.0,-3458.940411394046,6565.798378081370,-125.354437086960,-1.727717173206,-1.058424857496,7.021128671368
335639,0.0,-3514.751310443825,6538.488051831133,-1.434938081017,-1.663140586150,-1.175871872771,7.016965430779
335640,0.0,5512.117941029148,4534.508132716936,144.556577326753,-1.500299968009,1.487289805918,7.293799698657


In [59]:
baseline = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
baseline = baseline.merge(ref_test, left_index=True, right_index=True)
baseline = baseline.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                       ecc=x.ECCENTRICITY,
                                                       aop=x.ARG_OF_PERICENTER,
                                                       inc=x.INCLINATION,
                                                       mea=x.MEAN_ANOMALY,
                                                       mem=x.MEAN_MOTION,
                                                       raa=x.RA_OF_ASC_NODE,
                                                       epoch=x.__EPOCH_2,), axis=1)
baseline.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
display(baseline)

  0%|          | 0/335642 [00:00<?, ?it/s]

Unnamed: 0,SAT_E,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,0.0,-2286.893550009362,6963.735475550376,0.001162271689,-0.858003652017,-0.290232886069,7.335122477571
1,0.0,-2247.603194348781,6976.489740227655,-0.010023569655,-0.859817553182,-0.284764916982,7.335147277747
2,0.0,-2102.945801068716,7021.610715159019,-0.249485265383,-0.866219853031,-0.264526096734,7.334987886237
3,0.0,-2017.158478127208,7046.930142211687,-0.021140204038,-0.869633200331,-0.252884757235,7.334800602015
4,0.0,-1977.405754803880,7058.213969599779,-0.030734826287,-0.871246764807,-0.247293549893,7.334772939990
...,...,...,...,...,...,...,...
335637,0.0,-5865.871649518629,4703.585612312416,0.000469392287,-1.086129316288,-1.687013266652,6.928075759911
335638,0.0,-5577.232233189567,5021.832460461696,0.001293601536,-1.176503215710,-1.630820175520,6.940867775161
335639,0.0,-5865.871649610530,4703.585612169672,0.000469978520,-1.086129315822,-1.687013267026,6.928075759911
335640,0.0,4103.712172422393,5867.987157001509,0.002656784635,-1.788004589511,1.099155460874,7.274888944229


In [64]:
data_dict = {
    'tle_ground_truth':ground_truths,
    'sgp4_baseline':baseline,
    'model_t5_predictions':predictions,
}
with open('compare_output.pkl', 'wb') as handle:
    pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)