In [3]:
import yaml , os

out_dict = {'a':10}
out_path = f'./results/running_results.yaml'
if os.path.exists(out_path):
    out_type = 'a'
else:
    os.makedirs(os.path.dirname(out_path) , exist_ok=True)
    out_type = 'w'

with open(out_path , out_type , encoding = 'utf-8') as f:
    yaml.dump(out_dict , f)

In [4]:
%run gen_data.py

[1m[37m[41m23-11-09 00:13:41|MOD:gen_data    |[0m: [1m[31mData loading start![0m


Preparing day_trading_data data...
arr shape : (5264, 6202, 1, 6) , row shape : (5264,) , col shape : (6202,)
Preparing day_ylabels_data data...
arr shape : (5249, 6191, 1, 2) , row shape : (5249,) , col shape : (6191,)
Preparing 15m_trading_data data...


[1m[37m[45m23-11-09 00:20:02|MOD:gen_data    |[0m: [1m[35m[day] Data avg and std generation start![0m


arr shape : (5204, 3302, 16, 6) , row shape : (5204,) , col shape : (3302,)
Loading day trading data finished, cost 5.64 Secs
torch.Size([5264, 920, 1, 6])


[1m[37m[45m23-11-09 00:20:19|MOD:gen_data    |[0m: [1m[35m[15m] Data avg and std generation start![0m


Loading 15m trading data finished, cost 45.60 Secs
torch.Size([5204, 340, 1, 6])


[1m[37m[41m23-11-09 00:21:36|MOD:gen_data    |[0m: [1m[31mData loading Finished! Cost 475.09 Seconds[0m


In [None]:
%run run_model.py --process=0 --rawname=1 --resume=0 --anchoring=0

In [49]:
import shutil , os
folder = './model'
save_model = ['.ipynb_checkpoints', 'LSTM_day','GRU_day','Transformer_day','GeneralRNN_day','GeneralRNN_day_Trans_vs_LSTM']
# [print(f'{folder}/{p}') for p in os.listdir(folder) if not p in save_model]

import shutil
shutil.rmtree('./model/DoubleGRU_both_single')

In [7]:
import torch
import torch.nn as nn

b , s , f , h = 2000 , 30 , 6 , 8
x = torch.rand(b,s,f)
x[int(b*0.9):] = torch.nan

class PositionalEncoding(nn.Module):
    def __init__(self, input_dim, dropout=0.0, max_len=1000,**kwargs):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.seq_len = max_len
        self.P = torch.zeros(1 , self.seq_len, input_dim)
        X = torch.arange(self.seq_len, dtype=torch.float).reshape(-1,1) / torch.pow(10000,torch.arange(0, input_dim, 2 ,dtype=torch.float) / input_dim)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X[:,:input_dim//2])
    def forward(self, inputs):
        return self.dropout(inputs + self.P[:,:inputs.shape[1],:].to(inputs.device))
    
class TimeWiseTranformer(nn.Module):
    def __init__(self , input_dim , hidden_dim , ffn_dim = None , num_heads = 8 , num_enclayer = 2 , dropout=0.0):
        super().__init__()
        assert hidden_dim % num_heads == 0
        ffn_dim = 4 * hidden_dim if ffn_dim is None else ffn_dim
        self.fc_in = nn.Linear(input_dim,hidden_dim)
        self.pos_enc = PositionalEncoding(hidden_dim,dropout=dropout)
        enc_layer = nn.TransformerEncoderLayer(hidden_dim , num_heads, dim_feedforward=ffn_dim , dropout=dropout , batch_first=True)
        self.trans = nn.TransformerEncoder(enc_layer , num_enclayer)
    def forward(self, inputs):
        hidden = self.fc_in(inputs)
        hidden = self.pos_enc(hidden)
        return self.trans(hidden)

tf = TimeWiseTranformer(f,h,num_enclayer=6)
# tf(x).select(-2,-1).shape
tf(x).shape

a = locals()['TimeWiseTranformer']
a

__main__.TimeWiseTranformer

In [2]:
import torch
import torch.nn as nn
from mymodel import *

b , s , f , h = 2000 , 30 , 6 , 8
x = torch.rand(b,s,f)
x[int(b*0.9):] = torch.nan

net = mod_transformer(f,h)
net(x).shape

torch.Size([2000, 30, 8])

In [92]:
import torch
import matplotlib.pyplot as plt
base_lr = 5e-2
max_lr  = 1e-1
min_lr = base_lr * 1e-4

net = torch.nn.Sequential(torch.nn.Linear(10,10) , torch.nn.Linear(10,1))
net_base_lr = [{'params': p , 'lr':l , 'lr_param' : l} for l,p in zip([base_lr , base_lr/2] , net.parameters())]
opt = torch.optim.Adam(net_base_lr , 0.001)


In [6]:
from collections import deque
a = deque(range(10) ,   maxlen = 10)
print(a)

deque([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=10)


In [None]:
# TRA model
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim


device = "cuda" if torch.cuda.is_available() else "cpu"


class TRAModel(Model):
    def __init__(
        self,
        model_config,
        tra_config,
        model_type="LSTM",
        lr=1e-3,
        n_epochs=500,
        early_stop=50,
        smooth_steps=5,
        max_steps_per_epoch=None,
        freeze_model=False,
        model_init_state=None,
        lamb=0.0,
        rho=0.99,
        seed=None,
        logdir=None,
        eval_train=True,
        eval_test=False,
        avg_params=True,
        **kwargs,
    ):
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.logger = get_module_logger("TRA")
        self.logger.info("TRA Model...")

        self.model = eval(model_type)(**model_config).to(device)
        if model_init_state:
            self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
        if freeze_model:
            for param in self.model.parameters():
                param.requires_grad_(False)
        else:
            self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))

        self.tra = TRA(self.model.output_size, **tra_config).to(device)
        self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))

        self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)

        self.model_config = model_config
        self.tra_config = tra_config
        self.lr = lr
        self.n_epochs = n_epochs
        self.early_stop = early_stop
        self.smooth_steps = smooth_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lamb = lamb
        self.rho = rho
        self.seed = seed
        self.logdir = logdir
        self.eval_train = eval_train
        self.eval_test = eval_test
        self.avg_params = avg_params

        if self.tra.num_states > 1 and not self.eval_train:
            self.logger.warn("`eval_train` will be ignored when using TRA")

        if self.logdir is not None:
            if os.path.exists(self.logdir):
                self.logger.warn(f"logdir {self.logdir} is not empty")
            os.makedirs(self.logdir, exist_ok=True)

        self.fitted = False
        self.global_step = -1

    def train_epoch(self, data_set):
        self.model.train()
        self.tra.train()

        data_set.train()

        max_steps = self.n_epochs
        if self.max_steps_per_epoch is not None:
            max_steps = min(self.max_steps_per_epoch, self.n_epochs)

        count = 0
        total_loss = 0
        total_count = 0
        for batch in tqdm(data_set, total=max_steps):
            count += 1
            if count > max_steps:
                break

            self.global_step += 1

            data, label, index = batch["data"], batch["label"], batch["index"]

            feature = data[:, :, : -self.tra.num_states]
            hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]

            hidden = self.model(feature)
            pred, all_preds, prob = self.tra(hidden, hist_loss)

            loss = (pred - label).pow(2).mean()

            L = (all_preds.detach() - label[:, None]).pow(2)
            L -= L.min(dim=-1, keepdim=True).values  # normalize & ensure positive input

            data_set.assign_data(index, L)  # save loss to memory

            if prob is not None:
                P = sinkhorn(-L, epsilon=0.01)  # sample assignment matrix
                lamb = self.lamb * (self.rho**self.global_step)
                reg = prob.log().mul(P).sum(dim=-1).mean()
                loss = loss - lamb * reg

            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            total_loss += loss.item()
            total_count += len(pred)

        total_loss /= total_count

        return total_loss

    def test_epoch(self, data_set, return_pred=False):
        self.model.eval()
        self.tra.eval()
        data_set.eval()

        preds = []
        metrics = []
        for batch in tqdm(data_set):
            data, label, index = batch["data"], batch["label"], batch["index"]

            feature = data[:, :, : -self.tra.num_states]
            hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]

            with torch.no_grad():
                hidden = self.model(feature)
                pred, all_preds, prob = self.tra(hidden, hist_loss)

            L = (all_preds - label[:, None]).pow(2)

            L -= L.min(dim=-1, keepdim=True).values  # normalize & ensure positive input

            data_set.assign_data(index, L)  # save loss to memory

            X = np.c_[
                pred.cpu().numpy(),
                label.cpu().numpy(),
            ]
            columns = ["score", "label"]
            if prob is not None:
                X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
                columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
                    "prob_%d" % d for d in range(all_preds.shape[1])
                ]

            pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)

            metrics.append(evaluate(pred))

            if return_pred:
                preds.append(pred)

        metrics = pd.DataFrame(metrics)
        metrics = {
            "MSE": metrics.MSE.mean(),
            "MAE": metrics.MAE.mean(),
            "IC": metrics.IC.mean(),
            "ICIR": metrics.IC.mean() / metrics.IC.std(),
        }

        if return_pred:
            preds = pd.concat(preds, axis=0)
            preds.index = data_set.restore_index(preds.index)
            preds.index = preds.index.swaplevel()
            preds.sort_index(inplace=True)

        return metrics, preds

    def fit(self, dataset, evals_result=dict()):
        train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])

        best_score = -1
        best_epoch = 0
        stop_rounds = 0
        best_params = {
            "model": copy.deepcopy(self.model.state_dict()),
            "tra": copy.deepcopy(self.tra.state_dict()),
        }
        params_list = {
            "model": collections.deque(maxlen=self.smooth_steps),
            "tra": collections.deque(maxlen=self.smooth_steps),
        }
        evals_result["train"] = []
        evals_result["valid"] = []
        evals_result["test"] = []

        # train
        self.fitted = True
        self.global_step = -1

        if self.tra.num_states > 1:
            self.logger.info("init memory...")
            self.test_epoch(train_set)

        for epoch in range(self.n_epochs):
            self.logger.info("Epoch %d:", epoch)

            self.logger.info("training...")
            self.train_epoch(train_set)

            self.logger.info("evaluating...")
            # average params for inference
            params_list["model"].append(copy.deepcopy(self.model.state_dict()))
            params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
            self.model.load_state_dict(average_params(params_list["model"]))
            self.tra.load_state_dict(average_params(params_list["tra"]))

            # NOTE: during evaluating, the whole memory will be refreshed
            if self.tra.num_states > 1 or self.eval_train:
                train_set.clear_memory()  # NOTE: clear the shared memory
                train_metrics = self.test_epoch(train_set)[0]
                evals_result["train"].append(train_metrics)
                self.logger.info("\ttrain metrics: %s" % train_metrics)

            valid_metrics = self.test_epoch(valid_set)[0]
            evals_result["valid"].append(valid_metrics)
            self.logger.info("\tvalid metrics: %s" % valid_metrics)

            if self.eval_test:
                test_metrics = self.test_epoch(test_set)[0]
                evals_result["test"].append(test_metrics)
                self.logger.info("\ttest metrics: %s" % test_metrics)

            if valid_metrics["IC"] > best_score:
                best_score = valid_metrics["IC"]
                stop_rounds = 0
                best_epoch = epoch
                best_params = {
                    "model": copy.deepcopy(self.model.state_dict()),
                    "tra": copy.deepcopy(self.tra.state_dict()),
                }
            else:
                stop_rounds += 1
                if stop_rounds >= self.early_stop:
                    self.logger.info("early stop @ %s" % epoch)
                    break

            # restore parameters
            self.model.load_state_dict(params_list["model"][-1])
            self.tra.load_state_dict(params_list["tra"][-1])

        self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
        self.model.load_state_dict(best_params["model"])
        self.tra.load_state_dict(best_params["tra"])

        metrics, preds = self.test_epoch(test_set, return_pred=True)
        self.logger.info("test metrics: %s" % metrics)

        if self.logdir:
            self.logger.info("save model & pred to local directory")

            pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
                self.logdir + "/logs.csv", index=False
            )

            torch.save(best_params, self.logdir + "/model.bin")

            preds.to_pickle(self.logdir + "/pred.pkl")

            info = {
                "config": {
                    "model_config": self.model_config,
                    "tra_config": self.tra_config,
                    "lr": self.lr,
                    "n_epochs": self.n_epochs,
                    "early_stop": self.early_stop,
                    "smooth_steps": self.smooth_steps,
                    "max_steps_per_epoch": self.max_steps_per_epoch,
                    "lamb": self.lamb,
                    "rho": self.rho,
                    "seed": self.seed,
                    "logdir": self.logdir,
                },
                "best_eval_metric": -best_score,  # NOTE: minux -1 for minimize
                "metric": metrics,
            }
            with open(self.logdir + "/info.json", "w") as f:
                json.dump(info, f)

    def predict(self, dataset, segment="test"):
        if not self.fitted:
            raise ValueError("model is not fitted yet!")

        test_set = dataset.prepare(segment)

        metrics, preds = self.test_epoch(test_set, return_pred=True)
        self.logger.info("test metrics: %s" % metrics)

        return preds



In [None]:
# TRA dataset
import copy
import torch
import numpy as np
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"


def _to_tensor(x):
    if not isinstance(x, torch.Tensor):
        return torch.tensor(x, dtype=torch.float, device=device)
    return x


def _create_ts_slices(index, seq_len):
    """
    create time series slices from pandas index

    Args:
        index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
        seq_len (int): sequence length
    """
    assert index.is_lexsorted(), "index should be sorted"

    # number of dates for each code
    sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values

    # start_index for each code
    start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
    start_index_of_codes[0] = 0

    # all the [start, stop) indices of features
    # features btw [start, stop) are used to predict the `stop - 1` label
    slices = []
    for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
        for stop in range(1, cur_cnt + 1):
            end = cur_loc + stop
            start = max(end - seq_len, 0)
            slices.append(slice(start, end))
    slices = np.array(slices)

    return slices


def _get_date_parse_fn(target):
    """get date parse function

    This method is used to parse date arguments as target type.

    Example:
        get_date_parse_fn('20120101')('2017-01-01') => '20170101'
        get_date_parse_fn(20120101)('2017-01-01') => 20170101
    """
    if isinstance(target, pd.Timestamp):
        _fn = lambda x: pd.Timestamp(x)  # Timestamp('2020-01-01')
    elif isinstance(target, str) and len(target) == 8:
        _fn = lambda x: str(x).replace("-", "")[:8]  # '20200201'
    elif isinstance(target, int):
        _fn = lambda x: int(str(x).replace("-", "")[:8])  # 20200201
    else:
        _fn = lambda x: x
    return _fn


class MTSDatasetH(DatasetH):
    """Memory Augmented Time Series Dataset

    Args:
        handler (DataHandler): data handler
        segments (dict): data split segments
        seq_len (int): time series sequence length
        horizon (int): label horizon (to mask historical loss for TRA)
        num_states (int): how many memory states to be added (for TRA)
        batch_size (int): batch size (<0 means daily batch)
        shuffle (bool): whether shuffle data
        pin_memory (bool): whether pin data to gpu memory
        drop_last (bool): whether drop last batch < batch_size
    """

    def __init__(
        self,
        handler,
        segments,
        seq_len=60,
        horizon=0,
        num_states=1,
        batch_size=-1,
        shuffle=True,
        pin_memory=False,
        drop_last=False,
        **kwargs,
    ):
        assert horizon > 0, "please specify `horizon` to avoid data leakage"

        self.seq_len = seq_len
        self.horizon = horizon
        self.num_states = num_states
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.params = (batch_size, drop_last, shuffle)  # for train/eval switch

        super().__init__(handler, segments, **kwargs)

    def setup_data(self, handler_kwargs: dict = None, **kwargs):
        super().setup_data()

        # change index to <code, date>
        # NOTE: we will use inplace sort to reduce memory use
        df = self.handler._data
        df.index = df.index.swaplevel()
        df.sort_index(inplace=True)

        self._data = df["feature"].values.astype("float32")
        self._label = df["label"].squeeze().astype("float32")
        self._index = df.index

        # add memory to feature
        self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]

        # padding tensor
        self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)

        # pin memory
        if self.pin_memory:
            self._data = _to_tensor(self._data)
            self._label = _to_tensor(self._label)
            self.zeros = _to_tensor(self.zeros)

        # create batch slices
        self.batch_slices = _create_ts_slices(self._index, self.seq_len)

        # create daily slices
        index = [slc.stop - 1 for slc in self.batch_slices]
        act_index = self.restore_index(index)
        daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
        for i, (code, date) in enumerate(act_index):
            daily_slices[date].append(self.batch_slices[i])
        self.daily_slices = list(daily_slices.values())

    def _prepare_seg(self, slc, **kwargs):
        fn = _get_date_parse_fn(self._index[0][1])

        if isinstance(slc, slice):
            start, stop = slc.start, slc.stop
        elif isinstance(slc, (list, tuple)):
            start, stop = slc
        else:
            raise NotImplementedError(f"This type of input is not supported")
        start_date = fn(start)
        end_date = fn(stop)
        obj = copy.copy(self)  # shallow copy
        # NOTE: Seriable will disable copy `self._data` so we manually assign them here
        obj._data = self._data
        obj._label = self._label
        obj._index = self._index
        new_batch_slices = []
        for batch_slc in self.batch_slices:
            date = self._index[batch_slc.stop - 1][1]
            if start_date <= date <= end_date:
                new_batch_slices.append(batch_slc)
        obj.batch_slices = np.array(new_batch_slices)
        new_daily_slices = []
        for daily_slc in self.daily_slices:
            date = self._index[daily_slc[0].stop - 1][1]
            if start_date <= date <= end_date:
                new_daily_slices.append(daily_slc)
        obj.daily_slices = new_daily_slices
        return obj

    def restore_index(self, index):
        if isinstance(index, torch.Tensor):
            index = index.cpu().numpy()
        return self._index[index]

    def assign_data(self, index, vals):
        if isinstance(self._data, torch.Tensor):
            vals = _to_tensor(vals)
        elif isinstance(vals, torch.Tensor):
            vals = vals.detach().cpu().numpy()
            index = index.detach().cpu().numpy()
        self._data[index, -self.num_states :] = vals

    def clear_memory(self):
        self._data[:, -self.num_states :] = 0

    def train(self):
        """enable traning mode"""
        self.batch_size, self.drop_last, self.shuffle = self.params

    def eval(self):
        """enable evaluation mode"""
        self.batch_size = -1
        self.drop_last = False
        self.shuffle = False

    def _get_slices(self):
        if self.batch_size < 0:
            slices = self.daily_slices.copy()
            batch_size = -1 * self.batch_size
        else:
            slices = self.batch_slices.copy()
            batch_size = self.batch_size
        return slices, batch_size

    def __len__(self):
        slices, batch_size = self._get_slices()
        if self.drop_last:
            return len(slices) // batch_size
        return (len(slices) + batch_size - 1) // batch_size

    def __iter__(self):
        slices, batch_size = self._get_slices()
        if self.shuffle:
            np.random.shuffle(slices)

        for i in range(len(slices))[::batch_size]:
            if self.drop_last and i + batch_size > len(slices):
                break
            # get slices for this batch
            slices_subset = slices[i : i + batch_size]
            if self.batch_size < 0:
                slices_subset = np.concatenate(slices_subset)
            # collect data
            data = []
            label = []
            index = []
            for slc in slices_subset:
                _data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
                if len(_data) != self.seq_len:
                    if self.pin_memory:
                        _data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
                    else:
                        _data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
                if self.num_states > 0:
                    _data[-self.horizon :, -self.num_states :] = 0
                data.append(_data)
                label.append(self._label[slc.stop - 1])
                index.append(slc.stop - 1)
            # concate
            index = torch.tensor(index, device=device)
            if isinstance(data[0], torch.Tensor):
                data = torch.stack(data)
                label = torch.stack(label)
            else:
                data = _to_tensor(np.stack(data))
                label = _to_tensor(np.stack(label))
            # yield -> generator
            yield {"data": data, "label": label, "index": index}

In [23]:
import torch
import torch.nn as nn

class TRA(nn.Module):
    """Temporal Routing Adaptor (TRA)

    TRA takes historical prediction errors & latent representation as inputs,
    then routes the input sample to a specific predictor for training & inference.

    Args:
        input_size (int): input size (RNN/Transformer's hidden size)
        num_states (int): number of latent states (i.e., trading patterns)
            If `num_states=1`, then TRA falls back to traditional methods
        hidden_size (int): hidden size of the router
        tau (float): gumbel softmax temperature
    """

    def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, horizon = 20 , src_info="LR_TPE"):
        super().__init__()
        self.num_states = num_states
        self.tau = tau
        self.horizon = horizon
        self.src_info = src_info

        if num_states > 1:
            self.router = nn.LSTM(
                input_size=num_states,
                hidden_size=hidden_size,
                num_layers=1,
                batch_first=True,
            )
            self.fc = nn.Linear(hidden_size + input_size, num_states)
        self.predictors = nn.Linear(input_size, num_states)

    def forward(self, hidden, hist_loss):
        preds = self.predictors(hidden)

        if self.num_states == 1:
            final_pred = preds
            prob = None
        else:
            # information type
            router_out, _ = self.router(hist_loss[:,-self.horizon])
            if "LR" in self.src_info:
                latent_representation = hidden
            else:
                latent_representation = torch.randn_like(hidden)
            if "TPE" in self.src_info:
                temporal_pred_error = router_out[:, -1]
            else:
                temporal_pred_error = torch.randn_like(router_out[:, -1])

            out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
            prob = nn.functional.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)

            if self.training:
                final_pred = (preds * prob).sum(dim=-1 , keepdim = True)
            else:
                final_pred = preds[range(len(preds)), prob.argmax(dim=-1)].unsqueeze(-1)

        return final_pred, preds, prob


In [91]:
b = 100
s = 60
f = 10
horizon = 20
num_states = 3
x = torch.rand(b , s , f)
y = torch.rand(b , s)
label = y[:,-1].unsqueeze(-1)

rnn = nn.LSTM(f , 16 , 2 , batch_first=True)
tra = TRA(16 , num_states)
hidden = rnn(x)[0]

tra.predictors(hidden)
hist_loss1 = tra.predictors(hidden) - y.unsqueeze(-1)
pred , all_preds , prob = tra(hidden[:,-1] , hist_loss1[:,:-horizon])

torch.Size([100, 60, 16])

In [94]:
from scripts.special.TRA import *


def optimal_transport_penalty(preds , label , prob , global_batch_steps = 0 , lamb = 0.01 , rho = 0.99):
    if prob is not None:
        square_error = (preds - label).square()
        square_error -= square_error.min(dim=-1, keepdim=True).values  # normalize & ensure positive input
        P = sinkhorn(-square_error, epsilon=0.01)  # sample assignment matrix
        lamb = lamb * (rho**global_batch_steps)
        reg = prob.log().mul(P).sum(dim=-1).mean()
        return -lamb * reg
    else:
        return 0

init_lamb = 0.01
rho = 0.999
global_batch_steps = 500
b = 100
s = 60
f = 10
horizon = 20
num_states = 3
x = torch.rand(b , s , f)
y = torch.rand(b , s)
label = y[:,-1].unsqueeze(-1)

rnn = nn.LSTM(f , 16 , 2 , batch_first=True)
tra = TRA(16 , num_states)
hidden = rnn(x)[0]

tra.predictors(hidden)
hist_loss1 = tra.predictors(hidden) - y.unsqueeze(-1)
pred , all_preds , prob = tra(hidden[:,-1] , hist_loss1[:,:-horizon])

optimal_transport_penalty(all_preds , label , prob , global_batch_steps , init_lamb , rho)

tensor(0.0093, grad_fn=<MulBackward0>)