In [None]:
# get data first
from src.res.model.util import BatchInput , BatchData
batch_input = BatchInput.generate('day+style' , 2)

BatchData:
x : (torch.Size([5171, 30, 6]), torch.Size([5171, 30, 10]))
y : torch.Size([5171, 2])
w : None
i : torch.Size([5171, 2])
valid : torch.Size([5171])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from src.res.algo.nn import layer as Layer
from src.res.model.util import BatchOutput

class mod_gru(nn.Module):
    def __init__(self , input_dim , output_dim , dropout=0.0 , num_layers = 2):
        super().__init__()
        num_layers = min(3,num_layers)
        self.gru = nn.GRU(input_dim , output_dim , num_layers = num_layers , dropout = dropout , batch_first = True)
    def forward(self, x : Tensor) -> Tensor:
        return self.gru(x)[0]

class Astgnn(nn.Module):
    def __init__(self,input_dim,hidden_dim = 128,dropout = 0.1,rnn_layers = 2,enc_in=None,enc_in_dim=64,
                 act_type='leaky',dec_mlp_layers=2,dec_mlp_dim=128,
                 alpha_num = 60 , beta_num = 10 ,
                 **kwargs):
        super().__init__()
        self.fc_enc_in = nn.Sequential(nn.Linear(input_dim, enc_in_dim),nn.Tanh())

        rnn_kwargs = {'input_dim':enc_in_dim,'output_dim':hidden_dim,'num_layers':rnn_layers, 'dropout':dropout}
        self.fc_rnn = mod_gru(**rnn_kwargs)

        self.alpha_net = nn.Sequential(
            nn.Linear(hidden_dim , alpha_num), 
            Layer.Act.get_activation_fn(act_type), 
            nn.Dropout(dropout)
        )
        self.beta_net = nn.Sequential(
            nn.Linear(hidden_dim , beta_num), 
            Layer.Act.get_activation_fn(act_type), 
            nn.Dropout(dropout)
        )
        self.alpha_map_out = Layer.EwLinear()
    def forward(self, input : Tensor | tuple[Tensor,...] | list[Tensor]):
        '''
        in: [bs x seq_len x input_dim]
        out:[bs x hidden_dim]
        '''
        x = input if isinstance(input , Tensor) else torch.concat(input , dim = -1) 
        print(f'input shape: {x.shape}')
        x = self.fc_enc_in(x)
        print(f'enc_in shape: {x.shape}')
        x = self.fc_rnn(x)
        print(f'rnn outpur shape: {x.shape}')
        x , x_1 = x[:,-1] , x[: , -2]
        print(f'last rnn output shape: {x.shape}')
        print(f'second last rnn output shape: {x_1.shape}')
  
        alphas = self.alpha_net(x)
        print(f'alphas shape: {alphas.shape}')
        betas = self.beta_net(x)
        print(f'betas shape: {betas.shape}')
        betas_1 = self.beta_net(x_1)

        pred = self.alpha_map_out(alphas) 
        print(f'pred shape: {pred.shape}')
        return pred , {'alphas':alphas , 'betas':betas , 'betas_1':betas_1}

class AstgnnLoss(nn.Module):
    def __init__(self, lamb : float = 0.1):
        super().__init__()
        self.lamb = lamb

    def forward(self, pred , labels , alphas , betas , betas_1 , **kwargs):
        assert labels.shape[-1] == 2 , labels.shape
        mse = F.mse_loss(pred.squeeze() , labels[...,0].squeeze())
        rsquare = self.rsquare_loss(alphas , labels[...,1])
        corr = self.corr_loss(betas)
        corr2 = self.corr_loss2(betas)
        turnover = self.turnover_loss(betas , betas_1)
        print('mse' , mse)
        print('rsquare' , rsquare)
        print('corr' , corr)
        print('corr2' , corr2)
        print('turnover' , turnover)
        return mse + rsquare + self.lamb * corr + turnover

    def rsquare_loss(self, hiddens : Tensor , label : Tensor , **kwargs):
        assert hiddens.ndim == 2 , hiddens.shape
        y_norm = label.norm()
        pred = hiddens @ (hiddens.T @ hiddens).inverse() @ hiddens.T @ label
        res_norm = (label - pred).norm()
        return 1 - res_norm / y_norm

    def corr_loss(self, hiddens : Tensor , **kwargs):
        h = (hiddens - hiddens.mean(dim=0,keepdim=True)) / (hiddens.std(dim=0,keepdim=True) + 1e-6)
        pen = h.T.cov().norm()
        return pen

    def corr_loss2(self, hiddens : Tensor , **kwargs):
        h = (hiddens - hiddens.mean(dim=0,keepdim=True)) / (hiddens.std(dim=0,keepdim=True) + 1e-6)
        pen = h.T.cov().square().mean().sqrt() * h.shape[-1]
        return pen

    def turnover_loss(self, betas : Tensor , betas_1 : Tensor , **kwargs):
        return (betas - betas_1).norm()

model = Astgnn(16)
model = model.to(batch_input.device)
batch_output = BatchOutput.from_module(model , batch_input)
batch_data = BatchData(batch_input , batch_output)

input shape: torch.Size([5171, 30, 16])
enc_in shape: torch.Size([5171, 30, 64])
rnn outpur shape: torch.Size([5171, 30, 128])
last rnn output shape: torch.Size([5171, 128])
second last rnn output shape: torch.Size([5171, 128])
alphas shape: torch.Size([5171, 60])
betas shape: torch.Size([5171, 10])
pred shape: torch.Size([5171, 1])
BatchOutput:
pred : torch.Size([5171, 1])
alphas : torch.Size([5171, 60])
betas : torch.Size([5171, 10])
betas_1 : torch.Size([5171, 10])


In [None]:
model_loss = AstgnnLoss()
model_loss(**batch_data.loss_inputs())

mse tensor(1.0023, device='mps:0', grad_fn=<MseLossBackward0>)
rsquare tensor(0.0525, device='mps:0', grad_fn=<RsubBackward1>)
corr tensor(3.6176, device='mps:0', grad_fn=<NormBackward1>)
corr2 tensor(3.6176, device='mps:0', grad_fn=<MulBackward0>)
turnover tensor(9.6108, device='mps:0', grad_fn=<NormBackward1>)


tensor(11.0274, device='mps:0', grad_fn=<AddBackward0>)