In [None]:
# 📌 Brent Oil Change Point Detection (Local Jupyter Version)
# Author: yitbarek geletaw
# Description: Detects structural changes in Brent oil prices and links them to historical events.

# --- SETUP ---

import os
os.environ["MKL_THREADING_LAYER"] = "GNU"  # Prevent MKL issues

import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- CONFIGURE PATHS ---

data_path = "data/processed/brent_oil_log_returns.csv"
events_path = "data/processed/events.csv"
output_dir = "results"

os.makedirs(os.path.join(output_dir, 'figures', 'trace_plots'), exist_ok=True)

In [None]:

# --- LOAD DATA ---

df = pd.read_csv(data_path, parse_dates=["Date"])
df.dropna(subset=["LogReturn"], inplace=True)
log_returns = df["LogReturn"]
dates = df["Date"]
events_df = pd.read_csv(events_path, parse_dates=["Date"])

In [None]:
# --- RUN MODEL ---

model, trace = run_change_point_model(log_returns, dates, num_change_points=3)

# --- POSTERIOR SUMMARY ---

summary = az.summary(trace)
summary.to_csv(os.path.join(output_dir, 'posterior_summary.csv'))

# --- TRACE PLOT ---

az.plot_trace(trace)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'figures', 'trace_plots', 'trace_plot.png'), dpi=300)
plt.show()

In [None]:
# --- MODEL FUNCTION ---

def run_change_point_model(data, dates, num_change_points=3):
    n = len(data)
    data_std = np.std(data)
    idx = np.arange(n)

    with pm.Model() as model:
        if num_change_points == 3:
            tau_1 = pm.DiscreteUniform("tau_1", lower=0, upper=n//3)
            tau_2 = pm.DiscreteUniform("tau_2", lower=n//3, upper=2*n//3)
            tau_3 = pm.DiscreteUniform("tau_3", lower=2*n//3, upper=n-1)

            mu_1 = pm.Normal("mu_1", mu=0, sigma=data_std / 10)
            mu_2 = pm.Normal("mu_2", mu=0, sigma=data_std / 10)
            mu_3 = pm.Normal("mu_3", mu=0, sigma=data_std / 10)
            mu_4 = pm.Normal("mu_4", mu=0, sigma=data_std / 10)

            sigma_1 = pm.HalfNormal("sigma_1", sigma=data_std / 2)
            sigma_2 = pm.HalfNormal("sigma_2", sigma=data_std / 2)
            sigma_3 = pm.HalfNormal("sigma_3", sigma=data_std / 2)
            sigma_4 = pm.HalfNormal("sigma_4", sigma=data_std / 2)

            mu = pm.math.switch(tau_1 >= idx, mu_1,
                  pm.math.switch(tau_2 >= idx, mu_2,
                  pm.math.switch(tau_3 >= idx, mu_3, mu_4)))
            
            sigma = pm.math.switch(tau_1 >= idx, sigma_1,
                     pm.math.switch(tau_2 >= idx, sigma_2,
                     pm.math.switch(tau_3 >= idx, sigma_3, sigma_4)))
        else:
            tau = pm.DiscreteUniform("tau", lower=0, upper=n-1)
            mu_1 = pm.Normal("mu_1", mu=0, sigma=data_std / 10)
            mu_2 = pm.Normal("mu_2", mu=0, sigma=data_std / 10)
            sigma_1 = pm.HalfNormal("sigma_1", sigma=data_std / 2)
            sigma_2 = pm.HalfNormal("sigma_2", sigma=data_std / 2)

            mu = pm.math.switch(tau >= idx, mu_1, mu_2)
            sigma = pm.math.switch(tau >= idx, sigma_1, sigma_2)

        returns = pm.Normal("returns", mu=mu, sigma=sigma, observed=data)
        trace = pm.sample(10000, tune=5000, chains=4, random_seed=42, return_inferencedata=True)
    
    return model, trace

In [None]:
# --- IDENTIFY CHANGE POINTS ---

tau_vars = ['tau_1', 'tau_2', 'tau_3']
change_points = []
for var in tau_vars:
    if var in trace.posterior:
        tau_vals = trace.posterior[var].values.flatten().astype(int)
        tau_mode = np.bincount(tau_vals).argmax()
        change_date = dates.iloc[tau_mode]
        change_points.append(change_date)

# --- PLOT LOG RETURNS WITH CHANGE POINTS ---

plt.figure(figsize=(12, 6))
plt.plot(dates, log_returns, label='Log Returns', alpha=0.7)
for cp in change_points:
    plt.axvline(cp, color='red', linestyle='--', label=f'Change Point: {cp.strftime("%Y-%m-%d")}')
plt.title("Brent Oil Log Returns with Change Points")
plt.xlabel("Date")
plt.ylabel("Log Return")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'figures', 'log_returns_change_point.png'), dpi=300)
plt.show()

# --- MATCH CHANGE POINTS WITH EVENTS ---

matched_events = []
for cp_date in change_points:
    events_df['Date_Diff'] = abs(events_df['Date'] - cp_date)
    closest_event = events_df.loc[events_df['Date_Diff'].idxmin()]
    matched_events.append({
        "Change_Point_Date": cp_date,
        "Closest_Event_Date": closest_event["Date"],
        "Description": closest_event["Event_Description"],
        "Type": closest_event["Event_Type"],
        "Date_Diff_Days": closest_event["Date_Diff"].days
    })

matched_df = pd.DataFrame(matched_events)
matched_df.to_csv(os.path.join(output_dir, "matched_events.csv"), index=False)
matched_df