In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from dataset import Dataset, to_device
from model import NNSingleFeatureModel, tims_mse_loss
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import pickle
import normalize_data
import random
from time import time
from tqdm.notebook import tqdm

pd.set_option('display.max_columns', 999)
pd.set_option('display.precision', 12)

In [2]:
raw_data = {} # loads raw data and stores as a dict cache

def dataset_key(dataset='', validation=False):
    return dataset+('test' if validation else 'train')

def load_data(raw, dataset='', validation=False):
    key = dataset+('test' if validation else 'train')
    if key not in raw:
        print(f"Loading data to cache for: {key}")
        raw[key] = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../t5_data/{key}.pkl')
    return raw[key]

In [3]:
def load_sub_model_with_config(train_config, model_configs, sub_model_key, X_count=0, force_recreate=False):
    path = train_config['model_path']
    prefix = train_config['model_prefix']
    model_config = model_configs[sub_model_key]
    f = f"{path}/{prefix}{sub_model_key}.pth"
    if os.path.exists(f) and not force_recreate:
        print("Loading existing model")
        checkpoint = torch.load(f)
        net = checkpoint['net']
        loss_func = checkpoint['loss_func']
        optimizer = checkpoint['optimizer']
        mean_losses = checkpoint['mean_losses']
        next_epoch = checkpoint['next_epoch']
    else:
        raise Exception('Model does not exist')
    return net, loss_func, optimizer, mean_losses, next_epoch

In [4]:
def predict(model, X, y, device='cpu'):
    pyt_device = torch.device(device)
    model.eval()
    X_tensor = torch.from_numpy(X.to_numpy()).float()
    nn_results = model(X_tensor).detach().numpy()
    return nn_results

In [5]:
def get_ref_X_y(df):
    ref_cols = [c for c in df.columns if c.startswith('__')]
    X_cols = [c for c in df.columns if c.startswith('X_')]
    y_cols = [c for c in df.columns if c.startswith('y_')]
    return (df[ref_cols], df[X_cols], df[y_cols])

In [6]:
train_config = {
    'dataset' : 'sample_', # '', 'sample_', 'secret_'
    'model_prefix' : "TRY_2_", 
    'model_path' : f"{os.environ['GP_HIST_PATH']}/../t5_models",
    'device' : 'cpu',
}

# sample_ uses train dataset, get SGP4 from that
sgp4xyz = pd.read_pickle(f'{os.environ["GP_HIST_PATH"]}/../3_min/train_sgp4rv.pkl')

In [7]:
%%time
test_df = normalize_data.normalize_all_columns(load_data(raw_data,dataset=train_config['dataset'],validation=True)).dropna()
ref_test, X_test, y_test = get_ref_X_y(test_df)
y_cols = ['y_INCLINATION', 'y_ECCENTRICITY', 'y_MEAN_MOTION', 'y_RA_OF_ASC_NODE_REG', 'y_ARG_OF_PERICENTER_REG', 'y_REV_MA_REG', 'y_BSTAR']
y_test = y_test[y_cols]

Loading data to cache for: sample_test
CPU times: user 410 ms, sys: 383 ms, total: 793 ms
Wall time: 647 ms


In [8]:
model_configs = {
    'y_INCLINATION': { 'feature_index': X_test.columns.get_loc('X_INCLINATION_1') },
    'y_ECCENTRICITY': { 'feature_index': X_test.columns.get_loc('X_ECCENTRICITY_1') },
    'y_MEAN_MOTION': { 'feature_index': X_test.columns.get_loc('X_MEAN_MOTION_1') },
    'y_RA_OF_ASC_NODE_REG': { 'feature_index': X_test.columns.get_loc('X_RA_OF_ASC_NODE_1') },
    'y_ARG_OF_PERICENTER_REG': { 'feature_index': X_test.columns.get_loc('X_ARG_OF_PERICENTER_1') },
    'y_REV_MA_REG': { 'feature_index': X_test.columns.get_loc('X_MEAN_ANOMALY_1') },
    'y_BSTAR': { 'feature_index': X_test.columns.get_loc('X_BSTAR_1') },
}

In [9]:
# Create or load all new sub models here if needed.
all_models = {}
pred_data = []
X_sample = X_test
y_sample = y_test[y_cols]
for sub_key in y_cols:
    # When new models are created, a dummy optimizer is used
    model, _, _, _, _ = load_sub_model_with_config(train_config, model_configs, sub_key)
    all_models[sub_key] = model
    y_sample_pred = predict(model, X_sample, y_sample, device="cpu") # get predictions for each train
    y_sample_pred_df = pd.DataFrame(y_sample_pred, columns=[sub_key], index=y_sample.index)  # put results into a dataframe
    pred_data.append(y_sample_pred_df)
pred_df = pd.concat(pred_data, axis=1)

Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model
Loading existing model


In [22]:
def denormalize_predictions(indf):
    indf = indf.copy()
    d360 = ['y_ECCENTRICITY','y_ARG_OF_PERICENTER_REG','y_RA_OF_ASC_NODE_REG','y_REV_MA_REG']
    indf['y_REV_MA_REG'] = normalize_data.normalize(indf['y_REV_MA_REG'],min=0,max=90,reverse=True)
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf[d360] = indf[d360]%360
    indf['y_INCLINATION'] = normalize_data.normalize(indf['y_INCLINATION'],min=0,max=180,reverse=True)
    indf['y_INCLINATION'] = indf['y_INCLINATION']%180
    indf['y_ECCENTRICITY'] = normalize_data.normalize(indf['y_ECCENTRICITY'],min=0,max=0.25,reverse=True)
    indf['y_MEAN_MOTION'] = normalize_data.normalize(indf['y_MEAN_MOTION'],min=11.25,max=20,reverse=True)
    indf.name="Predictions"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def denormalize_ground_truths(indf):
    indf = indf.copy()
    d360 = ['X_ECCENTRICITY_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1']
    indf[d360] = normalize_data.normalize(indf[d360],min=0,max=360,reverse=True)
    indf['X_INCLINATION_1'] = normalize_data.normalize(indf['X_INCLINATION_1'],min=0,max=180,reverse=True)
    indf['X_ECCENTRICITY_1'] = normalize_data.normalize(indf['X_ECCENTRICITY_1'],min=0,max=0.25,reverse=True)
    indf['X_MEAN_MOTION_1'] = normalize_data.normalize(indf['X_MEAN_MOTION_1'],min=11.25,max=20,reverse=True)
    indf.name="Ground Truths"
    indf.columns = ['INCLINATION', 'ECCENTRICITY', 'MEAN_MOTION','RA_OF_ASC_NODE', 'ARG_OF_PERICENTER', 'MEAN_ANOMALY','BSTAR']
    return indf

def get_ground_truths_xyz(df):
    s = ref_test['__GP_ID_2']
    s.name = 'GP_ID'
    df = df.merge(s, left_index=True, right_index=True)
    df = df.merge(sgp4xyz, left_on='GP_ID', right_index=True)
    return df

In [26]:
ground_truths = denormalize_ground_truths(X_test[['X_INCLINATION_1', 'X_ECCENTRICITY_1', 'X_MEAN_MOTION_1','X_RA_OF_ASC_NODE_1', 'X_ARG_OF_PERICENTER_1', 'X_MEAN_ANOMALY_1','X_BSTAR_1']])
ground_truths = get_ground_truths_xyz(ground_truths).sort_index()
display(ground_truths)

Unnamed: 0,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR,GP_ID,SAT_E,SAT_RX,SAT_RY,SAT_RZ,SAT_VX,SAT_VY,SAT_VZ
0,82.9566,1.451628,13.75958369,108.1802,1.1757,358.9487,0.00156370,46884464.0,0,-2247.603194235576,6976.489740265147,-0.010024535399,-0.859817553482,-0.284764916051,7.335147277747
1,82.9567,1.452384,13.75959654,107.8573,359.9304,0.1839,0.00156130,46884465.0,0,-2102.945801140380,7021.610715137136,-0.249484658555,-0.866219852855,-0.264526097323,7.334987886237
2,82.9567,1.448676,13.75963698,106.6730,355.5472,4.5301,0.00152510,46884466.0,0,-2017.158478116495,7046.930142214801,-0.021140294377,-0.869633200356,-0.252884757147,7.334800602015
3,82.9569,1.445616,13.75965806,105.9736,352.9911,7.0677,0.00149470,18895023.0,0,-1977.405754763013,7058.213969611379,-0.030735170327,-0.871246764901,-0.247293549558,7.334772939990
4,82.9568,1.448892,13.75966699,105.6506,351.6333,8.4145,0.00147820,46884467.0,0,-1924.326182443856,7073.082960036271,-0.307548697930,-0.873307395759,-0.239759332089,7.334567451932
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
335637,73.9219,12.169872,13.70709313,141.2754,126.1524,237.1379,0.00083759,166627814.0,0,-4005.029582473137,6263.597321480315,0.001428493234,-1.567743225900,-1.294414134696,7.006534060415
335638,73.9228,12.164976,13.70714688,137.9996,122.4672,240.9594,0.00092591,166792302.0,0,-3632.923410636105,6468.449550943837,-0.001317372866,-1.641186266291,-1.207826825990,7.021405970211
335639,73.9219,12.169872,13.70709313,141.2754,126.1524,237.1379,0.00083759,166792302.0,0,-3632.923410636105,6468.449550943837,-0.001317372866,-1.641186266291,-1.207826825990,7.021405970211
335640,73.9243,11.934972,13.70824688,55.0334,28.6057,333.2808,0.00089765,169801979.0,0,5320.203614858912,4759.098905177199,-0.004350172217,-1.449433720895,1.523145072713,7.297762542087


In [27]:
predictions = denormalize_predictions(pred_df)
display(predictions)

Unnamed: 0,INCLINATION,ECCENTRICITY,MEAN_MOTION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,BSTAR
0,82.964691162109,1.453434467316,13.760486207902,105.701118469238,351.519866943359,0.235107421875,0.001445628819
1,82.963737487793,1.452838420868,13.760498724878,105.254943847656,349.225891113281,4.694824218750,0.001442658831
2,82.964332580566,1.449891090393,13.760539144278,104.161117553711,345.425445556641,7.236816406250,0.001406512805
3,82.964988708496,1.447422385216,13.760560527444,103.502380371094,343.314636230469,8.359619140625,0.001376181724
4,82.964752197266,1.450529575348,13.760569393635,103.175910949707,341.830963134766,10.108154296875,0.001359240268
...,...,...,...,...,...,...,...
335637,73.920211791992,12.159126281738,13.707993403077,119.909706115723,106.867599487305,257.855468750000,0.000719451753
335638,73.921104431152,12.154232025146,13.708047121763,116.621109008789,103.181854248047,261.519531250000,0.000807795965
335639,73.918426513672,12.156836509705,13.707993142307,116.191574096680,105.110343933105,261.253906250000,0.000719053147
335640,73.925582885742,11.928043365479,13.709147572517,39.803482055664,12.246371269226,347.031250000000,0.000779947615


In [29]:
ref_test

Unnamed: 0,__NORAD_CAT_ID_1,__GP_ID_1,__GP_ID_2,__EPOCH_1,__EPOCH_2
0,21912,46884463,46884464.0,1992-01-16 07:04:00.872256,1992-01-16 17:32:17.511936
1,21912,46884464,46884465.0,1992-01-16 17:32:17.511936,1992-01-18 07:55:58.423584
2,21912,46884465,46884466.0,1992-01-18 07:55:58.423584,1992-01-19 06:37:14.335103
3,21912,46884466,18895023.0,1992-01-19 06:37:14.335103,1992-01-19 17:05:30.866783
4,21912,18895023,46884467.0,1992-01-19 17:05:30.866783,1992-01-20 07:03:12.825215
...,...,...,...,...,...
335637,38056,165853868,166627814.0,2020-11-19 04:29:44.826432,2020-11-30 10:14:15.340416
335638,38056,165990888,166792302.0,2020-11-21 03:47:20.066496,2020-12-02 09:31:44.091840
335639,38056,165853868,166792302.0,2020-11-19 04:29:44.826432,2020-12-02 09:31:44.091840
335640,38056,169258086,169801979.0,2021-01-10 01:47:56.188896,2021-01-18 00:41:58.076160
