In [1]:
#| default_exp inference.core
#| default_cls_lvl 3

In [2]:
#| export
from tsfast.datasets.core import extract_mean_std_from_dls
from tsfast.data.loader import reset_model_state
from tsfast.models.layers import NormalizedModel
import numpy as np
import torch

In [3]:
#| export

class InferenceWrapper:
    """
    A wrapper class to simplify inference with a trained tsfast/fastai Learner
    on NumPy data. Handles normalization and state reset automatically.
    """
    def __init__(self, learner,device='cpu'):
        """
        Initializes the inferencer.

        Args:
            learner: The trained tsfast/fastai Learner object.
            device = 'cpu': The device to run the inference on.
        """
        if not hasattr(learner, 'model') or not hasattr(learner, 'dls'):
            raise TypeError("Input 'learner' object does not appear to be a valid fastai/tsfast Learner.")

        self.device = device
        self.core_model = learner.model.to(self.device)

        # Extract normalization stats
        mean, std = extract_mean_std_from_dls(learner.dls)
        if mean is None or std is None:
             raise ValueError("Could not extract mean/std from learner's DataLoaders. Ensure normalization was used during training.")

        # Create and store the NormalizedModel
        self.norm_model = NormalizedModel(self.core_model, mean, std).to(self.device)
        self.norm_model.eval() # Set to evaluation mode

    def inference(self, numpy_data: np.ndarray) -> np.ndarray:
        # Add batch dimension if needed
        if numpy_data.ndim == 1:
            numpy_data_batched = np.expand_dims(numpy_data, axis=(0,-1))
        elif numpy_data.ndim == 2:
            numpy_data_batched = np.expand_dims(numpy_data, axis=0)
        elif numpy_data.ndim == 3 and numpy_data.shape[0] == 1:
            numpy_data_batched = numpy_data
        else:
             raise ValueError(f"Input data should have 2 dimensions [seq_len, features] or 3 dimensions [1, seq_len, features]. Provided shape: {numpy_data.shape}")

        input_tensor = torch.from_numpy(numpy_data_batched).float().to(self.device)

        output_tensor = None
        with torch.no_grad():
            reset_model_state(self.core_model)
            model_output = self.norm_model(input_tensor)

            # Handle tuple outputs
            if isinstance(model_output, tuple):
                output_tensor = model_output[0]
            else:
                output_tensor = model_output
        if output_tensor is None:
            raise RuntimeError("Model did not return a valid output tensor.")

        return output_tensor.squeeze(0).cpu().numpy()

    def __call__(self, numpy_data: np.ndarray) -> np.ndarray:
        """Allows calling the predictor instance like a function."""
        return self.inference(numpy_data)

In [4]:
from tsfast.datasets.core import create_dls_test
from tsfast.learner.learner import RNNLearner

In [5]:
dls = create_dls_test()
lrn = RNNLearner(dls)
model = InferenceWrapper(lrn)

In [6]:
model(np.random.randn(100, 1)).shape

(100, 1)

In [7]:
model(np.random.randn(100)).shape

(100, 1)

In [8]:
model(np.random.randn(1,100,1)).shape

(100, 1)

In [9]:
#| include: false
import nbdev; nbdev.nbdev_export()