In [1]:
import pickle
import dill
import itertools
import pandas as pd
import numpy as np

import torch
import torch.utils.data as data

from modules.RIFT_dataset import RIFT_Dataset
from modules.RIFT_model import RIFT_Model
from modules.RIFT_model_config import RIFT_Model_Config
from modules.train_model import Model_Trainer
from modules.radam import RAdam
from modules.train_utils import ret_seq_indices, shifted_diff, ts_moving_average, ts_moving_var, seq_corr_1d, seq_corr_3d

  from .autonotebook import tqdm as notebook_tqdm
2023-02-23 02:36:37 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [2]:
def get_input_ts_transform_list():
    input_ts_transform_list = []
    input_ts_transform_list.append(ret_seq_indices)
    for i in range(1, 11):
        def f_shifted_diff(x, i=i): return shifted_diff(x, i)
        input_ts_transform_list.append(f_shifted_diff)
    for i in range(10, 200, 20):
        def f_moving_avg(x, i=i): return ts_moving_average(x, i)
        input_ts_transform_list.append(f_moving_avg)
        def f_moving_var(x, i=i): return ts_moving_var(x, i)
        input_ts_transform_list.append(f_moving_var)
    return input_ts_transform_list


input_ts_transform_fns = get_input_ts_transform_list()
target_fns = [
    lambda x1, x2: seq_corr_1d(x1, x2),
    lambda x1, x2: seq_corr_1d(x1[0:120], x2[0:120]),
    lambda x1, x2: seq_corr_1d(x1[0:60], x2[0:60]),
    lambda x1, x2: seq_corr_1d(x1[0:20], x2[0:20]),
    lambda x1, x2: seq_corr_1d(x1[0:10], x2[0:10]),
    lambda x1, x2: seq_corr_1d(x1[0:5], x2[0:5])
]

DAYS_LAG = 500
DAYS_LEAD = 250

In [57]:
ts_df = pd.read_csv("data/ts_df/ts_df.csv", encoding='utf-8')
dates = ['2017-01-03', '2018-01-02', '2019-01-02', '2020-01-02', '2021-01-04', '2022-01-03']
embed_sets = [RIFT_Dataset(ts_df, (date, date), target_fns=target_fns, days_lag=DAYS_LAG, days_lead=DAYS_LEAD, sample_size="ALL") for date in dates]
with open('data/embed/embed_sets.dill', 'wb') as handle:
    dill.dump(embed_sets, handle)

In [12]:
best_model_id = '8e7ca76ba73e4bcdb646329e72817438'
with open(f'mlruns/{best_model_id}/config.dill', 'rb') as handle:
    config = dill.load(handle)

model = RIFT_Model(config)
model.load_state_dict(torch.load(f"mlruns/{best_model_id}/model.pth"))
if torch.cuda.is_available():
    model = model.cuda()

final size of concatenated embeddings within the encoder is: 950


In [66]:
def torch_to_numpy(tensor):
    return(tensor.detach().cpu().numpy())


def get_embed(model, embed_set, batch_size=64):
    model.eval()
    embed_dict = dict()

    embed_loader = data.DataLoader(embed_set, batch_size=batch_size, drop_last=False, shuffle=False)
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(embed_loader):
            date, rel_date_num, t, s = labels
            print(f"{i}: {date[0]}, {t}")
            for j in range(len(t)):
                embed_dict[t[j]] = torch_to_numpy(model.siamese_encoder.encoder_forward(inputs[0])[0][j])
    
    embed_df = pd.DataFrame.from_dict(embed_dict)
    return embed_df

In [None]:
for i, embed_set in enumerate(embed_sets):
    embed_df = get_embed(model, embed_set)
    embed_df.to_csv(f"data/embed/{dates[i]}.csv", encoding='utf-8')