In [None]:
import numpy as np
import pandas as pd
from fredapi import Fred
import matplotlib.pyplot as plt

from statsmodels.tsa.statespace.sarimax import SARIMAX  
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf

In [None]:
fred = Fred(api_key='321467cef92af0c44aa2aaf01257acc5')
cpi = fred.get_series('CPIAUCSL')
inflation = cpi.pct_change(periods = 12) * 100 
inflation.dropna(inplace = True)
inflation_df = pd.DataFrame(inflation, columns=['inflation'])
inflation_df.index.name = 'date'
inflation_df.head(10)                                                                                                                                                                                                                                                   

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(10, 7))

plot_acf(inflation_df['inflation'], lags=40, ax=axes[0])
axes[0].set_title('Autocorrelation Function (ACF)')
axes[0].set_xlabel('Lags')

plot_pacf(inflation_df['inflation'], lags=40, ax=axes[1])
axes[1].set_title('Partial Autocorrelation Function (PACF)', fontsize=14)
axes[1].set_xlabel('Lags')

plt.tight_layout()
plt.show()

In [None]:
inflation_diff = inflation_df['inflation'].diff().dropna()

fig, axes = plt.subplots(2, 1, figsize=(10, 8))
plot_acf(inflation_diff, lags=40, ax=axes[0])
plot_pacf(inflation_diff, lags=40, ax=axes[1])
plt.tight_layout()
plt.show()

In [None]:
train_size = int(len(inflation_df) * 0.8)
train = inflation_df['inflation'][:train_size]
test = inflation_df['inflation'][train_size:] 

In [None]:
model = SARIMAX(train, order = (1,1,1), seasonal_order = (2,1,1,12))
fit = model.fit(disp = False)


In [None]:
forecast = fit.forecast(steps=len(test))    

plt.figure(figsize=(10, 5))
plt.plot(train.index, train, label='Training Data')
plt.plot(test.index, test, label='Test Data', color='orange')
plt.plot(test.index, forecast, label='Forecast', color='green')     
plt.xlabel('Date')
plt.ylabel('Inflation Rate (%)')
plt.title('SARIMA Model Forecast vs Actual Inflation Rates')    
plt.legend()
plt.show()