In [4]:
import pandas as pd
import torch
import utils

from constants import *
from models import *

In [5]:
ORBIT_N = 42
WINDOW_SIZE = 10

In [6]:
df_orbit = pd.read_csv(utils.resolve_path(ORBIT_DIR, ORBIT_FILE(ORBIT_N)),
                       index_col=DATE_COL,
                       parse_dates=True,
                       usecols=PRED_COLS)

for feat_3d in COLS_3D:
    utils.normalize_coupled(df_orbit, feat_3d)
utils.normalize_decoupled(df_orbit, COLS_SINGLE)

In [7]:
model = BaseNet(num_bands=2)
state_dict = torch.load(utils.resolve_path(MODELS_DIR, "model.pth"), map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()

BaseNet(
  (flatten): Flatten()
  (linear1): Linear(in_features=6, out_features=16, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=16, out_features=5, bias=True)
)

In [8]:
x = torch.full((WINDOW_SIZE, 24), float("nan"))

with torch.no_grad():
    for i in range(24 - (WINDOW_SIZE - 1)): # range(len(df_orbit) - (WINDOW_SIZE - 1)):
        window = df_orbit.iloc[i:(i + WINDOW_SIZE)]
        feats_3d = [torch.tensor(window[feat].values, dtype=torch.float) for feat in COLS_3D]
        sample = torch.stack(feats_3d, dim=1).unsqueeze(0)
        pred = model(sample)
        pred = torch.argmax(pred, dim=1)[0]
        x[i % WINDOW_SIZE, i:(i + WINDOW_SIZE)] = pred

In [9]:
print(x)
print(x.mode(dim=0))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan, nan, nan],
        [nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan, nan],
        [nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan],
        [nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan],
        [nan, nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [nan, nan, nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., n