In [None]:
# imports
from pathlib import Path

import torch
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.data.dataset import RawTrainingDataset, TrainingDataset, TrainingMode, RefinementDataset, RefinementCollator
from src.model.loss import ODRMSELoss
from src.model.refinement import RefinementTransformer
from src.model.residual import TimeSeriesDecoderOnlyTransformer
from src.model.scaler import ZTransform

In [None]:
DEVICE = "cuda:0"
DATA_PATH = Path('../data/training_ds/residual_model_ds.npz')

In [None]:
# load models
model = TimeSeriesDecoderOnlyTransformer(
    904,
    432,
    112,
    4,
    4,
    4 * 112,
    0
)
model.load_state_dict(torch.load("../submission/model.pt"))
model = model.to(DEVICE)
model.compile()
model.eval()

feature_scaler = ZTransform.load("../submission/feature_transformer.npz")
result_scaler = ZTransform.load("../submission/result_transformer.npz")

In [None]:
# load data
val_dataset_m1 = TrainingDataset(
    raw_dataset=RawTrainingDataset.load(DATA_PATH), input_noise_std=0.0, output_noise_std=0.0,
    training_mode=TrainingMode.TRAIN, device=DEVICE, n_features=904
)

train_loader = DataLoader(val_dataset_m1, batch_size=128, shuffle=False)

In [None]:
# get one batch
first_batch = next(iter(train_loader))
batch_features, batch_target = first_batch

print(batch_features.shape, batch_target.shape)

In [None]:
# predict output and calculate loss
out = model.predict(batch_features)
loss = torch.nn.functional.mse_loss(out, batch_target.unsqueeze(-1))
odrmse = ODRMSELoss(432 * 10 * 60).forward(out, batch_target.unsqueeze(-1), torch.arange(0, 432, device=DEVICE) * 10 * 60)
total_loss = loss + odrmse

print(loss.item(), odrmse.item(), total_loss.item())

In [None]:
index = 7
plt.plot(out[index, :, 0].detach().cpu().numpy(), label="Residual predictions")
plt.plot(batch_target[index].detach().cpu().numpy(), label="Ground truth residuals")
plt.ylabel("Z-Transformed residual")
plt.xlabel("Time-window ID (10 min. windows)")
plt.legend()
plt.savefig("results-bad.png")
plt.show()

In [None]:
test = np.load("../data/training_ds/residual_model_ds-reduced.npz")

In [None]:
odrmse_weights = torch.exp(-torch.log(torch.tensor(1e-5)) / 432 * torch.arange(0, 432))

def odrmse_score(scaled_out, scaled_msis, scaled_target):
    rmse_test = torch.sqrt((scaled_out - scaled_target).pow(2).sum())
    rmse_msis = torch.sqrt((scaled_msis - scaled_target).pow(2).sum())
    return torch.sum(odrmse_weights * (1 - rmse_test / rmse_msis)) / odrmse_weights.sum()

In [None]:
out_scaled = torch.as_tensor(result_scaler.reverse_z_transform(out.detach().cpu().squeeze().numpy()))
nrlm_predictions = np.load("../data/training_ds/residual_model_ds.npz")['nrlm_predictions']

In [None]:
np.isnan(np.load("../data/training_ds/residual_model_ds-reduced.npz")['main_features'][4958:5451])

In [None]:
index = 7
true_vals = test['y'][index]
nrlm_vals = test['nrlm_predictions'][index]
scaled_true_vals = true_vals * test['y_std'] + test['y_mean']
plt.plot(nrlm_vals + scaled_true_vals)
plt.plot(nrlm_vals) # -(nrlm_vals - nrlm_vals.mean()) / nrlm_vals.std())
plt.show()


In [None]:
import math

plt.plot(np.sin(np.arange(1000) * 2 * math.pi / 1000), label="Density in Propagated Orbit")
plt.plot(np.zeros(1000, dtype=np.float32), label="Average Density of Orbit")
plt.title("Density of propagated orbit vs. average density\nin a \"perfect\" environment")
plt.xlabel("Sample points for a full orbit revolution")
plt.ylabel("Z-Transformed density")
plt.legend()
plt.savefig("avg-vs-propagated-orbit.png")
plt.show()

In [None]:
import matplotlib.ticker as plticker

example = 24
out_fft = torch.fft.fft(out[example].squeeze() - out[example].mean().squeeze()).detach().cpu().squeeze().numpy()
n = out_fft.size
fft_freq = np.fft.fftfreq(n, d=0.1)

bounds = 20
lb = (bounds // 2 - 1)
ub = (bounds // 2 + 1)
out_fft = np.concatenate([out_fft[n//2+1:], out_fft[:n//2]])[lb * n // bounds : ub * n // bounds]
fft_freq = np.concatenate([fft_freq[n//2+1:], fft_freq[:n//2]])[lb * n // bounds : ub * n // bounds]

plt.figure(figsize=(12, 8))
tickers = plticker.MultipleLocator(base=0.1)
plt.plot(fft_freq, out_fft)
plt.gca().xaxis.set_major_locator(plt.MultipleLocator(0.025))
plt.xticks(rotation=60)
# plt.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.1f'))
plt.grid(True)
# plt.plot(out_fft)
# plt.plot(out[24].detach().cpu().squeeze().cpu())

plt.title("Frequency analysis of residuals")
plt.xlabel("Frequency (1/600 Hz)")
plt.ylabel("Intensity of frequency")
plt.savefig("fft-freq.png")
plt.show()

fft_freq[out_fft.argmax()]