In [None]:
import torch
from chronos import ChronosPipeline

In [None]:
def convert_to_tensors(train_series_dict,valid_series_dict):
    train_series_tensors = {}
    valid_series_tensors = {}
    for key, ts in train_series_dict.items():
        # Extract the numpy array from the TimeSeries object
        # Assuming `ts.values()` returns the data array, call the method to get the data
        data = ts.values()  # Notice the parentheses to call the method
        
        # Convert the numpy array to a PyTorch tensor
        tensor = torch.tensor(data, dtype=torch.float32)
        
        # Add the tensor to the dictionary
        train_series_tensors[key] = tensor


    for key, ts in valid_series_dict.items():
        # Extract the numpy array from the TimeSeries object
        # Assuming `ts.values()` returns the data array, call the method to get the data
        data = ts.values()  # Notice the parentheses to call the method
        
        # Convert the numpy array to a PyTorch tensor
        tensor = torch.tensor(data, dtype=torch.float32)
        
        # Add the tensor to the dictionary
        valid_series_tensors[key] = tensor
        
    return train_series_tensors, valid_series_tensors


In [None]:
def chronos_predict(train_series_dict,valid_series_dict):
    train_series_tensors, valid_series_tensors = convert_to_tensors(train_series_dict,valid_series_dict)
    pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-base",
    device_map="cuda",  # use "cpu" for CPU inference and "mps" for Apple Silicon
    torch_dtype=torch.bfloat16,
    )
    # Dictionary to store the forecasts for each time series
    forecasts = {}

    # Iterate over each time series in the dictionary
    for key, context in train_series_tensors.items():
        # Ensure the context is a 1D tensor or prepare it appropriately
        # Here, we assume `context` is a 2D tensor (e.g., [time, features])
        # Select the feature of interest (e.g., crp) if necessary
        # Adjust the slicing based on the data structure
        context = context[:, 0]  # Assuming first component is 'crp' and we want a 1D tensor
        
        # Make sure context is on the same device as the pipeline (if using CUDA or MPS)
        #context = context.to(pipeline.device)

        # Make forecasts with the given context
        forecast = pipeline.predict(
            context=context,
            prediction_length=1,  # Specify the desired forecast length
            num_samples=500,      # Specify the number of samples for probabilistic forecast
        )

        # Store the forecast result
        forecasts[key] = forecast
    evaluate_chronos(forecasts,valid_series_tensors)
    return forecasts,valid_series_tensors



In [None]:
def evaluate_chronos(forecasts,valid_series_tensors):
    # Initialize accumulators for the metrics
    total_mae = 0
    total_mse = 0
    total_smape = 0
    num_series = 0

    for key, forecast in forecasts.items():
        # Get the true values from valid_series_dict
        true_values = valid_series_tensors[key]
        last_value = true_values[-1]
        

        # Take the mean of the forecasts along the samples dimension (axis 1)
        forecast_mean = forecast.mean(dim=1)
        # Compute the absolute errors
        absolute_errors = torch.abs(forecast_mean - last_value)
        # Compute MAE for this time series
        mae = absolute_errors.mean().item()
        total_mae += mae
        
        # Compute MSE for this time series
        mse = (absolute_errors ** 2).mean().item()
        total_mse += mse
        
        # Compute SMAPE for this time series
        denominator = (torch.abs(last_value) + torch.abs(forecast_mean)) / 2
        smape = (absolute_errors / denominator).mean().item() * 100
        total_smape += smape
        
        num_series += 1

    # Calculate overall metrics
    overall_mae = total_mae / num_series
    overall_mse = total_mse / num_series
    overall_rmse = torch.sqrt(torch.tensor(overall_mse)).item()
    #overall_smape = total_smape / num_series

    print(f'MAE: {overall_mae}')
    print(f'MSE: {overall_mse}')
    print(f'RMSE: {overall_rmse}')
    #print(f'SMAPE: {overall_smape}%')
    