In [None]:
from typing import Tuple, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Function
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

sns.set_theme(context='paper')

K = 2  # Number of hidden states

# Hyperparameters for 2 states (/3 states), defined in exhibit 2
H = 32  # GRSTU hidden state size
grad_clip = 0.001  # Gradient clipping threshold
B = 64  # Batch size
R = 20  # (/80) Length of truncated backpropagation
L = 30  # Window of past observations fed to the model
J = 10  # Parameter for the likelihood loss
nu = 0.005  # Learning rate
U = 30  # Frequency at which the distributions are updated
C_1 = 0.1
E_1 = 500  # (/400)
E_2 = 250  # (/250)
epochs = 750  # Number of epochs for which the model is trained
p = 7
lambda_1 = 3

In [None]:
def gaussian_log_likelihood(x: torch.Tensor, mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
    # Equation 9
    return -0.5 * torch.log(2 * np.pi * var) - 0.5 * ((x - mean) ** 2) / var


class StraightThroughEstimator(Function):
    @staticmethod
    def forward(ctx, x):
        # Returns one hot-hot encoded vector (definition of s_t in bottom of page 5)
        one_hot = torch.zeros_like(x)
        return one_hot.scatter_(1, x.argmax(dim=1).unsqueeze(1), 1.0)

    @staticmethod
    def backward(ctx, grad_output): # TODO check later
        # Transparent backward pass
        return grad_output


class GRSTU(nn.Module):
    def __init__(self, init_std: float = 0.5):
        super().__init__()

        self.W_u = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_r = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_h = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_y = nn.Parameter(torch.randn(H, H) * init_std)
        self.W_p = nn.Parameter(torch.randn(K, H) * init_std)
        self.R_u = nn.Parameter(torch.randn(H, H) * init_std)
        self.R_r = nn.Parameter(torch.randn(H, H) * init_std)
        self.R_h = nn.Parameter(torch.randn(H, H) * init_std)

        self.b_u = nn.Parameter(torch.zeros(H))
        self.b_r = nn.Parameter(torch.zeros(H))
        self.b_h = nn.Parameter(torch.zeros(H))
        self.b_y = nn.Parameter(torch.zeros(H))
        self.b_p = nn.Parameter(torch.zeros(K))

        self.batch_norm = nn.BatchNorm1d(H)

        self.hidden = None

        self.ste = StraightThroughEstimator.apply

    def init_hidden(self, batch_size: int = 1):
        self.hidden = torch.zeros(batch_size, H)

    def compute_state_probabilities(self):
        # Equation 7
        relu = torch.relu(torch.matmul(self.hidden, self.W_y.t()) + self.b_y)
        return torch.softmax(torch.matmul(relu, self.W_p.t()) + self.b_p, dim=1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = x.size(0)

        if self.hidden is None: #or self.hidden.size(0) != batch_size:
            self.init_hidden(batch_size)

        # Equation 3
        z_t = torch.sigmoid(torch.matmul(x, self.W_u.t()) + torch.matmul(self.hidden, self.R_u.t()) + self.b_u)

        # Equation 4
        norm_in = torch.matmul(x, self.W_r.t()) + torch.matmul(self.hidden, self.R_r.t()) + self.b_r
        r_t = torch.sigmoid(self.batch_norm(norm_in) if batch_size > 1 else norm_in)

        # Equation 5
        h_dot = torch.tanh(torch.matmul(x, self.W_h.t()) + torch.matmul(r_t * self.hidden, self.R_h.t()) + self.b_h)

        # Equation 6
        self.hidden = (1 - z_t) * self.hidden + z_t * h_dot

        # Equation 7
        p_t = self.compute_state_probabilities()

        # Equation 8
        s_t = self.ste(p_t)

        return p_t, s_t


class GRSTULoss(nn.Module):
    def __init__(self, data: np.ndarray):
        super().__init__()

        self.register_buffer('means', torch.full((K,), data.mean()))
        self.register_buffer('vars', torch.full((K,), data.var()))
        self.register_buffer('lambda_epochs', torch.zeros(epochs))

    def update_gaussian_params(self, data: torch.Tensor, s_t: torch.Tensor):
        old_means = self.means.clone()
        old_vars = self.vars.clone()

        k_t = torch.argmax(s_t, dim=1)

        # Count states for debugging
        state_counts = torch.bincount(k_t, minlength=K)
        print(f"State distribution: {state_counts.tolist()} out of {len(k_t)} points")

        # Check state probabilities - are they nearly binary?
        probs_max = torch.max(s_t, dim=1)[0]
        nearly_binary = (probs_max > 0.99).float().mean()
        print(f"Proportion of nearly-binary state assignments: {nearly_binary:.4f}")

        # Track parameter changes
        old_params = {i: (self.means[i].item(), self.vars[i].item()) for i in range(K)}

        for i in range(K):
            N_i = torch.sum(k_t == i)

            if N_i > 1:
                # Equation 10
                self.means[i] = torch.mean(data[k_t == i])
                # Equation 11
                self.vars[i] = torch.var(data[k_t == i], unbiased=True)
            elif N_i == 1:
                # Update mean only when single point
                self.means[i] = torch.mean(data[k_t == i])

        # Track parameter changes after updates
        new_params = {i: (self.means[i].item(), self.vars[i].item()) for i in range(K)}

        #print("Parameter changes:")
        for i in range(K):
            old_mean, old_var = old_params[i]
            new_mean, new_var = new_params[i]
            #print(f"State {i}: mean {old_mean:.6f} → {new_mean:.6f}, var {old_var:.6f} → {new_var:.6f}")

        # Apply permutation prevention
        self._prevent_label_permutation(old_means, old_vars)

        # Check if permutation occurred
        post_perm_params = {i: (self.means[i].item(), self.vars[i].item()) for i in range(K)}
        #if any(post_perm_params[i] != new_params[i] for i in range(K)):
        #    print("Label permutation occurred!")

    def _prevent_label_permutation(self, old_means: torch.Tensor, old_vars: torch.Tensor):
        for i in range(K):
            min_kl = np.inf
            best_j = i

            # Find the old distribution that minimizes KL divergence
            for j in range(K):
                # KL divergence between two Gaussian
                kl = 0.5 * (
                        torch.log(old_vars[j] / self.vars[i]) +
                        self.vars[i] / old_vars[j] +
                        ((self.means[i] - old_means[j]) ** 2) / old_vars[j] -
                        1
                )

                if kl < min_kl:
                    min_kl = kl
                    best_j = j

            # If best match is different from current index, swap parameters
            if best_j != i:
                # Swap means
                temp_mean = self.means[i].clone()
                self.means[i] = self.means[best_j]
                self.means[best_j] = temp_mean

                # Swap variances
                temp_var = self.vars[i].clone()
                self.vars[i] = self.vars[best_j]
                self.vars[best_j] = temp_var

    def forward(self, e: int, current_x: torch.Tensor, p_t: torch.Tensor,
                s_t: torch.Tensor, p_prev: Optional[torch.Tensor]) -> torch.Tensor:
        #if p_prev is None:
        p_prev = p_t

        batch_size = p_t.size(0)
        likelihood_sum = torch.zeros(batch_size)
        k_t = torch.argmax(s_t, dim=1)

        for j in range(J):
            mean = self.means[k_t]
            var = self.vars[k_t]
            likelihood_sum += gaussian_log_likelihood(current_x[:, j], mean, var)

        # Equation 13
        entropy = -torch.sum(p_t * torch.log(p_t), dim=1)

        # Equation 14
        beta_e = C_1 if e < E_1 else 0.0

        # Equation 15
        lambda_e = lambda_1 * np.exp((p * e) / E_2 - p) if e < E_2 else lambda_1

        self.lambda_epochs[e] = lambda_e

        # Equation 12
        jump_loss = lambda_e * torch.norm(p_t - p_prev, p=1, dim=1)
        return torch.mean(-likelihood_sum / J - beta_e * entropy + jump_loss)


class TimeSeriesDataset(Dataset):
    """
    Dataset for time series data with sliding windows.
    """

    def __init__(self, data: np.ndarray, window_size: int):
        self.data = data
        self.window_size = window_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        start = max(0, idx - self.window_size)
        return torch.FloatTensor(self.data[start:idx])


def train_grstu(data: np.ndarray):
    eval_data = torch.FloatTensor(data)
    dataset = TimeSeriesDataset(data, L)
    #dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=True)

    model = GRSTU()
    loss_fn = GRSTULoss(data)
    optimizer = optim.Adam(model.parameters(), lr=nu)

    # Training variables
    best_loss = np.inf
    best_states = None
    best_means = None
    best_vars = None
    avg_losses = []

    training_loop = tqdm(range(epochs))

    for e in training_loop:
        model.train()

        batch_losses = []
        current_states = []
        prev_probs = None

        model.init_hidden(1)

        for t in range(L + 1, len(data)):
            x_batch = torch.FloatTensor(data[t - L:t]).reshape(1, -1)

            optimizer.zero_grad()

            probs, states = model(x_batch)

            # Compute loss
            loss = loss_fn(e, x_batch, probs, states, prev_probs)

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            prev_probs = probs.detach()

            # Update parameters
            optimizer.step()

            batch_losses.append(loss.item())
            current_states.append(states)

            probs.detach()
            states.detach()
            prev_probs.detach()
            model.hidden = torch.FloatTensor(model.hidden.data)

        avg_loss = np.mean(batch_losses)
        avg_losses.append(avg_loss)

        # Update Gaussian parameters every U epochs
        if e % U == 0:
            model.eval()

            with torch.no_grad():
                loss_fn.update_gaussian_params(eval_data[L + 1:], torch.cat(current_states, dim=0))

                eval_prev_probs = None
                eval_losses = []
                eval_states = []
                eval_dataloader = DataLoader(dataset, batch_size=B, shuffle=False, drop_last=True)

                # Collect predictions for all data
                for t in range(L + 1, len(dataset)):
                    x_batch = dataset[t].reshape(1, -1)

                    model.init_hidden(x_batch.size(0))

                    # Forward pass
                    probs, states = model(x_batch)

                    loss = loss_fn(e, x_batch[:, -J:], probs, states, eval_prev_probs)

                    eval_prev_probs = probs.detach()

                    eval_losses.append(loss.item())

                    eval_states.append(states)

                # Save states prediction with the lowest evaluation loss
                if np.mean(eval_losses) < best_loss:
                    best_loss = np.mean(eval_losses)
                    best_states = torch.cat(eval_states, dim=0).cpu().numpy()
                    best_means = loss_fn.means.clone().cpu().numpy()
                    best_vars = loss_fn.vars.clone().cpu().numpy()

        training_loop.set_postfix(loss=avg_loss)

    return {
        'model': model,
        'best_states': best_states,
        'means': best_means,
        'vars': best_vars,
        'losses': avg_losses,
        'lambda_epochs': loss_fn.lambda_epochs.cpu().numpy(),
    }


torch.manual_seed(42)
np.random.seed(42)

sp500 = pd.read_csv('sp500.csv')
data = sp500['returns'].values

results = train_grstu(data)

fig, ax = plt.subplots(figsize=(8, 4))
ax.set_title('Training Loss')
ax.plot(results['losses'], label='Training Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

fig, ax = plt.subplots(figsize=(8, 4))
ax.set_title('Jump Penalty')
ax.plot(results['lambda_epochs'])
ax.set_xlabel('Epoch')
ax.set_ylabel('$\\lambda$(Epochs)')
plt.show()

In [None]:
import matplotlib.dates as mdates


def plot_results(data, results, window_size=30):
    regimes = np.argmax(results['best_states'], axis=1)

    _, ax1 = plt.subplots(figsize=(14, 8))
    ax2 = ax1.twinx()

    # Adjust time index to account for window
    time_idx = np.arange(len(data))[window_size:window_size + len(regimes)]
    plot_data = data[window_size:window_size + len(regimes)]

    # Plot data
    ax1.plot(time_idx, plot_data, label='Data')
    ax2.plot(time_idx, regimes, label='Regime', color='r')

    ax1.set_xlabel('Time')
    ax1.set_ylabel('Value')
    ax2.set_ylabel('Regime')
    plt.legend()
    plt.tight_layout()
    plt.show()


plot_results(data, results, window_size=L)

In [None]:
from datetime import datetime


def plot_regimes(dates, prices, returns, regime_labels, means, vars):
    fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(14, 10), sharex=True)

    ax1.plot(dates, prices, 'k-', alpha=0.7, label='S&P 500')
    ax1.set_ylabel('Price')
    ax1.set_title('GRSTU Regime Detection - Prices')

    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    ax1.xaxis.set_major_locator(mdates.YearLocator())

    ax2.plot(dates, returns, 'k-', alpha=0.3, label='Returns')
    ax2.set_ylabel('Returns (%)')
    ax2.set_xlabel('Date')

    # Create a colormap for regimes
    colors = plt.cm.tab10(np.linspace(0, 1, len(means)))

    # Add colored backgrounds for each regime
    for i in range(len(means)):
        if np.any(regime_labels == i):
            mask_indices = np.where(regime_labels == i)[0]
            segments = np.split(mask_indices, np.where(np.diff(mask_indices) != 1)[0] + 1)

            for segment in segments:
                if len(segment) > 0:
                    start_idx = segment[0]
                    end_idx = segment[-1]

                    # Plot on price chart
                    ax1.axvspan(dates[start_idx], dates[end_idx], alpha=0.3, color=colors[i],
                                label=f'Regime {i + 1}' if segment is segments[0] else "")

                    # Plot on returns chart
                    ax2.axvspan(dates[start_idx], dates[end_idx], alpha=0.3, color=colors[i])

    # Add regime statistics to the legend
    legend_elements = []
    for i in range(len(means)):
        label = f'Regime {i + 1}: μ={means[i]:.4f}, σ={np.sqrt(vars[i]):.4f}'
        legend_elements.append(plt.Line2D([0], [0], color=colors[i], lw=4, label=label))

    # Add legend to both subplots
    ax1.legend(loc='upper left')
    ax2.legend(handles=legend_elements, loc='upper left')

    # Format x-axis
    fig.autofmt_xdate()
    plt.tight_layout()
    plt.show()
""

dates = pd.date_range(end=datetime.now(), periods=len(data)).values

prices = np.zeros(len(data))
prices[0] = 1000
for t in range(1, len(data)):
    prices[t] = prices[t - 1] * (1 + data[t] / 100)  # Assuming returns are in percentage

# Extract results
states = results['best_states']
regime_labels = np.argmax(states, axis=1)
regime_means = results['means']
regime_vars = results['vars']

# Adjust for window size (L)
dates_adjusted = dates[L:L + len(regime_labels)]
prices_adjusted = prices[L:L + len(regime_labels)]
returns_adjusted = data[L:L + len(regime_labels)]

known_states_adjusted = None

# Plot results with dates
plot_regimes(
    dates_adjusted,
    prices_adjusted,
    returns_adjusted,
    regime_labels,
    regime_means,
    regime_vars
)

In [None]:
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, ConfusionMatrixDisplay

y_pred = np.argmax(results['best_states'], axis=1)
y_test = sp500['states'].values[:len(y_pred)]

print('Balanced Accuracy:', balanced_accuracy_score(y_test, y_pred))

confusion_matrix(y_test, y_pred)