In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation of Synthetic Time Series Forecasts\n",
    "\n",
    "This notebook evaluates the quality of synthetic time series forecasts generated by Chronos-T5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import required libraries\n",
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import json\n",
    "import yaml\n",
    "import torch\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error\n",
    "\n",
    "# Add the project root directory to path for importing project modules\n",
    "sys.path.append('..')\n",
    "from src.evaluation import TimeSeriesEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Load configuration\n",
    "with open('../config/config.yaml', 'r') as file:\n",
    "    config = yaml.safe_load(file)\n",
    "\n",
    "# Display configuration\n",
    "print(\"Project Configuration:\")\n",
    "print(yaml.dump(config, sort_keys=False, default_flow_style=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Evaluation Results\n",
    "\n",
    "Let's load the evaluation results generated by the main process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Initialize the evaluator\n",
    "evaluator = TimeSeriesEvaluator('../config/config.yaml')\n",
    "\n",
    "# Evaluation results file path\n",
    "results_file = os.path.join(evaluator.results_path, 'evaluation_results.json')\n",
    "\n",
    "# Load evaluation results if available\n",
    "if os.path.exists(results_file):\n",
    "    with open(results_file, 'r') as f:\n",
    "        evaluation_results = json.load(f)\n",
    "    print(\"Evaluation results loaded successfully.\")\n",
    "else:\n",
    "    print(f\"Evaluation results file not found at {results_file}\")\n",
    "    evaluation_results = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Display average metrics if available\n",
    "if evaluation_results and 'average_metrics' in evaluation_results:\n",
    "    avg_metrics = evaluation_results['average_metrics']\n",
    "    print(\"Average Evaluation Metrics:\")\n",
    "    for metric, value in avg_metrics.items():\n",
    "        print(f\"{metric}: {value:.4f}\")\n",
    "else:\n",
    "    print(\"No average metrics available.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Display individual metrics if available\n",
    "if evaluation_results and 'individual_metrics' in evaluation_results:\n",
    "    ind_metrics = evaluation_results['individual_metrics']\n",
    "    metrics_df = pd.DataFrame.from_dict(ind_metrics, orient='index')\n",
    "    print(f\"Individual metrics for {len(metrics_df)} time series:\")\n",
    "    display(metrics_df.head(10))\n",
    "    \n",
    "    # Summary statistics of metrics\n",
    "    print(\"\\nSummary Statistics:\")\n",
    "    display(metrics_df.describe())\n",
    "else:\n",
    "    print(\"No individual metrics available.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize Individual Forecasts\n",
    "\n",
    "Let's visualize some of the individual forecasts generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Function to display forecast images\n",
    "def display_forecast_images(results_path, max_images=5):\n",
    "    \"\"\"Display forecast visualization images\"\"\"\n",
    "    forecast_images = [f for f in os.listdir(results_path) if f.startswith('forecast_series_') and f.endswith('.png')]\n",
    "    \n",
    "    if not forecast_images:\n",
    "        print(\"No forecast visualization images found.\")\n",
    "        return\n",
    "    \n",
    "    print(f\"Found {len(forecast_images)} forecast visualization images.\")\n",
    "    \n",
    "    # Display up to max_images\n",
    "    for image_file in forecast_images[:max_images]:\n",
    "        image_path = os.path.join(results_path, image_file)\n",
    "        \n",
    "        # Get series ID from filename\n",
    "        series_id = image_file.split('_')[-1].split('.')[0]\n",
    "        \n",
    "        # Display image\n",
    "        plt.figure(figsize=(12, 8))\n",
    "        img = plt.imread(image_path)\n",
    "        plt.imshow(img)\n",
    "        plt.axis('off')\n",
    "        plt.title(f\"Forecast for Series {series_id}\")\n",
    "        plt.show()\n",
    "\n",
    "# Display forecast images\n",
    "display_forecast_images(evaluator.results_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and Analyze Synthetic Forecasts\n",
    "\n",
    "Let's load and analyze the synthetic forecasts generated by Chronos-T5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Load synthetic data\n",
    "def load_synthetic_data(synthetic_path):\n",
    "    \"\"\"Load synthetic forecasts data\"\"\"\n",
    "    # Find all forecast series files\n",
    "    forecast_files = [f for f in os.listdir(synthetic_path) if f.startswith('forecast_series_') and f.endswith('.csv')]\n",
    "    \n",
    "    if not forecast_files:\n",
    "        print(\"No forecast series files found.\")\n",
    "        return {}\n",
    "    \n",
    "    print(f\"Found {len(forecast_files)} forecast series files.\")\n",
    "    \n",
    "    # Load each forecast series\n",
    "    forecasts = {}\n",
    "    for file in forecast_files:\n",
    "        series_id = int(file.split('_')[-1].split('.')[0])\n",
    "        path = os.path.join(synthetic_path, file)\n",
    "        forecasts[series_id] = pd.read_csv(path)\n",
    "    \n",
    "    return forecasts\n",
    "\n",
    "# Load main synthetic data file\n",
    "synthetic_files = [f for f in os.listdir(config['data']['synthetic_path']) if f.endswith('.csv') and not f.startswith('forecast_series_')]\n",
    "\n",
    "if synthetic_files:\n",
    "    synthetic_file = synthetic_files[0]  # Use the first synthetic file found\n",
    "    synthetic_path = os.path.join(config['data']['synthetic_path'], synthetic_file)\n",
    "    synthetic_df = pd.read_csv(synthetic_path)\n",
    "    print(f\"Loaded main synthetic data file: {synthetic_file}\")\n",
    "    print(f\"Shape: {synthetic_df.shape}\")\n",
    "    display(synthetic_df.head())\n",
    "else:\n",
    "    print(\"No main synthetic data file found.\")\n",
    "    synthetic_df = None\n",
    "\n",
    "# Load forecast series\n",
    "forecasts = load_synthetic_data(config['data']['synthetic_path'])\n",
    "\n",
    "# Display first forecast\n",
    "if forecasts:\n",
    "    first_id = list(forecasts.keys())[0]\n",
    "    print(f\"\\nSample forecast for series {first_id}:\")\n",
    "    display(forecasts[first_id].head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze Forecast Uncertainty\n",
    "\n",
    "One of the key advantages of Chronos-T5 is generating probabilistic forecasts. Let's analyze the uncertainty in these forecasts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def analyze_forecast_uncertainty(forecasts, max_series=3):\n",
    "    \"\"\"Analyze uncertainty in probabilistic forecasts\"\"\"\n",
    "    if not forecasts:\n",
    "        print(\"No forecasts available for uncertainty analysis.\")\n",
    "        return\n",
    "    \n",
    "    series_ids = list(forecasts.keys())[:max_series]\n",
    "    \n",
    "    for series_id in series_ids:\n",
    "        forecast_df = forecasts[series_id]\n",
    "        \n",
    "        if 'lower_bound' in forecast_df.columns and 'upper_bound' in forecast_df.columns:\n",
    "            # Calculate prediction interval width\n",
    "            forecast_df['interval_width'] = forecast_df['upper_bound'] - forecast_df['lower_bound']\n",
    "            forecast_df['relative_width'] = forecast_df['interval_width'] / forecast_df['median_forecast']\n",
    "            \n",
    "            # Plot median forecast with prediction intervals\n",
    "            plt.figure(figsize=(12, 8))\n",
    "            \n",
    "            # Upper subplot: Forecast with intervals\n",
    "            plt.subplot(2, 1, 1)\n",
    "            plt.plot(forecast_df['time_idx'], forecast_df['median_forecast'], 'r-', label='Median Forecast')\n",
    "            plt.fill_between(\n",
    "                forecast_df['time_idx'],\n",
    "                forecast_df['lower_bound'],\n",
    "                forecast_df['upper_bound'],\n",
    "                color='r', alpha=0.3,\n",
    "                label='80% Prediction Interval'\n",
    "            )\n",
    "            plt.title(f'Probabilistic Forecast for Series {series_id}')\n",
    "            plt.ylabel('Value')\n",
    "            plt.grid(True)\n",
    "            plt.legend()\n",
    "            \n",
    "            # Lower subplot: Interval width over time\n",
    "            plt.subplot(2, 1, 2)\n",
    "            plt.plot(forecast_df['time_idx'], forecast_df['interval_width'], 'b-', label='Interval Width')\n",
    "            plt.title('Prediction Interval Width (Uncertainty)')\n",
    "            plt.xlabel('Time Step')\n",
    "            plt.ylabel('Interval Width')\n",
    "            plt.grid(True)\n",
    "            \n",
    "            plt.tight_layout()\n",
    "            plt.show()\n",
    "            \n",
    "            # Display statistics\n",
    "            print(f\"\\nUncertainty Statistics for Series {series_id}:\")\n",
    "            print(f\"Average Interval Width: {forecast_df['interval_width'].mean():.4f}\")\n",
    "            print(f\"Min Interval Width: {forecast_df['interval_width'].min():.4f}\")\n",
    "            print(f\"Max Interval Width: {forecast_df['interval_width'].max():.4f}\")\n",
    "            \n",
    "            # Check if uncertainty grows over time\n",
    "            corr = forecast_df['time_idx'].corr(forecast_df['interval_width'])\n",
    "            print(f\"Correlation between Time and Uncertainty: {corr:.4f}\")\n",
    "            if corr > 0.5:\n",
    "                print(\"Uncertainty significantly increases over time.\")\n",
    "            elif corr < -0.5:\n",
    "                print(\"Uncertainty significantly decreases over time.\")\n",
    "            else:\n",
    "                print(\"No strong trend in uncertainty over time.\")\n",
    "        else:\n",
    "            print(f\"Series {series_id} does not have prediction interval data.\")\n",
    "\n",
    "# Analyze forecast uncertainty\n",
    "analyze_forecast_uncertainty(forecasts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Manual Forecast Generation with Chronos-T5\n",
    "\n",
    "Let's try generating a forecast manually using the Chronos-T5 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "from chronos import ChronosPipeline\n",
    "\n",
    "def generate_manual_forecast(sample_data=None, prediction_length=24):\n",
    "    \"\"\"Generate a forecast manually using Chronos-T5\"\"\"\n",
    "    try:\n",
    "        # If no sample data provided, create a synthetic time series\n",
    "        if sample_data is None:\n",
    "            print(\"Creating synthetic time series...\")\n",
    "            # Create a synthetic time series with trend, seasonality, and noise\n",
    "            t = np.linspace(0, 4, 200)  # 200 time steps\n",
    "            trend = 0.01 * t**2  # Quadratic trend\n",
    "            seasonal_1 = 0.5 * np.sin(2 * np.pi * t * 10)  # Fast cycle\n",
    "            seasonal_2 = 0.8 * np.sin(2 * np.pi * t)       # Slow cycle\n",
    "            noise = 0.1 * np.random.randn(len(t))          # Random noise\n",
    "            series = trend + seasonal_1 + seasonal_2 + noise\n",
    "            \n",
    "            # Convert to tensor\n",
    "            time_series = torch.tensor(series, dtype=torch.float32)\n",
    "        else:\n",
    "            time_series = torch.tensor(sample_data, dtype=torch.float32)\n",
    "        \n",
    "        print(f\"Time series shape: {time_series.shape}\")\n",
    "        \n",
    "        # Load Chronos-T5 model\n",
    "        print(\"Loading Chronos-T5-Small model...\")\n",
    "        pipeline = ChronosPipeline.from_pretrained(\n",
    "            \"amazon/chronos-t5-small\",\n",
    "            device_map=\"auto\",\n",
    "            torch_dtype=torch.float32\n",
    "        )\n",
    "        \n",
    "        # Generate forecast\n",
    "        print(f\"Generating forecast for {prediction_length} future time steps...\")\n",
    "        forecast = pipeline.predict(\n",
    "            time_series,\n",
    "            prediction_length,\n",
    "            num_samples=20,  # Generate more samples for better uncertainty estimation\n",
    "            temperature=0.8\n",
    "        )\n",
    "        \n",
    "        print(f\"Forecast shape: {forecast.shape}\")  # Should be [num_samples, prediction_length]\n",
    "        \n",
    "        # Calculate forecast statistics\n",
    "        forecast_np = forecast.numpy()\n",
    "        median_forecast = np.median(forecast_np, axis=0)\n",
    "        lower_bound = np.percentile(forecast_np, 10, axis=0)\n",
    "        upper_bound = np.percentile(forecast_np, 90, axis=0)\n",
    "        \n",
    "        # Visualize with multiple trajectory samples\n",
    "        plt.figure(figsize=(12, 8))\n",
    "        \n",
    "        # Plot historical data\n",
    "        plt.plot(range(len(time_series)), time_series.numpy(), 'b-', label='Historical Data')\n",
    "        \n",
    "        # Plot individual forecast trajectories (a subset)\n",
    "        forecast_idx = range(len(time_series), len(time_series) + prediction_length)\n",
    "        for i in range(min(5, forecast_np.shape[0])):  # Plot up to 5 trajectories\n",
    "            plt.plot(forecast_idx, forecast_np[i], 'gray', alpha=0.3, linewidth=0.5)\n",
    "            \n",
    "        # Plot median forecast and prediction intervals\n",
    "        plt.plot(forecast_idx, median_forecast, 'r-', label='Median Forecast')\n",
    "        plt.fill_between(forecast_idx, lower_bound, upper_bound, color='r', alpha=0.3, label='80% Prediction Interval')\n",
    "        \n",
    "        plt.title('Manual Time Series Forecast with Chronos-T5')\n",
    "        plt.xlabel('Time Step')\n",
    "        plt.ylabel('Value')\n",
    "        plt.legend()\n",
    "        plt.grid(True)\n",
    "        plt.show()\n",
    "        \n",
    "        return forecast_np, median_forecast, lower_bound, upper_bound\n",
    "    \n",
    "    except Exception as e:\n",
    "        print(f\"Error generating manual forecast: {str(e)}\")\n",
    "        return None, None, None, None\n",
    "\n",
    "# Generate a forecast manually\n",
    "forecast_np, median_forecast, lower_bound, upper_bound = generate_manual_forecast(prediction_length=48)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions and Recommendations\n",
    "\n",
    "Based on the evaluation of synthetic time series forecasts generated by Chronos-T5, we can draw the following conclusions:\n",
    "\n",
    "1. **Forecast Accuracy**: The average metrics (RMSE, MAE, MAPE, R²) provide insights into how well the model is forecasting future values. Lower RMSE/MAE/MAPE and higher R² indicate better forecasts.\n",
    "\n",
    "2. **Uncertainty Quantification**: Chronos-T5 provides probabilistic forecasts, allowing for uncertainty quantification through prediction intervals. The width of these intervals can help assess the model's confidence in its predictions.\n",
    "\n",
    "3. **Forecast Quality**: Visualizing the individual forecasts allows us to qualitatively assess if the model is capturing relevant patterns (trend, seasonality, cycles) in the data.\n",
    "\n",
    "### Recommendations:\n",
    "\n",
    "- **Model Selection**: If forecast accuracy is not satisfactory, consider using a larger Chronos model (e.g., chronos-t5-base) for potentially better performance.\n",
    "  \n",
    "- **Temperature Tuning**: Adjust the temperature parameter (lower for more conservative forecasts, higher for more diverse but potentially less accurate forecasts).\n",
    "  \n",
    "- **Preprocessing Improvements**: Consider more advanced preprocessing of time series data, such as detrending, seasonal adjustment, or normalization.\n",
    "  \n",
    "- **Ensemble Approach**: For critical applications, consider an ensemble of different models for improved robustness.\n",
    "\n",
    "- **Data Quality**: The quality of forecasts depends significantly on the quality and quantity of historical data. Ensure that input data is clean and representative."
   ]
  }
,
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Check if metrics distribution image exists\n",
    "metrics_dist_file = os.path.join(evaluator.results_path, 'metric_distributions.png')\n",
    "\n",
    "if os.path.exists(metrics_dist_file):\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    img = plt.imread(metrics_dist_file)\n",
    "    plt.imshow(img)\n",
    "    plt.axis('off')\n",
    "    plt.title(\"Distributions of Evaluation Metrics\")\n",
    "    plt.show()\n",
    "else:\n",
    "    print(f\"Metrics distribution visualization not found at {metrics_dist_file}\")\n",
    "    \n",
    "    # If we have the metrics dataframe, we can generate the visualization here\n",
    "    if 'metrics_df' in locals():\n",
    "        plt.figure(figsize=(15, 10))\n",
    "        \n",
    "        # RMSE\n",
    "        plt.subplot(2, 2, 1)\n",
    "        sns.histplot(metrics_df['rmse'].dropna())\n",
    "        plt.title('Distribution of RMSE')\n",
    "        plt.xlabel('RMSE')\n",
    "        \n",
    "        # MAE\n",
    "        plt.subplot(2, 2, 2)\n",
    "        sns.histplot(metrics_df['mae'].dropna())\n",
    "        plt.title('Distribution of MAE')\n",
    "        plt.xlabel('MAE')\n",
    "        \n",
    "        # MAPE\n",
    "        plt.subplot(2, 2, 3)\n",
    "        sns.histplot(metrics_df['mape'].dropna())\n",
    "        plt.title('Distribution of MAPE (%)')\n",
    "        plt.xlabel('MAPE (%)')\n",
    "        \n",
    "        # R²\n",
    "        plt.subplot(2, 2, 4)\n",
    "        sns.histplot(metrics_df['r2'].dropna())\n",
    "        plt.title('Distribution of R²')\n",
    "        plt.xlabel('R²')\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()