In [1]:
from torch import optim
from lib.datasets import get_stock_price,sample_indices,train_test_split
from lib.aug import apply_augmentations,parse_augmentations
import torch
from torch import nn
import torch.nn.functional as F
from typing import List
from lib.utils import sample_indices


In [2]:
data_config = {
    "ticker" : "^GSPC",
    "interval" : "1d",
    "column" : 1,  
    "window_size" : 20,
    "dir" : "datasets",
    "subdir" : "stock"
}
sig_config = {
    "augmentations": [
        {"name": "AddTime"},
        {"name": "LeadLag"},
    ],
    "device" : "cuda:0",
    "depth" : 4,
}

In [3]:
tensor_data = get_stock_price(data_config)
x_real_train, x_real_test = train_test_split(tensor_data, train_test_ratio=0.8, device=sig_config["device"])
if sig_config["augmentations"] is not None:
    sig_config["augmentations"] = parse_augmentations(sig_config.get('augmentations'))
print("Before augmentation shape:",x_real_train.shape)
if sig_config["augmentations"] is not None:
    # Print the tensor shape after each augmentation
    x_aug_sig = apply_augmentations(x_real_train,sig_config["augmentations"])
    # Input dimension of encoder
    # We'll flat the tensor
    input_dim = x_aug_sig.shape[1]*x_aug_sig.shape[2]
print("After augmentation shape:",x_aug_sig.shape)
x_aug_sig = x_aug_sig.to(sig_config["device"])

Rolled data for training, shape torch.Size([1232, 20, 1])
Before augmentation shape: torch.Size([985, 20, 1])
torch.Size([985, 20, 2])
torch.Size([985, 39, 4])
After augmentation shape: torch.Size([985, 39, 4])


In [4]:
class BetaVAE(nn.Module):
    def __init__(self, x_aug_sig, epoch, batch_size, hidden_dims: List, device) -> None:
        super(BetaVAE, self).__init__()

        self.x_aug_sig = x_aug_sig
        print("Input tensor shape: {}".format(x_aug_sig.shape))
        self.epoch = epoch
        self.batch_size = batch_size
        self.device = device

        # Assume len(hidden_dims)=3.
        self.encoder_mu = nn.Sequential(
            nn.Linear(hidden_dims[0],hidden_dims[1]),
            nn.LeakyReLU(),
            nn.Linear(hidden_dims[1],hidden_dims[2]),
            nn.LeakyReLU(),
        )
        self.encoder_sigma = nn.Sequential(
            nn.Linear(hidden_dims[0],hidden_dims[1]),
            nn.Tanh(),
            nn.Linear(hidden_dims[1],hidden_dims[2]),
            nn.LeakyReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dims[2],hidden_dims[1]),
            nn.LeakyReLU(),
            nn.Linear(hidden_dims[1],hidden_dims[0]),
            nn.LeakyReLU(),
        )

        # To device
        self.encoder_mu.to(device)
        self.encoder_sigma.to(device)
        self.decoder.to(device)
    
    def encode(self, x):
        x_flatten = x.view(x.shape[0],-1)
        mean = self.encoder_mu(x_flatten)
        log_var = self.encoder_sigma(x_flatten)
        # Clipping
        log_var = torch.clamp(log_var, min=-10, max=10)
        noise = torch.randn(x.shape[0],mean.shape[1]).to(self.device)
        z = mean + torch.exp(0.5*log_var).mul(noise)
        return mean, log_var, z
        
    def decode(self,z):
        reconstructed_data = self.decoder(z)
        return reconstructed_data

    def loss(self, mean, log_var, sample_data, reconstructed_data, beta=1.0):
        """
        Compute β-VAE loss
        
        Args:
            mean: latent mean [batch_size, latent_dim]
            log_var: latent log variance [batch_size, latent_dim]
            sample_data: input data [batch_size, 156]
            reconstructed_data: reconstructed data [batch_size, 156]
            beta: weight for KL divergence term (β parameter)
        
        Returns:
            total_loss: total β-VAE loss
            recon_loss: reconstruction loss
            kl_loss: KL divergence term
        """
        # Flatten sample_data if not already flattened
        sample_data = sample_data.view(sample_data.size(0), -1)  # [128, 156]
        batch_size = sample_data.size(0)

        # 1. Reconstruction loss (MSE)
        recon_loss = F.mse_loss(reconstructed_data, sample_data, reduction='sum') / batch_size
        
        # 2. KL divergence between q(z|x) and p(z)
        # Analytical KL divergence: D_KL(N(μ, σ^2) || N(0, 1))
        kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) / batch_size
        
        # Total loss = reconstruction + β * KL
        total_loss = recon_loss + beta * kl_loss
        
        return total_loss

    def generate(self,x: torch.Tensor):
        _, _, z = self.encode(x)
        reconstructed_data = self.decode(z)
        return reconstructed_data

In [5]:
def train(model,optimizer):
    early_stop = 500
    cnt = 0
    min_loss = float('inf')
    for i in range(model.epoch):
        # Sample time indices of size equal to the batch size
        # From sefl.x_aug_sig
        time_indics = sample_indices(model.x_aug_sig.shape[0],model.batch_size,"cuda")
        sample_data = model.x_aug_sig[time_indics]
        # print("sample_data shape {}".format(sample_data.shape))
        # Encode 
        mean, log_var, z = model.encode(sample_data)
        # Decode
        reconstructed_data = model.decode(z)
        # print("reconstructed_data shape {},".format(reconstructed_data.shape))
        # Calculate loss
        loss = model.loss(mean,log_var,sample_data.view(sample_data.shape[0],-1),reconstructed_data)
        # Backpropogation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print loss
        if i%500==0:
            print("Epoch {} loss {}".format(i,loss.item()))
        # Early stop
        if loss.item()<min_loss:
            min_loss = loss.item()
            cnt = 0
        else:
            cnt += 1
            if cnt>early_stop:
                break

In [6]:
lr = 1e-4
batch_size = 200
epoch = 20001
hidden_dims = [input_dim,12,3]
model = BetaVAE(x_aug_sig=x_aug_sig,epoch=epoch,batch_size=batch_size,hidden_dims=hidden_dims,device='cuda')
print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

Input tensor shape: torch.Size([985, 39, 4])
BetaVAE(
  (encoder_mu): Sequential(
    (0): Linear(in_features=156, out_features=12, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=12, out_features=3, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
  )
  (encoder_sigma): Sequential(
    (0): Linear(in_features=156, out_features=12, bias=True)
    (1): Tanh()
    (2): Linear(in_features=12, out_features=3, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
  )
  (decoder): Sequential(
    (0): Linear(in_features=3, out_features=12, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=12, out_features=156, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
  )
)


In [7]:
train(model,optimizer=optimizer)

Epoch 0 loss 1463627392.0
Epoch 500 loss 434325632.0
Epoch 1000 loss 127512672.0
Epoch 1500 loss 65280000.0
Epoch 2000 loss 44557572.0
Epoch 2500 loss 32601516.0
Epoch 3000 loss 25540910.0
Epoch 3500 loss 20314860.0
Epoch 4000 loss 16063943.0
Epoch 4500 loss 13197256.0
Epoch 5000 loss 10956188.0
Epoch 5500 loss 8774265.0
Epoch 6000 loss 7202094.5
Epoch 6500 loss 5849141.0
Epoch 7000 loss 4701825.0
Epoch 7500 loss 3961849.5
Epoch 8000 loss 3127251.25
Epoch 8500 loss 2629494.5
Epoch 9000 loss 2347202.0
Epoch 9500 loss 2043580.5
Epoch 10000 loss 1808552.125
Epoch 10500 loss 1769922.625
Epoch 11000 loss 1624679.75
Epoch 11500 loss 1464341.875
Epoch 12000 loss 1426667.0
Epoch 12500 loss 1340772.25
Epoch 13000 loss 1284507.0
Epoch 13500 loss 1224837.375
Epoch 14000 loss 1146381.0
Epoch 14500 loss 1310418.875
Epoch 15000 loss 1276603.5
Epoch 15500 loss 1247635.875
