In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.autograd import Function

sns.set(context='paper')

K = 2  # Number of hidden states

# Hyperparameters for 2 states (/3 states), defined in exhibit 2
H = 32  # GRSTU hidden state size
grad_clip = 0.001  # Gradient clipping threshold
B = 64  # Batch size
R = 20  # (/80) Length of truncated backpropagation
L = 30  # Window of past observations fed to the model
J = 10  # Parameter for the likelihood loss
nu = 0.005  # Learning rate
U = 30  # Frequency at which the distributions are updated
C_1 = 0.1
E_1 = 500  # (/400)
E_2 = 250  # (/250)
epochs = 750  # Number of epochs for which the model is trained
p = 7
lambda_1 = 3

In [None]:
class STE(Function):
    @staticmethod
    def forward(ctx, x):
        # Returns one hot-hot encoded vector (definition of s_t in bottom of page 5)
        one_hot = torch.zeros_like(x)
        return one_hot.scatter_(0, x.argmax(dim=0), 1.0)

    @staticmethod
    def backward(ctx, grad_output):
        # Transparent backward pass
        return grad_output


t_min = 50  # TODO define t_min


class GRSTU(nn.Module):
    def __init__(self, init_std=0.02):  # TODO ensure optimal value for init_std (use Xavier initialization?)
        super().__init__()

        self.W_u = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_r = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_h = nn.Parameter(torch.randn(H, L) * init_std)
        self.W_y = nn.Parameter(torch.randn(H, H) * init_std)
        self.W_p = nn.Parameter(torch.randn(K, H) * init_std)
        self.R_u = nn.Parameter(torch.randn(H, H) * init_std)
        self.R_r = nn.Parameter(torch.randn(H, H) * init_std)
        self.R_h = nn.Parameter(torch.randn(H, H) * init_std)

        self.b_u = nn.Parameter(torch.zeros(H))
        self.b_r = nn.Parameter(torch.zeros(H))
        self.b_h = nn.Parameter(torch.zeros(H))
        self.b_y = nn.Parameter(torch.zeros(H))
        self.b_p = nn.Parameter(torch.zeros(K))

        # TODO how should BatchNorm be used for a 1D vector?
        self.batch_norm = nn.Identity() #

    def step(self, v_t, h_prev):
        h_t = self.forward_pass(v_t, h_prev)
        p_t = self.compute_state_probabilities(h_t)
        s_t = STE.apply(p_t)

        return h_t, p_t, s_t

    def forward(self, x):
        return None  # TODO implement forward for GRSTU

    def forward_pass(self, v_t, h_prev):
        z_t = torch.sigmoid(self.W_u @ v_t + self.R_u @ h_prev + self.b_u)  # Eq (3)
        r_t = torch.sigmoid(self.batch_norm(self.W_r @ v_t + self.R_r @ h_prev + self.b_r))  # Eq (4)
        h_dot_t = torch.tanh(self.W_h @ v_t + self.R_h @ (r_t * h_prev) + self.b_h)  # Eq (5)

        # Eq (6)
        return (1 - z_t) * h_prev + z_t * h_dot_t  # h_t

    def compute_state_probabilities(self, h_t):
        # Eq (7)
        return torch.softmax(self.W_p @ torch.relu(self.W_y @ h_t + self.b_y) + self.b_p, dim=0)  # p_t

In [None]:
def compute_loss(x, t, e, mu_kt, sigma_sq_kt, p_t, p_prev):
    p_norm = torch.sum(torch.abs(p_t - p_prev))
    entropy = -torch.sum(p_t * torch.log(p_t))  # Eq (13)
    beta_e = C_1 if e < E_1 else 0.0  # Eq (14)
    lambda_e = lambda_1 * torch.exp((p * e) / E_2 - p) if e < E_2 else lambda_1  # Eq (15)

    log_gauss = -0.5 * torch.log(2 * torch.pi * sigma_sq_kt) - 0.5 * (x[t] - mu_kt) ** 2 / sigma_sq_kt  # Eq (9)

    return -1 / J * torch.sum(log_gauss - beta_e * entropy + lambda_e * p_norm)  # Eq (12)

In [None]:
from torch import optim
from tqdm import tqdm

x = torch.tensor(pd.read_csv('simulated_hmm_500.csv')['returns_3'].values, dtype=torch.float32)
T = len(x)

# Initialize distributions (exhibit 3)
mu = torch.full((K,), x.mean())
sigma_square = torch.full((K,), x.var())

model = GRSTU()
optimizer = optim.Adam(model.parameters(), lr=nu)

losses = []
training_loop = tqdm(range(epochs))

for e in training_loop:
    model.train()

    h_t = torch.zeros(H)
    p_prev = torch.zeros(K)
    k = torch.zeros(T - t_min)

    for t in range(t_min, T):
        h_t, p_t, s_t = model.step(x[t-L:t], h_t)
        k_t = torch.argmax(s_t)
        loss = compute_loss(x, t, torch.tensor(e), mu[k_t], sigma_square[k_t], p_t, p_prev)

        k[t - t_min] = k_t

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        h_t = h_t.detach()

    losses.append(loss.item())

    if e % U == 0:
        mu = torch.tensor([(x[t_min:][k == i]).mean() for i in range(K)])  # Eq (10)
        sigma_square = torch.tensor([(x[t_min:][k == i]).var() for i in range(K)])  # Eq (11)

        for i in range(K):
            temp_x = (x[t_min:][k == i])
            temp_var = temp_x.var()

    training_loop.set_postfix(loss=loss.item())

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
ax.set_title('Training Loss')
ax.plot(losses, label='Training Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()