In [4]:
# Library
import sys, os
sys.path.append(os.path.abspath('..'))

from hdf5_loader import StockDatasetHDF5
from myconfig import *
import subclass as sc

import pandas as pd
import numpy as np
import seaborn as sns
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
import os, shutil, wandb
from itertools import permutations

import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
import torch.nn.functional as F
import torchsummary

# np.set_printoptions(precision=4, suppress=True, linewidth=120)
torch.set_printoptions(sci_mode=False, precision=4)
# _ = plt.tight_layout()

### Model

#### Preparing

In [5]:
import models.encdec as encdec
import models.mybuffer as buf

ticker_list=[
    'AAPL',
    'MSFT',
    'GOOGL',
    'META',
    'IBM',
    'INTC',
]
date_range=[ST, ED]
hz_dim = {hz:128 for hz in THZ}
targ_hz = '5m'
label_weight = {hz:v for hz, v in zip(THZ, [0.1, 0.3, 0.5, 0.1, 0])}
batch_size = 3

In [6]:
import importlib
_ = importlib.reload(encdec)
_ = importlib.reload(sc)
_ = importlib.reload(buf)

In [7]:
hdf5_inst = sc.StockDatasetHDF5(ticker_list, date_range)
envgen = sc.get_samples(hdf5_inst, hz_dim, targ_hz, tensor=True)

#### Modeling

In [8]:
buffer = buf.mybuffer(1000, batch_size)

In [9]:
def kl_divergence_multivariate(x, y):
    x_mu, x_sigma = x[0], x[1]  # shape: (batch_size, latent_dim)
    y_mu, y_sigma = y[0], y[1]

    # 배치 간 브로드캐스팅을 위해 차원 추가
    # x: (batch_size, 1, latent_dim), y: (1, batch_size, latent_dim)
    x_mu = x_mu.unsqueeze(1)
    x_sigma = x_sigma.unsqueeze(1)
    y_mu = y_mu.unsqueeze(0)
    y_sigma = y_sigma.unsqueeze(0)

    # KL divergence를 latent dimension에 대해 계산
    kl = torch.log(y_sigma / x_sigma) \
        + (x_sigma ** 2 + (x_mu - y_mu) ** 2) / (2 * y_sigma ** 2) \
        - 0.5

    # latent dimension에 대해 합산하여 최종 (batch_size, batch_size) 행렬을 얻음
    kl_divergence = kl.sum(dim=2)
    return kl_divergence

def hellinger_distance(x, y):
    """
    x: (batch_x, 2, outdim)  # [mu, sigma]
    y: (batch_y, 2, outdim)  # [mu, sigma]
    
    Returns:
        H: (batch_x, batch_y) Hellinger distance matrix
    """
    mu_x, sigma_x = x[:, 0, :], x[:, 1, :]
    mu_y, sigma_y = y[:, 0, :], y[:, 1, :]

    # Broadcasting을 위해 차원 확장
    mu_x = mu_x.unsqueeze(1)  # (batch_x, 1, outdim)
    sigma_x = sigma_x.unsqueeze(1)

    mu_y = mu_y.unsqueeze(0)  # (1, batch_y, outdim)
    sigma_y = sigma_y.unsqueeze(0)

    # Hellinger distance 계산
    term1 = torch.sqrt(2 * sigma_x * sigma_y) / torch.sqrt(sigma_x**2 + sigma_y**2)
    term2 = torch.exp(-((mu_x - mu_y) ** 2) / (4 * (sigma_x**2 + sigma_y**2)))

    H = torch.sqrt(1 - term1 * term2).mean(dim=-1)  # (batch_x, batch_y)

    return H

In [None]:
similarity_hz_weight = [0.1, 0.3, 0.5, 0.1, 0.0]

def ts_similarity(ts1, ts2):
    return np.corrcoef(ts1, ts2)[0, 1]

def chart_similarity(data:torch.Tensor):
    dnum = len(data)
    sim_list = np.ones((dnum, dnum, 5))
    for r, c in permutations(range(dnum), 2):
        for i in range(5):
            sim_list[r,c,i] = ts_similarity(data[r,i], data[c,i])
    
    return np.einsum("ijk,k->ij", sim_list, similarity_hz_weight)

In [11]:
model = encdec.CustomCNN(hz_dim, THZ, 2)

* Batch data Dimension: $(batch, 5(hz), feature, seqlen)$
* Normal distribution Dimension: $(batch, 2(\mu|\sigma), outdim)$

In [None]:
buffer.clear()
# Start
for i in tqdm(range(100)):
    rel_charts, features, labels, infos = sc.batch_maker(envgen, batch_size)
    q = model(rel_charts)

    q_buf, x_buf = buffer.get(4)

    buffer.put(q.detach(), rel_charts.detach())

100%|██████████| 100/100 [00:06<00:00, 15.38it/s]


In [16]:
qs, xs = buffer.get(4)

In [19]:
hellinger_distance(q, qs)

tensor([[0.0062, 0.0062, 0.0062, 0.0043],
        [0.0060, 0.0060, 0.0060, 0.0042],
        [0.0076, 0.0076, 0.0076, 0.0056]], grad_fn=<MeanBackward1>)

In [17]:
chart_similarity(xs[:,:,3])

array([[ 1.        ,  0.98745249,  0.97575433, -0.67505496],
       [ 0.98745249,  1.        ,  0.98552752, -0.67995366],
       [ 0.97575433,  0.98552752,  1.        , -0.68242266],
       [-0.67505496, -0.67995366, -0.68242266,  1.        ]])