# Handling Seasonal Data with `innovate`

This notebook demonstrates how to handle time series data with seasonality when using the `innovate` library. We will use STL (Seasonal and Trend decomposition using Loess) to separate the trend component from the seasonal and residual components, and then fit a diffusion model to the extracted trend.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from innovate.preprocess import stl_decomposition
from innovate.diffuse.logistic import LogisticModel
from innovate.fitters.scipy_fitter import ScipyFitter
from innovate.plots.comparison import plot_diffusion_curves

## 1. Generate Synthetic Seasonal Data

In [None]:
dates = pd.date_range(start='2015-01-01', periods=120, freq='M')
trend = 500 / (1 + np.exp(-0.1 * (np.arange(120) - 60)))
seasonal = 50 * np.sin(np.linspace(0, 10 * 2 * np.pi, 120))
noise = np.random.normal(0, 5, 120)
seasonal_data = pd.Series(trend + seasonal + noise, index=dates, name='Adoptions')

## 2. Decompose the Time Series using STL

In [None]:
decomposed = stl_decomposition(seasonal_data, period=12)

fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
seasonal_data.plot(ax=axes[0], title='Original Series')
decomposed['trend'].plot(ax=axes[1], title='Trend')
decomposed['seasonal'].plot(ax=axes[2], title='Seasonal')
decomposed['residual'].plot(ax=axes[3], title='Residual')
plt.tight_layout()
plt.show()

## 3. Fit a Diffusion Model to the Trend

In [None]:
trend_data = decomposed['trend']
t = np.arange(len(trend_data))

logistic_model = LogisticModel()
fitter = ScipyFitter()
fitter.fit(logistic_model, t, trend_data.values)

print("Fitted Logistic Model Parameters:", logistic_model.params_)

## 4. Visualize the Fit

In [None]:
predictions = logistic_model.predict(t)

plt.figure(figsize=(10, 6))
plt.plot(t, trend_data.values, label='STL Trend')
plt.plot(t, predictions, label='Fitted Logistic Model', linestyle='--')
plt.title('Logistic Model Fit to Trend Component')
plt.xlabel('Time')
plt.ylabel('Cumulative Adoptions')
plt.legend()
plt.grid(True)
plt.show()