In [None]:
from typing import Tuple, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Function
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

sns.set_theme(context='paper')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using {device} device')

K = 2  # Number of hidden states

# Hyperparameters for 2 states (/3 states), defined in exhibit 2
H = 32  # GRSTU hidden state size
grad_clip = 0.001  # Gradient clipping threshold
B = 64  # Batch size
R = 20  # (/80) Length of truncated backpropagation
L = 30  # Window of past observations fed to the model
J = 10  # Parameter for the likelihood loss
nu = 0.005  # Learning rate
U = 30  # Frequency at which the distributions are updated
C_1 = 0.1
E_1 = 500  # (/400)
E_2 = 250  # (/250)
epochs = 750  # Number of epochs for which the model is trained
p = 7
lambda_1 = 3  # Jump penalty

t_min = L + 10  # Minimum time step for evaluation

In [None]:
import math


def gaussian_log_likelihood(x: torch.Tensor, mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
    # Equation 9
    return -0.5 * torch.log(2 * np.pi * var) - 0.5 * ((x - mean) ** 2) / var


def gaussian_kl_divergence(mean1, var1, mean2, var2):
    # Equation 16
    return 0.5 * (torch.log(var2 / var1) + var1 / var2 + ((mean1 - mean2) ** 2) / var2 - 1)


class StraightThroughEstimator(Function):
    @staticmethod
    def forward(ctx, x):
        # Returns one hot-hot encoded vector (definition of s_t in bottom of page 5)
        one_hot = torch.zeros_like(x)
        out = one_hot.scatter_(1, x.argmax(dim=1).unsqueeze(1), 1.0)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        # Transparent backward pass
        return grad_output


class GRSTU(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.W_u = nn.Parameter(torch.empty(input_size, hidden_size))
        self.W_r = nn.Parameter(torch.empty(input_size, hidden_size))
        self.W_h = nn.Parameter(torch.empty(input_size, hidden_size))
        self.W_y = nn.Parameter(torch.empty(hidden_size, hidden_size))
        self.W_p = nn.Parameter(torch.empty(hidden_size, output_size))
        self.R_u = nn.Parameter(torch.empty(hidden_size, hidden_size))
        self.R_r = nn.Parameter(torch.empty(hidden_size, hidden_size))
        self.R_h = nn.Parameter(torch.empty(hidden_size, hidden_size))

        self.b_u = nn.Parameter(torch.empty(hidden_size))
        self.b_r = nn.Parameter(torch.empty(hidden_size))
        self.b_h = nn.Parameter(torch.empty(hidden_size))
        self.b_y = nn.Parameter(torch.empty(hidden_size))
        self.b_p = nn.Parameter(torch.empty(output_size))

        self.batch_norm = nn.BatchNorm1d(hidden_size)

        self.hidden = None

        self.ste = StraightThroughEstimator.apply

        # Same initialization as PyTorch GRU
        stdv = 1.0 / math.sqrt(hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def init_hidden(self, batch_size: int = 1, h_device=None) -> None:
        self.hidden = torch.zeros(batch_size, self.hidden_size, device=h_device)

    def reset_hidden(self) -> None:
        self.hidden = None

    def compute_state_probabilities(self) -> torch.Tensor:
        # Equation 7
        relu = torch.relu(torch.matmul(self.hidden, self.W_y) + self.b_y)
        return torch.softmax(torch.matmul(relu, self.W_p) + self.b_p, dim=-1)

    def forward(self, x_t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = x_t.size(0)

        if self.hidden is None:
            self.init_hidden(batch_size, x_t.device)

        # Equation 3
        z_t = torch.sigmoid(x_t @ self.W_u + self.hidden @ self.R_u + self.b_u)

        # Equation 4
        norm_in = x_t @ self.W_r + self.hidden @ self.R_r + self.b_r
        r_t = torch.sigmoid(self.batch_norm(norm_in) if batch_size > 1 else norm_in)

        # Equation 5
        h_dot = torch.tanh(x_t @ self.W_h + (r_t * self.hidden) @ self.R_h + self.b_h)

        # Equation 6
        self.hidden = (1 - z_t) * self.hidden + z_t * h_dot

        # Equation 7
        p_t = self.compute_state_probabilities()

        # Equation 8
        s_t = self.ste(p_t)

        return s_t, p_t

    def eval_predict(self, x: torch.Tensor) -> torch.Tensor:
        self.reset_hidden()

        data_length = x.size(1)
        out = torch.zeros(data_length - t_min, K)

        with torch.no_grad():
            for t_pred in range(t_min, data_length):
                out[t_pred - t_min] = self(x[:, t_pred - self.input_size:t_pred])

        return out


class GRSTULoss(nn.Module):
    def __init__(self, data: np.ndarray):
        super().__init__()

        # Equation 15 for e=0
        self.jump_penalty = lambda_1 * np.exp(-p)

        self.register_buffer('means', torch.full((K,), data.mean()))
        self.register_buffer('vars', torch.full((K,), data.var()))

    def update_gaussian_params(self, data: torch.Tensor, e: int, s_t: torch.Tensor):
        # Equation 15
        self.jump_penalty = lambda_1 * np.exp((p * e) / E_2 - p) if e < E_2 else lambda_1

        new_means = self.means.clone()
        new_vars = self.vars.clone()

        k_t = torch.argmax(s_t, dim=1)
        window_data = data[-s_t.size(0):]

        for i in range(K):
            N_i = torch.sum(k_t == i)

            if N_i > 1:
                # Equation 10
                new_means[i] = torch.mean(window_data[k_t == i])
                # Equation 11
                new_vars[i] = torch.var(window_data[k_t == i], unbiased=True)
            elif N_i == 1:
                # Update mean only when single point
                new_means[i] = torch.mean(window_data[k_t == i])

        # Apply permutation prevention
        self._prevent_label_permutation(self.means.clone(), self.vars.clone(), new_means, new_vars)

    def _prevent_label_permutation(self, old_means, old_vars, new_means, new_vars) -> None:
        distributions = list(range(K))

        for i in range(K):
            min_kl = np.inf
            best_j = i

            # Find the old distribution that minimizes KL divergence
            for j in distributions:
                kl = gaussian_kl_divergence(old_means[i], old_vars[i], new_means[j], new_vars[j])

                if kl < min_kl:
                    min_kl = kl
                    best_j = j

            # Ensure that this distribution is not used again
            distributions.remove(best_j)

            self.means[i] = new_means[best_j]
            self.vars[i] = new_vars[best_j]

    def forward(self, e: int, v_t: torch.Tensor, s_t: torch.Tensor, p_t: torch.Tensor,
                p_prev: Optional[torch.Tensor]) -> torch.Tensor:
        if p_prev is None:
            p_prev = p_t

        mean = (s_t @ self.means).unsqueeze(1)
        var = (s_t @ self.vars).unsqueeze(1)

        log_likelihood = gaussian_log_likelihood(v_t[:, -J:], mean, var)

        # Equation 13
        entropy = -torch.sum(p_t * torch.log(torch.where(p_t > 0, p_t, 1.0)), dim=1)

        # Equation 14
        beta_e = C_1 if e < E_1 else 0.0

        # Equation 12
        jump_loss = self.jump_penalty * torch.norm(p_t - p_prev.detach(), p=1, dim=1)
        return -log_likelihood.mean(dim=1) - beta_e * entropy + jump_loss


class TimeSeriesWindowDataset(Dataset):
    def __init__(self, time_series: torch.Tensor, window_size: int, start: int):
        self.time_series = time_series
        self.window_size = window_size
        self.t_min = start
        self.t_max = len(time_series)

    def __len__(self):
        return self.t_max - self.t_min

    def __getitem__(self, idx):
        index = idx + self.t_min

        return self.time_series[index - self.window_size:index]


def train_grstu(data: np.ndarray):
    time_series = torch.tensor(data, dtype=torch.float, device=device)
    dataset = TimeSeriesWindowDataset(time_series, L, t_min)
    dataloader = DataLoader(dataset, batch_size=B, shuffle=False, drop_last=True)

    model = GRSTU(input_size=L, hidden_size=H, output_size=K).to(device)
    loss_fn = GRSTULoss(data).to(device)
    optimizer = optim.Adam(model.parameters(), lr=nu)

    # Training variables
    best_loss = np.inf
    best_states = None
    best_means = None
    best_vars = None

    loss_history = []
    training_loop = tqdm(range(epochs))

    for e in training_loop:
        model.train()

        epoch_losses = []
        states = []
        last_probs = None

        # Each batch consists of consecutive v_t from `t` to `t + B`
        for v_t in dataloader:
            current_state, current_probs = model(v_t)

            if last_probs is None:
                last_probs = current_probs[0].detach()

            # Each batch consists of B consecutive `t`, so to have the previous
            # probabilities we add the last probabilities output from previous
            # batch and add it before the outputs
            prev_probs = torch.cat([last_probs.unsqueeze(0), current_probs[:-1]])

            # Compute loss
            loss = loss_fn(e, v_t, current_state, current_probs, prev_probs).mean()

            optimizer.zero_grad()

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            last_probs = current_probs[-1].detach()

            model.hidden = model.hidden.detach()

            # Update parameters
            optimizer.step()

            epoch_losses.append(loss.item())
            states.append(current_state)

        loss_history.append(np.mean(epoch_losses))

        model.reset_hidden()

        # Update Gaussian parameters every U epochs
        if e % U == 0:
            model.eval()

            with torch.no_grad():
                loss_fn.update_gaussian_params(time_series, e, torch.cat(states, dim=0))

                last_probs = None
                eval_losses = []
                eval_states = []

                # Collect predictions for all data
                for v_t in dataloader:
                    # Forward pass
                    current_state, current_probs = model(v_t)

                    if last_probs is None:
                        last_probs = current_probs[0].detach()

                    prev_probs = torch.cat([last_probs.unsqueeze(0), current_probs[:-1]])

                    loss = loss_fn(e, v_t, current_state, current_probs, prev_probs).mean()

                    last_probs = current_probs[-1].detach()

                    eval_losses.append(loss.item())

                    eval_states.append(current_state)

                model.reset_hidden()

                # Save states prediction with the lowest evaluation loss
                if np.mean(eval_losses) < best_loss:
                    best_loss = np.mean(eval_losses)
                    best_states = torch.cat(eval_states, dim=0).cpu().numpy()
                    best_means = loss_fn.means.clone().cpu().numpy()
                    best_vars = loss_fn.vars.clone().cpu().numpy()

        training_loop.set_postfix(loss=np.mean(epoch_losses))

    return {
        'model': model,
        'best_states': best_states,
        'means': best_means,
        'vars': best_vars,
        'losses': loss_history,
    }


torch.manual_seed(42)
np.random.seed(42)

sp500 = pd.read_csv('sp500.csv', parse_dates=['date'])
data = sp500['returns'].values

results = train_grstu(data)

_, ax = plt.subplots(figsize=(8, 4))
ax.set_title('Training Loss')
ax.plot(results['losses'], label='Training Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

In [None]:
def plot_regimes(dates, prices, returns, regime_labels, means, vars):
    fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(14, 9), sharex=True)

    ax1.plot(dates, prices, alpha=0.7, label='Price')
    ax1.set_ylabel('Price')
    ax1.set_title('Regime Detection')

    ax2.plot(dates, returns, alpha=0.7, label='Returns')
    ax2.set_ylabel('Returns (%)')
    ax2.set_xlabel('Date')

    # Create a colormap for regimes
    colors = ['g', 'r']

    # Add colored backgrounds for each regime
    for i in range(len(means)):
        if not np.any(regime_labels == i):
            continue

        mask_indices = np.where(regime_labels == i)[0]
        segments = np.split(mask_indices, np.where(np.diff(mask_indices) != 1)[0] + 1)
        regime_label = f'Regime {i + 1}: μ={means[i]:.4f}, σ={np.sqrt(vars[i]):.4f}'

        for segment in segments:
            if len(segment) > 0:
                start_idx = segment[0]
                end_idx = segment[-1]
                label = regime_label if segment is segments[0] else ''

                ax1.axvspan(dates[start_idx], dates[end_idx], alpha=0.4, color=colors[i], label=label)
                ax2.axvspan(dates[start_idx], dates[end_idx], alpha=0.4, color=colors[i])

    # Add legend to both subplots
    ax1.legend(loc='upper left')

    # Format x-axis
    plt.tight_layout()
    plt.show()


dates = sp500['date'].values if 'date' in sp500.columns else np.arange(len(data))

prices = sp500['close'].values if 'close' in sp500.columns else np.cumsum(data)

regime_labels = np.argmax(results['best_states'], axis=1)

# Plot results with dates
plot_regimes(
    dates[L:L + len(regime_labels)],
    prices[L:L + len(regime_labels)],
    data[L:L + len(regime_labels)],
    regime_labels,
    results['means'],
    results['vars'],
)

In [None]:
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

y_pred = np.argmax(results['best_states'], axis=1)
y_test = sp500['states'].values[:len(y_pred)]

print('Balanced Accuracy:', balanced_accuracy_score(y_test, y_pred))

confusion_matrix(y_test, y_pred)