In [35]:
import os
import torch
import numpy as np
import pandas as pd
from model import LSTMRegressor
from config import p
from config_predict import last_date, checkpoint_path, output_csv, \
prediction_days
# from config_predict import input_path as input_csv
input_path="../data/transformed/influenza_features.parquet"

In [36]:
def load_model(checkpoint_path: str, config: dict):
    """
    Load the trained model from a checkpoint file.
    Args:
        checkpoint_path (str): Path to the saved model checkpoint.
        config (dict): Model configuration parameters.
    Returns:
        torch.nn.Module: Loaded LSTM model.
    """
    model = LSTMRegressor(
        n_features=config["n_features"],
        hidden_size=config["hidden_size"],
        num_layers=config["num_layers"],
        dropout=config["dropout"],
        learning_rate=config["learning_rate"],
        criterion=config["criterion"],
        output_size=config["output_size"],
    )
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()  # Set the model to evaluation mode
    return model

In [37]:
print("Loading model...")
model = load_model(checkpoint_path, p)

Loading model...


  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))


In [38]:
print("Loading and preprocessing input data...")
raw_data = pd.read_parquet(input_path, columns=["log_cases_14d_moving_avg", "cases_14d_moving_avg", "diff_log_14d"])
raw_data = raw_data.set_index(pd.to_datetime(raw_data.index))
raw_data = raw_data.sort_index()

Loading and preprocessing input data...


In [43]:
initial_input_data = raw_data.iloc[-p["seq_len"]:]
initial_input_data

Unnamed: 0,log_cases_14d_moving_avg,cases_14d_moving_avg,diff_log_14d
1970-01-01 00:00:00.000002616,3.840987,46.571429,0.02642
1970-01-01 00:00:00.000002617,3.840987,46.571429,0.0
1970-01-01 00:00:00.000002618,4.004732,54.857143,0.163745
1970-01-01 00:00:00.000002619,4.135167,62.5,0.130434
1970-01-01 00:00:00.000002620,4.255613,70.5,0.120446
1970-01-01 00:00:00.000002621,4.351199,77.571429,0.095586
1970-01-01 00:00:00.000002622,4.465087,86.928571,0.113888
1970-01-01 00:00:00.000002623,0.0,0.0,-4.465087
1970-01-01 00:00:00.000002624,0.0,0.0,0.0
1970-01-01 00:00:00.000002625,0.0,0.0,0.0


In [42]:
data = initial_input_data.values
data

array([[ 3.84098723e+00,  4.65714286e+01,  2.64195630e-02],
       [ 3.84098723e+00,  4.65714286e+01,  0.00000000e+00],
       [ 4.00473240e+00,  5.48571429e+01,  1.63745171e-01],
       [ 4.13516656e+00,  6.25000000e+01,  1.30434153e-01],
       [ 4.25561271e+00,  7.05000000e+01,  1.20446153e-01],
       [ 4.35119917e+00,  7.75714286e+01,  9.55864611e-02],
       [ 4.46508676e+00,  8.69285714e+01,  1.13887592e-01],
       [ 0.00000000e+00,  0.00000000e+00, -4.46508676e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]])

In [48]:
seq_len = p["seq_len"]
initial_input = torch.tensor(np.array(data), dtype=torch.float32)
initial_input

tensor([[ 3.8410e+00,  4.6571e+01,  2.6420e-02],
        [ 3.8410e+00,  4.6571e+01,  0.0000e+00],
        [ 4.0047e+00,  5.4857e+01,  1.6375e-01],
        [ 4.1352e+00,  6.2500e+01,  1.3043e-01],
        [ 4.2556e+00,  7.0500e+01,  1.2045e-01],
        [ 4.3512e+00,  7.7571e+01,  9.5586e-02],
        [ 4.4651e+00,  8.6929e+01,  1.1389e-01],
        [ 0.0000e+00,  0.0000e+00, -4.4651e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00]])

In [50]:
with torch.no_grad():
    pred, _ = model(initial_input)

RuntimeError: For unbatched 2-D input, hx and cx should also be 2-D but got (3-D, 3-D) tensors