# predict with fitted models

In [8]:
import os
import warnings
import tensorflow as tf
import numpy as np
import pandas as pd
import h5py

warnings.filterwarnings("ignore")

INPUT = "120"
DATA_DIR = "train_" + INPUT
OUTPUT_DIR = "pred/train"
MODEL_DIR = "saved_model"

In [9]:
def load_dataset(dir):
    r"""
    Load preprocessed datasets.

    Parameters
    ----------
    dir: str
        Fold contains datasets.
    
    Returns
    -------
    value: generator
        A generator with elements like (name, X, y)
    """
    
    for file in os.listdir(dir):
        if file.endswith("hdf5"):
            name = file.split("_")[0]
            with h5py.File(f"{dir}/{file}", 'r') as f:
                X, y = f["X"][...], f["y"][...]
                X[5, :] = np.log(X[5, :] + 1)
                X[6, :] = np.log(X[6, :] + 1)
                y = np.log(y + 1)
                t = pd.read_csv(f"train_valid/{file.split('.')[0]}.csv")["timestamp"]
                t = t[-len(y):]
            yield (name, X, y, t)

def load_model(dir, name, more):
    r"""
    Load fitted model.

    Parameters
    ----------
    dir: str
        Fold contains datasets.
    name: str
        Name of fitted model.
    more: str
        More flags to find the model.
    
    Returns
    -------
    model: tf.keras.Model
        Model with specific name.
    """
    for model in os.listdir(dir):
        if model.startswith(name+"_"+more):
            return tf.keras.models.load_model(f"{dir}/{model}")

In [10]:
res = dict()
for name, X, y, t in load_dataset(DATA_DIR):
    print(f"Testing {name}...")
    model = load_model(MODEL_DIR, name, INPUT)
    y_pred = model.predict(X).squeeze()
    y_pred = np.exp(y_pred) - 1
    res[name] = pd.DataFrame({
        "timestamp": t,
        "Predict_return": y_pred
    })

Testing AR...
Testing BAH...
Testing FTI...
Testing HII...
Testing LMT...
Testing MLI...
Testing NFE...
Testing NOC...
Testing PBR...
Testing STLD...


In [11]:
for name, pred in res.items():
    pred.to_csv(f"{OUTPUT_DIR}/{name}_pred.csv", index=False)
print("All saved.")

All saved.
