In [None]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from dataset import Dataset, to_device
from model import NNSingleFeatureModel, tims_mse_loss
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import pickle
import normalize_data
import random
from time import time
from tqdm.notebook import tqdm
tqdm.pandas()
from datetime import datetime
from sgp4.api import Satrec, SatrecArray, WGS72
pd.set_option('display.max_columns', 999)
pd.set_option('display.precision', 12)

In [2]:
raw_data = {} # loads raw data and stores as a dict cache

def dataset_key(dataset='', validation=False):
    return dataset+('test' if validation else 'train')

def load_data(raw, dataset='', validation=False):
    key = dataset+('test' if validation else 'train')
    if key not in raw:
        print(f"Loading data to cache for: {key}")
        raw[key] = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{key}.pkl')
    return raw[key]

In [3]:
def load_sub_model_with_config(train_config, model_configs, sub_model_key, X_count=0, force_recreate=False):
    path = train_config['model_path']
    prefix = train_config['model_prefix']
    model_config = model_configs[sub_model_key]
    f = f"{path}/{prefix}{sub_model_key}.pth"
    if os.path.exists(f) and not force_recreate:
        print("Loading existing model")
        checkpoint = torch.load(f)
        net = checkpoint['net']
        loss_func = checkpoint['loss_func']
        optimizer = checkpoint['optimizer']
        mean_losses = checkpoint['mean_losses']
        next_epoch = checkpoint['next_epoch']
    else:
        raise Exception('Model does not exist')
    return net, loss_func, optimizer, mean_losses, next_epoch

In [4]:
def predict(model, X, y, device='cpu'):
    pyt_device = torch.device(device)
    model.eval()
    X_tensor = torch.from_numpy(X.to_numpy()).float()
    nn_results = model(X_tensor).detach().numpy()
    return nn_results

In [5]:
def get_ref_X_y(df):
    ref_cols = [c for c in df.columns if c.startswith('__')]
    X_cols = [c for c in df.columns if c.startswith('X_')]
    y_cols = [c for c in df.columns if c.startswith('y_')]
    return (df[ref_cols], df[X_cols], df[y_cols])

In [6]:
def __jday_convert(x):
    '''
    Algorithm from python-sgp4:

    from sgp4.functions import jday
    jday(x.year, x.month, x.day, x.hour, x.minute, x.second + x.microsecond * 1e-6)
    '''
    jd = (367.0 * x.year
         - 7 * (x.year + ((x.month + 9) // 12.0)) * 0.25 // 1.0
           + 275 * x.month / 9.0 // 1.0
           + x.day
         + 1721013.5)
    fr = (x.second + (x.microsecond * 1e-6) + x.minute * 60.0 + x.hour * 3600.0) / 86400.0;
    return jd, fr

def get_satrec_erv(bst, ecc, aop, inc, mea, mem, raa, mmdot=0, mmddot=0, norad=0, epoch=None):
    '''
    Get cartesian coordinates of a satellite based on TLE parameters

     Parameters
     ----------
     bst : float : B-star
     ecc : float : eccentricity (in degrees)
     aop : float : argument of perigee (in degrees)
     inc : float : inclination (in degrees)
     mea : float : mean anomaly (in degrees)
     mem : float : mean motion (in degrees per minute)
     raa : float : right ascension of ascending node (in degrees)
     mmdot : float : NOT USED - ballistic coefficient
     mmddot : float : NOT USED - mean motion 2nd derivative
     norad : int : NOT USED - NORAD ID
     epoch : Timestamp : moment in time to get position

     Returns
     -------
     list
         [e, rx, ry, rz, vx, vy, yz] error, position xyz, velocity xyz.  error = 0 is good
    '''
    try:
        r = datetime.strptime('12/31/1949 00:00:00', '%m/%d/%Y %H:%M:%S')
        epoch_days = (epoch-r)/np.timedelta64(1, 'D')
        s = Satrec()
        s.sgp4init(
             WGS72,           # gravity model
             'i',             # 'a' = old AFSPC mode, 'i' = improved mode
             norad,               # satnum: Satellite number
             epoch_days,       # epoch: days since 1949 December 31 00:00 UT
             bst,      # bstar: drag coefficient (/earth radii)
             mmdot,   # ndot (NOT USED): ballistic coefficient (revs/day)
             mmddot,             # nddot (NOT USED): mean motion 2nd derivative (revs/day^3)
             ecc,       # ecco: eccentricity
             aop*np.pi/180, # argpo: argument of perigee (radians)
             inc*np.pi/180, # inclo: inclination (radians)
             mea*np.pi/180, # mo: mean anomaly (radians)
             mem*np.pi/(4*180), # no_kozai: mean motion (radians/minute)
             raa*np.pi/180, # nodeo: right ascension of ascending node (radians)
        )
        jday = __jday_convert(epoch)
        e,r,v = s.sgp4(*jday)
        return pd.Series([e, *r, *v])
    except:
        # e is SGP4 propagation errors, i've also added error 999 for when something goes wrong
        return pd.Series([999, 0,0,0,0,0,0])

In [7]:
train_config = {
    'dataset' : 'sample_big_', # '', 'sample_', 'secret_'
    'model_prefix' : "TRY_2_", 
    'model_path' : f"{os.environ['GP_HIST_PATH']}/../t5_models",
    'device' : 'cpu',
}

# sample_ uses train dataset, get SGP4 from that
# sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/train_sgp4rv.pkl')
sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/test_sgp4rv.pkl')

In [8]:
%%time
test_df = normalize_data.normalize_all_columns(load_data(raw_data,dataset=train_config['dataset'],validation=True)).dropna()
ref_test, X_test, y_test = get_ref_X_y(test_df)
y_cols = ['y_INCLINATION', 'y_ECCENTRICITY', 'y_MEAN_MOTION', 'y_RA_OF_ASC_NODE_REG', 'y_ARG_OF_PERICENTER_REG', 'y_REV_MA_REG', 'y_BSTAR']
y_test = y_test[y_cols]

Loading data to cache for: sample_big_test
CPU times: user 3.88 s, sys: 4.49 s, total: 8.37 s
Wall time: 6.61 s


In [9]:
model_configs = {
    'y_INCLINATION': { 'feature_index': X_test.columns.get_loc('X_INCLINATION_1') },
    'y_ECCENTRICITY': { 'feature_index': X_test.columns.get_loc('X_ECCENTRICITY_1') },
    'y_MEAN_MOTION': { 'feature_index': X_test.columns.get_loc('X_MEAN_MOTION_1') },
    'y_RA_OF_ASC_NODE_REG': { 'feature_index': X_test.columns.get_loc('X_RA_OF_ASC_NODE_1') },
    'y_ARG_OF_PERICENTER_REG': { 'feature_index': X_test.columns.get_loc('X_ARG_OF_PERICENTER_1') },
    'y_REV_MA_REG': { 'feature_index': X_test.columns.get_loc('X_MEAN_ANOMALY_1') },
    'y_BSTAR': { 'feature_index': X_test.columns.get_loc('X_BSTAR_1') },
}

In [10]:
# Create or load all new sub models here if needed.
all_models = {}
pred_data = []
X_sample = X_test
y_sample = y_test[y_cols]
for sub_key in y_cols:
    # When new models are created, a dummy optimizer is used
    model, _, _, _, _ = load_sub_model_with_config(train_config, model_configs, sub_key)
    all_models[sub_key] = model
    y_sample_pred = predict(model, X_sample, y_sample, device="cpu") # get predictions for each train
    y_sample_pred_df = pd.DataFrame(y_sample_pred, columns=[sub_key], index=y_sample.index)  # put results into a dataframe
    pred_data.append(y_sample_pred_df)
pred_df = pd.concat(pred_data, axis=1)

Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model


In [11]:
def denormalize_predictions(indf):
    indf = indf.copy()
    d360 = ['y_ARG_OF_PERICENTER_REG','y_RA_OF_ASC_NODE_REG','y_REV_MA_REG']
    indf['y_REV_MA_REG'] = normalize_data.normalize(indf['y_REV_MA_REG'],min=0,max=90,reverse=True)
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf[d360] = indf[d360]%360
    indf['y_INCLINATION'] = normalize_data.normalize(indf['y_INCLINATION'],min=0,max=180,reverse=True)
    indf['y_INCLINATION'] = indf['y_INCLINATION']%180
    indf['y_ECCENTRICITY'] = normalize_data.normalize(indf['y_ECCENTRICITY'],min=0,max=0.25,reverse=True)
    indf['y_MEAN_MOTION'] = normalize_data.normalize(indf['y_MEAN_MOTION'],min=11.25,max=20,reverse=True)
    indf.name="Predictions"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def denormalize_ground_truths(indf):
    indf = indf.copy()
    d360 = ['X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1']
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf['X_INCLINATION_1'] = normalize_data.normalize(indf['X_INCLINATION_1'],min=0,max=180,reverse=True)
    indf['X_ECCENTRICITY_1'] = normalize_data.normalize(indf['X_ECCENTRICITY_1'],min=0,max=0.25,reverse=True)
    indf['X_MEAN_MOTION_1'] = normalize_data.normalize(indf['X_MEAN_MOTION_1'],min=11.25,max=20,reverse=True)
    indf.name="Ground Truths"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def get_ground_truths_xyz(df):
    s = ref_test['__GP_ID_2']
    s.name = 'GP_ID'
    df = df.merge(s, left_index=True, right_index=True)
    df = df.merge(sgp4xyz, left_on='GP_ID', right_index=True)
    return df

In [24]:
ground_truths = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
ground_truths = get_ground_truths_xyz(ground_truths).sort_index()
ground_truths = ground_truths[['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
ground_truths = ground_truths.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
ground_truths.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
ground_truths = ground_truths[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(ground_truths)

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,2.307414610000,66.8093,0.0019542,13.87024311,74.4117,266.2382,93.6431,0.00256590,2644.649637377463,6825.762773305299,-0.302831316117,-2.705632314121,1.056823921345,6.785076780971
1,3.028445960000,66.8114,0.0019826,13.87036273,68.8221,266.2533,93.6251,0.00248920,3495.093500369564,6432.239204534228,0.005151885234,-2.548367470561,1.393534023780,6.784821208027
2,2.018816670000,66.7968,0.0020405,13.87143821,35.9336,259.6156,100.2594,0.00325870,6273.495433832282,3775.338747683804,-0.331242471557,-1.491277011755,2.493421345878,6.782718984472
3,3.388695700000,66.8176,0.0020565,13.87160098,16.7307,259.7897,100.0797,0.00133350,7241.256792328082,1083.350859378716,-0.713947246314,-0.420967584521,2.872881299793,6.783504172538
4,2.235077099988,66.8148,0.0020777,13.87170453,359.9450,256.3677,103.4965,0.00097176,7289.194851477701,-698.480230753366,-0.209235331193,0.284962502223,2.889168477739,6.782835542171
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3422007,12.024337480000,97.9139,0.0006253,14.81192989,68.1166,181.5303,178.5891,0.00040992,1218.528488626521,6903.782762870275,0.001091974984,1.019946207278,-0.190272563808,7.468256894684
3422008,12.767355500000,97.9140,0.0006283,14.81196521,68.6501,180.2147,179.9064,0.00040404,1065.508435020738,6928.909686514208,0.003209085554,1.023898435623,-0.167833176198,7.468391146110
3422009,12.767326070000,97.9139,0.0006281,14.81199888,69.1170,179.3407,180.7815,0.00041578,1009.017844810137,6937.329036388848,0.002650724318,1.025174399674,-0.159636823649,7.468425421362
3422010,12.429543600000,97.9140,0.0006306,14.81202997,69.6506,178.0071,182.1167,0.00039314,984.780427324029,6940.786868562617,0.002730168807,1.025725165427,-0.156082140598,7.468448787356


In [34]:
predictions = denormalize_predictions(pred_df)
predictions = predictions.merge(ref_test[['__EPOCH_2']], left_index=True, right_index=True)
predictions_sgp4 = predictions.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                                      ecc=x.ECCENTRICITY,
                                                                      aop=x.ARG_OF_PERICENTER,
                                                                      inc=x.INCLINATION,
                                                                      mea=x.MEAN_ANOMALY,
                                                                      mem=x.MEAN_MOTION,
                                                                      raa=x.RA_OF_ASC_NODE,
                                                                      epoch=x.__EPOCH_2,), axis=1)
predictions_sgp4.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
predictions = predictions.merge(predictions_sgp4, left_index=True, right_index=True)
predictions = predictions.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
predictions.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
predictions = predictions[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(predictions)

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,2.307414610000,66.815696716309,0.001953178784,13.871145136654,66.414474487305,264.176208496094,95.3310546875,0.002447555074,2946.364711558203,6701.105060456012,-44.353246805382,-2.638945536749,1.213345814458,6.784808118071
1,3.028445960000,66.817138671875,0.001979251858,13.871264569461,59.248462677002,264.728912353516,95.7597656250,0.002370860660,3716.559118870705,6306.302952619862,71.248527698613,-2.531576999215,1.424061302682,6.784778055319
2,2.018816670000,66.803459167480,0.002040410647,13.872340247035,28.627645492554,258.254638671875,101.7626953125,0.003140114713,6423.193733862166,3514.005127798604,16.262919938476,-1.400615155389,2.545000430665,6.783039033401
3,3.388695700000,66.823013305664,0.002051989082,13.872502706945,6.396644592285,258.391571044922,102.5703125000,0.001215097262,7268.701775423152,869.758595269584,127.296465962377,-0.454331774117,2.869372675449,6.782873790070
4,2.235077099988,66.821258544922,0.002076912671,13.872606493533,352.076721191406,254.253509521484,105.1552734375,0.000852736994,7249.112023955263,-1032.752206535539,-55.143619764548,0.467763800085,2.865199883626,6.782768554498
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3422007,12.024337480000,97.911499023438,0.000592916971,14.812830053270,80.813552856445,144.399032592773,216.0078125000,0.000292252982,1123.170778140166,6919.134245423385,29.780704622485,1.017068188484,-0.208018410927,7.468966000592
3422008,12.767355500000,97.910919189453,0.000593518373,14.812865257263,82.402145385742,140.889038085938,219.4765625000,0.000286324561,930.160328805034,6947.578097815735,24.381555201135,1.023142517613,-0.174135089156,7.469166579682
3422009,12.767326070000,97.910820007324,0.000593317905,14.812898896635,82.869338989258,140.106063842773,220.3437500000,0.000297928578,874.861051708502,6954.651787383302,34.494098021208,1.023144397877,-0.176744612349,7.469167465893
3422010,12.429543600000,97.911231994629,0.000596908503,14.812930189073,82.922065734863,139.802780151367,220.6562500000,0.000275478582,868.604343780467,6955.424035928598,35.544425510150,1.023213914204,-0.176978610355,7.469154220376


In [35]:
baseline = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
baseline = baseline.merge(ref_test, left_index=True, right_index=True)
baseline_sgp4 = baseline.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                       ecc=x.ECCENTRICITY,
                                                       aop=x.ARG_OF_PERICENTER,
                                                       inc=x.INCLINATION,
                                                       mea=x.MEAN_ANOMALY,
                                                       mem=x.MEAN_MOTION,
                                                       raa=x.RA_OF_ASC_NODE,
                                                       epoch=x.__EPOCH_2,), axis=1)
baseline_sgp4.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
baseline = baseline.merge(baseline_sgp4, left_index=True, right_index=True)
baseline = baseline.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
baseline.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
baseline = baseline[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(baseline)

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,2.307414610000,66.8093,0.0019542,13.87024311,74.4117,266.2382,93.6431,0.00256590,1967.255600834512,7050.928302726330,-0.343841009810,-2.795983604719,0.788171511598,6.784959890251
1,3.028445960000,66.8114,0.0019826,13.87036273,68.8221,266.2533,93.6251,0.00248920,2644.649637111497,6825.762773409188,-0.302830649138,-2.705632314385,1.056823920663,6.785076780971
2,2.018816670000,66.7968,0.0020405,13.87143821,35.9336,259.6156,100.2594,0.00325870,5928.325476581760,4296.525654316197,-0.314913314488,-1.698605238275,2.357652232060,6.782897263670
3,3.388695700000,66.8176,0.0020565,13.87160098,16.7307,259.7897,100.0797,0.00133350,7011.637541529944,2107.353322155683,-0.747620531544,-0.827352794897,2.783001866838,6.783998580221
4,2.235077099988,66.8148,0.0020777,13.87170453,359.9450,256.3677,103.4965,0.00097176,7322.349588755003,-7.638277259520,-1.421796170572,0.012315250278,2.903330762735,6.783024731122
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3422007,12.024337480000,97.9139,0.0006253,14.81192989,68.1166,181.5303,178.5891,0.00040992,2613.169080256581,6505.910102512269,-0.004918153408,0.959709208525,-0.393940033423,7.467768969991
3422008,12.767355500000,97.9140,0.0006283,14.81196521,68.6501,180.2147,179.9064,0.00040404,2552.482921061801,6529.969889255255,0.001710257526,0.963303423823,-0.385098286695,7.467749172096
3422009,12.767326070000,97.9139,0.0006281,14.81199888,69.1170,179.3407,180.7815,0.00041578,2499.182008345241,6550.541275136905,0.001870179208,0.966372509472,-0.377299143750,7.467758484808
3422010,12.429543600000,97.9140,0.0006306,14.81202997,69.6506,178.0071,182.1167,0.00039314,2438.069235809936,6573.537241698386,-0.008549272301,0.969820528433,-0.368380423596,7.467745887145


In [36]:
data_dict = {
    'tle_ground_truth':ground_truths,
    'sgp4_baseline':baseline,
    'model_t5_predictions':predictions,
    'ref':ref_test,
}
with open(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{train_config["dataset"]}_compare_output.pkl', 'wb') as handle:
    pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)