In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from dataset import Dataset, to_device
from model import NNSingleFeatureModel, tims_mse_loss
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import pickle
import normalize_data
import random
from time import time
from tqdm.notebook import tqdm
tqdm.pandas()
from datetime import datetime
from sgp4.api import Satrec, SatrecArray, WGS72
pd.set_option('display.max_columns', 999)
pd.set_option('display.precision', 12)

In [2]:
raw_data = {} # loads raw data and stores as a dict cache

def dataset_key(dataset='', validation=False):
    return dataset+('test' if validation else 'train')

def load_data(raw, dataset='', validation=False):
    key = dataset+('test' if validation else 'train')
    if key not in raw:
        print(f"Loading data to cache for: {key}")
        raw[key] = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{key}.pkl')
    return raw[key]

In [3]:
def load_sub_model_with_config(train_config, model_configs, sub_model_key, X_count=0, force_recreate=False):
    path = train_config['model_path']
    prefix = train_config['model_prefix']
    model_config = model_configs[sub_model_key]
    f = f"{path}/{prefix}{sub_model_key}.pth"
    if os.path.exists(f) and not force_recreate:
        print("Loading existing model")
        checkpoint = torch.load(f)
        net = checkpoint['net']
        loss_func = checkpoint['loss_func']
        optimizer = checkpoint['optimizer']
        mean_losses = checkpoint['mean_losses']
        next_epoch = checkpoint['next_epoch']
    else:
        raise Exception('Model does not exist')
    return net, loss_func, optimizer, mean_losses, next_epoch

In [4]:
def predict(model, X, y, device='cpu'):
    pyt_device = torch.device(device)
    model.eval()
    X_tensor = torch.from_numpy(X.to_numpy()).float()
    nn_results = model(X_tensor).detach().numpy()
    return nn_results

In [5]:
def get_ref_X_y(df):
    ref_cols = [c for c in df.columns if c.startswith('__')]
    X_cols = [c for c in df.columns if c.startswith('X_')]
    y_cols = [c for c in df.columns if c.startswith('y_')]
    return (df[ref_cols], df[X_cols], df[y_cols])

In [6]:
def __jday_convert(x):
    '''
    Algorithm from python-sgp4:

    from sgp4.functions import jday
    jday(x.year, x.month, x.day, x.hour, x.minute, x.second + x.microsecond * 1e-6)
    '''
    jd = (367.0 * x.year
         - 7 * (x.year + ((x.month + 9) // 12.0)) * 0.25 // 1.0
           + 275 * x.month / 9.0 // 1.0
           + x.day
         + 1721013.5)
    fr = (x.second + (x.microsecond * 1e-6) + x.minute * 60.0 + x.hour * 3600.0) / 86400.0;
    return jd, fr

def get_satrec_erv(bst, ecc, aop, inc, mea, mem, raa, mmdot=0, mmddot=0, norad=0, epoch1=None, epoch2=None):
    '''
    Get cartesian coordinates of a satellite based on TLE parameters

     Parameters
     ----------
     bst : float : B-star
     ecc : float : eccentricity (in degrees)
     aop : float : argument of perigee (in degrees)
     inc : float : inclination (in degrees)
     mea : float : mean anomaly (in degrees)
     mem : float : mean motion (in degrees per minute)
     raa : float : right ascension of ascending node (in degrees)
     mmdot : float : NOT USED - ballistic coefficient
     mmddot : float : NOT USED - mean motion 2nd derivative
     norad : int : NOT USED - NORAD ID
     epoch : Timestamp : moment in time to get position

     Returns
     -------
     list
         [e, rx, ry, rz, vx, vy, yz] error, position xyz, velocity xyz.  error = 0 is good
    '''
    try:
        r = datetime.strptime('12/31/1949 00:00:00', '%m/%d/%Y %H:%M:%S')
        epoch_days = (epoch1-r)/np.timedelta64(1, 'D')
        s = Satrec()
        s.sgp4init(
             WGS72,           # gravity model
             'i',             # 'a' = old AFSPC mode, 'i' = improved mode
             norad,               # satnum: Satellite number
             epoch_days,       # epoch: days since 1949 December 31 00:00 UT
             bst,      # bstar: drag coefficient (/earth radii)
             mmdot,   # ndot (NOT USED): ballistic coefficient (revs/day)
             mmddot,             # nddot (NOT USED): mean motion 2nd derivative (revs/day^3)
             ecc,       # ecco: eccentricity
             aop*np.pi/180, # argpo: argument of perigee (radians)
             inc*np.pi/180, # inclo: inclination (radians)
             mea*np.pi/180, # mo: mean anomaly (radians)
             mem*np.pi/(4*180), # no_kozai: mean motion (radians/minute)
             raa*np.pi/180, # nodeo: right ascension of ascending node (radians)
        )
        jday = __jday_convert(epoch2)
        e,r,v = s.sgp4(*jday)
        return pd.Series([e, *r, *v])
    except:
        # e is SGP4 propagation errors, i've also added error 999 for when something goes wrong
        return pd.Series([999, 0,0,0,0,0,0])

In [7]:
train_config = {
    'dataset' : 'sample4_', # '', 'sample_', 'secret_'
    'model_prefix' : "TRY_2_", 
    'model_path' : f"{os.environ['GP_HIST_PATH']}/../t5_models",
    'device' : 'cpu',
}

# sample_ uses train dataset, get SGP4 from that
# sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/train_sgp4rv.pkl')

sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/test_sgp4rv.pkl')

In [8]:
%%time
test_df = normalize_data.normalize_all_columns(load_data(raw_data,dataset=train_config['dataset'],validation=True)).dropna()
ref_test, X_test, y_test = get_ref_X_y(test_df)
y_cols = ['y_INCLINATION', 'y_ECCENTRICITY', 'y_MEAN_MOTION', 'y_RA_OF_ASC_NODE_REG', 'y_ARG_OF_PERICENTER_REG', 'y_REV_MA_REG', 'y_BSTAR']
y_test = y_test[y_cols]

Loading data to cache for: sample4_test
CPU times: user 2.21 s, sys: 3.3 s, total: 5.52 s
Wall time: 13.3 s


In [9]:
model_configs = {
    'y_INCLINATION': { 'feature_index': X_test.columns.get_loc('X_INCLINATION_1') },
    'y_ECCENTRICITY': { 'feature_index': X_test.columns.get_loc('X_ECCENTRICITY_1') },
    'y_MEAN_MOTION': { 'feature_index': X_test.columns.get_loc('X_MEAN_MOTION_1') },
    'y_RA_OF_ASC_NODE_REG': { 'feature_index': X_test.columns.get_loc('X_RA_OF_ASC_NODE_1') },
    'y_ARG_OF_PERICENTER_REG': { 'feature_index': X_test.columns.get_loc('X_ARG_OF_PERICENTER_1') },
    'y_REV_MA_REG': { 'feature_index': X_test.columns.get_loc('X_MEAN_ANOMALY_1') },
    'y_BSTAR': { 'feature_index': X_test.columns.get_loc('X_BSTAR_1') },
}

In [10]:
# Create or load all new sub models here if needed.
all_models = {}
pred_data = []
X_sample = X_test
y_sample = y_test[y_cols]
for sub_key in y_cols:
    # When new models are created, a dummy optimizer is used
    model, _, _, _, _ = load_sub_model_with_config(train_config, model_configs, sub_key)
    all_models[sub_key] = model
    y_sample_pred = predict(model, X_sample, y_sample, device="cpu") # get predictions for each train
    y_sample_pred_df = pd.DataFrame(y_sample_pred, columns=[sub_key], index=y_sample.index)  # put results into a dataframe
    pred_data.append(y_sample_pred_df)
pred_df = pd.concat(pred_data, axis=1)

Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model


In [11]:
def denormalize_predictions(indf):
    indf = indf.copy()
    d360 = ['y_ARG_OF_PERICENTER_REG','y_RA_OF_ASC_NODE_REG','y_REV_MA_REG']
    indf['y_REV_MA_REG'] = normalize_data.normalize(indf['y_REV_MA_REG'],min=0,max=90,reverse=True)
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf[d360] = indf[d360]%360
    indf['y_INCLINATION'] = normalize_data.normalize(indf['y_INCLINATION'],min=0,max=180,reverse=True)
    indf['y_INCLINATION'] = indf['y_INCLINATION']%180
    indf['y_ECCENTRICITY'] = normalize_data.normalize(indf['y_ECCENTRICITY'],min=0,max=0.25,reverse=True)
    indf['y_MEAN_MOTION'] = normalize_data.normalize(indf['y_MEAN_MOTION'],min=11.25,max=20,reverse=True)
    indf.name="Predictions"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def denormalize_ground_truths(indf):
    indf = indf.copy()
    d360 = ['X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1']
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf['X_INCLINATION_1'] = normalize_data.normalize(indf['X_INCLINATION_1'],min=0,max=180,reverse=True)
    indf['X_ECCENTRICITY_1'] = normalize_data.normalize(indf['X_ECCENTRICITY_1'],min=0,max=0.25,reverse=True)
    indf['X_MEAN_MOTION_1'] = normalize_data.normalize(indf['X_MEAN_MOTION_1'],min=11.25,max=20,reverse=True)
    indf.name="Ground Truths"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def get_ground_truths_xyz(df):
    s = ref_test['__GP_ID_2']
    s.name = 'GP_ID'
    df = df.merge(s, left_index=True, right_index=True)
    df = df.merge(sgp4xyz, left_on='GP_ID', right_index=True)
    return df

In [12]:
ground_truths = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
ground_truths = get_ground_truths_xyz(ground_truths).sort_index()
ground_truths = ground_truths[['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
ground_truths = ground_truths.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
ground_truths.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
ground_truths = ground_truths[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(ground_truths)

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,1.001690209988,96.4160,0.1186894,11.98433914,248.8476,0.5577,359.6540,0.000055674,-2507.664802091827,-6654.727034856126,0.819542599571,-0.833354682012,0.297326304232,7.870415717233
1,1.001688860012,96.4167,0.1186916,11.98434376,249.3531,358.4278,1.3246,0.000081452,-2449.221810602869,-6677.845227561918,2.134646089881,-0.846439422876,0.261524056737,7.868982855997
2,1.001686159988,96.4183,0.1187012,11.98435476,249.8605,356.2713,3.0202,0.000145080,-2390.820222783404,-6701.432486914496,1.541446467213,-0.859343019036,0.223806882668,7.866428283328
3,1.001686330012,96.4197,0.1187132,11.98436032,250.3680,354.1374,4.6816,0.000173990,-2332.746267185475,-6725.422764420366,2.159592278920,-0.871099667952,0.186984283302,7.862731221566
4,2.003368019988,96.4201,0.1187358,11.98437786,250.8726,351.9626,6.3943,0.000259430,-2216.480894902044,-6775.023827650037,1.603887588121,-0.892934381598,0.111533315375,7.851826596835
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2069772,7.836440310000,51.3698,0.0650555,13.26524224,304.8717,300.2421,53.5600,0.004909800,935.444524536351,-7092.823953677166,-0.007327737789,4.764777746522,0.372308807175,5.979705127039
2069773,7.309026610000,51.3698,0.0647227,13.26532258,302.5002,302.3334,51.6458,0.006155600,870.175916853176,-7099.599389502044,0.002198767252,4.768568926974,0.331961013595,5.981043018447
2069774,7.082984530000,51.3698,0.0645659,13.26511170,301.4465,303.2644,50.7964,-0.002664800,837.530598966206,-7102.759586249131,0.007481255699,4.770349597258,0.311709105725,5.981707024254
2069775,6.932372470000,51.3672,0.0640280,13.26510920,298.0233,306.3596,47.9818,-0.000513510,477.790546140828,-7128.199486736229,-0.000259202627,4.783847045830,0.088997394621,5.988580509084


In [13]:
predictions = denormalize_predictions(pred_df)
predictions = predictions.merge(ref_test[['__EPOCH_2']], left_index=True, right_index=True)
predictions_sgp4 = predictions.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                                      ecc=x.ECCENTRICITY,
                                                                      aop=x.ARG_OF_PERICENTER,
                                                                      inc=x.INCLINATION,
                                                                      mea=x.MEAN_ANOMALY,
                                                                      mem=x.MEAN_MOTION,
                                                                      raa=x.RA_OF_ASC_NODE,
                                                                      epoch1=x.__EPOCH_2,
                                                                      epoch2=x.__EPOCH_2,
                                                                     ), axis=1)
predictions_sgp4.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
predictions = predictions.merge(predictions_sgp4, left_index=True, right_index=True)
predictions = predictions.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
predictions.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
predictions = predictions[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(predictions)

  0%|          | 0/2069777 [00:00<?, ?it/s]

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,1.001690209988,96.423591613770,0.118692599237,11.985241398215,246.764831542969,1.012334465981,1.41015625000,-0.000061972532,-2836.622314962749,-6512.050123920332,333.001461484148,-0.672829354442,0.675586431786,7.862899722977
1,1.001688860012,96.424285888672,0.118694797158,11.985246026888,247.249267578125,354.499145507812,3.07763671875,-0.000036701167,-2727.951126777194,-6565.619786579169,-207.435386300792,-0.925141728882,0.082299167369,7.863890196031
2,1.001686159988,96.425888061523,0.118704393506,11.985257044435,247.754440307617,351.930480957031,4.76513671875,0.000026964088,-2664.920598130678,-6592.302790997095,-258.450533117073,-0.958895542975,-0.007920225275,7.857804874539
3,1.001686330012,96.427284240723,0.118716396391,11.985262585804,248.265884399414,349.238494873047,6.42919921875,0.000055864773,-2598.812886241682,-6619.254941196235,-328.638067217469,-0.998409775422,-0.118177292428,7.848515445676
4,2.003368019988,96.426773071289,0.118735760450,11.985279861838,248.657424926758,346.298919677734,9.77929687500,0.000141328987,-2576.375297871884,-6644.634431186554,-164.871140016340,-0.952745585357,-0.011458926278,7.845260879941
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2069772,7.836440310000,51.371189117432,0.065036632121,13.266142737120,278.438690185547,321.141693115234,36.48437500000,0.004791674670,1224.506459272894,-7047.173421781984,221.452515433560,4.724320252092,0.698149492021,5.979651153581
2069773,7.309026610000,51.371669769287,0.064705535769,13.266223184764,276.785003662109,321.209289550781,36.12500000000,0.006037420593,993.244587024602,-7084.342454494744,186.744992892940,4.747556227135,0.518482801961,5.979675284744
2069774,7.082984530000,51.371871948242,0.064549461007,13.266012221575,275.712799072266,321.082305908203,35.26562500000,-0.002783076139,775.771199269524,-7109.768343587401,80.290766280163,4.771267545014,0.291400376039,5.980994290170
2069775,6.932372470000,51.369407653809,0.064012050629,13.266009744257,272.290069580078,324.178161621094,33.21093750000,-0.000631814066,409.841153905260,-7131.154224020785,155.775034592085,4.781746552674,0.127442743389,5.988530099087


In [14]:
baseline = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
baseline = baseline.merge(ref_test, left_index=True, right_index=True)
baseline_sgp4 = baseline.progress_apply(lambda x:get_satrec_erv(bst=x.BSTAR,
                                                                ecc=x.ECCENTRICITY,
                                                                aop=x.ARG_OF_PERICENTER,
                                                                inc=x.INCLINATION,
                                                                mea=x.MEAN_ANOMALY,
                                                                mem=x.MEAN_MOTION,
                                                                raa=x.RA_OF_ASC_NODE,
                                                                epoch1=x.__EPOCH_1,
                                                                epoch2=x.__EPOCH_2,), axis=1)
baseline_sgp4.columns = ["SAT_E","SAT_RX","SAT_RY","SAT_RZ","SAT_VX","SAT_VY","SAT_VZ"]
baseline = baseline.merge(baseline_sgp4, left_index=True, right_index=True)
baseline = baseline.merge(X_test[['X_delta_EPOCH']]*7, left_index=True, right_index=True)
baseline.rename(columns={'X_delta_EPOCH':'EPOCH_DIFF'}, inplace=True)
baseline = baseline[['EPOCH_DIFF', 'INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION', 'RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY', 'BSTAR', 'SAT_RX', 'SAT_RY', 'SAT_RZ', 'SAT_VX', 'SAT_VY', 'SAT_VZ']]
display(baseline)

  0%|          | 0/2069777 [00:00<?, ?it/s]

Unnamed: 0,EPOCH_DIFF,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,1.001690209988,96.4160,0.1186894,11.98433914,248.8476,0.5577,359.6540,0.000055674,-2507.775862180598,-6654.697957829078,-0.009480443650,-0.833468899994,0.296748040243,7.870414192031
1,1.001688860012,96.4167,0.1186916,11.98434376,249.3531,358.4278,1.3246,0.000081452,-2449.416714848534,-6677.807896751887,-0.003920078561,-0.846738658136,0.260066043474,7.868968908199
2,1.001686159988,96.4183,0.1187012,11.98435476,249.8605,356.2713,3.0202,0.000145080,-2390.955382562042,-6701.445078616598,-0.000971333184,-0.859537539716,0.222841196999,7.866378588928
3,1.001686330012,96.4197,0.1187132,11.98436032,250.3680,354.1374,4.6816,0.000173990,-2332.457243028695,-6725.553611816860,-1.172266416477,-0.871839007710,0.184667485792,7.862677979344
4,2.003368019988,96.4201,0.1187358,11.98437786,250.8726,351.9626,6.3943,0.000259430,-2216.498574319953,-6774.998242545674,-3.454887464347,-0.893826245490,0.107862380389,7.851798530630
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2069772,7.836440310000,51.3698,0.0650555,13.26524224,304.8717,300.2421,53.5600,0.004909800,1219.580485250314,-7050.327806212666,366.119792062162,4.704651960370,0.803207259054,5.971229096850
2069773,7.309026610000,51.3698,0.0647227,13.26532258,302.5002,302.3334,51.6458,0.006155600,1185.347986487763,-7054.395234630433,405.030608668085,4.703741072509,0.816117053116,5.969674078726
2069774,7.082984530000,51.3698,0.0645659,13.26511170,301.4465,303.2644,50.7964,-0.002664800,904.176889587085,-7090.552971389384,90.684880678329,4.766434299039,0.392269130560,5.984447768411
2069775,6.932372470000,51.3672,0.0640280,13.26510920,298.0233,306.3596,47.9818,-0.000513510,620.631477140667,-7114.228171212158,186.323015698103,4.774972399386,0.296161521045,5.989414659116


In [15]:
data_dict = {
    'tle_ground_truth':ground_truths,
    'sgp4_baseline':baseline,
    'model_t5_predictions':predictions,
    'ref':ref_test,
}
prefix = train_config["dataset"]
if prefix == "":
    prefix = "full_test_set"

with open(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{prefix}_compare_output.pkl', 'wb') as handle:
    pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)