In [1]:
#model: https://github.com/kristpapadopoulos/seriesnet/blob/master/seriesnet.py
#causalconv1: https://github.com/pytorch/pytorch/issues/1333

import torch
import numpy as np
import torch.nn as nn
from math import sqrt
import torch.nn.functional as F
import pytorch_lightning as pl

def weight_init(m):
    if isinstance(m, nn.Conv1d):
        torch.nn.init.trunc_normal_(m.weight, 0.0, 0.05)

class CausalConv1d(torch.nn.Conv1d):
    def __init__(self,num_conditions,num_features,in_channels,out_channels,kernel_size,dilation,device):
        self.__padding = ((kernel_size - 1) * dilation)
        
        #(batch, in_channels, in_length) 
        super(CausalConv1d, self).__init__(in_channels = num_features,out_channels = 32,
            kernel_size=2,stride=1,padding=self.__padding,dilation=dilation,groups=1,bias=False,device=device)

    def forward(self, x_input,cond): 
        """ x_input = [32,7,6], cond=[32,2]"""
        conv1d_out = super(CausalConv1d, self).forward(x_input)
        conv1d_out = conv1d_out[:, :, :-self.__padding] if self.__padding != 0 else conv1d_out
        return conv1d_out

class DC_CNN_Block(torch.nn.Module):
    def __init__(self,num_conditions,num_features,in_channels,out_channels,kernel_size,dilation):
        super(DC_CNN_Block, self).__init__()

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.causal = CausalConv1d(num_conditions,num_features,in_channels, out_channels,kernel_size,dilation,device=device)
        self.in_channels= in_channels
        self.out_channels = out_channels
        self.num_features = num_features
        self.causal.apply(weight_init)

    def forward(self, x_input,cond):
        dev = x_input.device

        residual = x_input #[32,6,7]

        #[32,6,7] => [32,7,6]
        x = x_input.permute(0,2,1).to(dev)

        #[32,7,6] => [32,32,6]
        layer_out = self.causal(x,cond).to(dev)

        #[32,32,6] => [32, 6, 32]
        layer_out = layer_out.permute(0,2,1).to(dev) 
        
        #[32, 6, 32] => [32, 6, 32]
        layer_out = F.selu(layer_out)

        #[32, 7, 6] => [32, 7, 6]
        skip_out = nn.Conv1d(in_channels=layer_out.shape[1],out_channels=self.out_channels,kernel_size=1,stride=5,bias=False).to(dev)(layer_out)

        #[32, 7, 6] => [32, 7, 6]
        network_in = nn.Conv1d(in_channels=layer_out.shape[1],out_channels=self.out_channels,kernel_size=1,stride=5,bias=False).to(dev)(layer_out)

        #[32,6,7] + [32,6,7] => [32,6,7]
        network_out = residual + network_in
        return network_out, skip_out

class CausalModel(pl.LightningModule):
    def __init__(self, w_decay,dropout,alpha,gamma,input_seq_length,output_seq_length,
                 num_features,lr,num_conditions, path,feature_list,
                 net,in_channels,out_channels,kernel_size,stride,dilation,groups,bias):

        super(CausalModel,self).__init__()
        self.w_decay = w_decay
        self.dropout = dropout
        self.alpha = alpha
        self.gamma = gamma
        self.input_seq_length = input_seq_length
        self.output_seq_length = output_seq_length
        self.num_features = num_features
        self.lr = lr
        self.num_conditions = num_conditions
        self.path = path
        self.feature_list = feature_list
        self.net = net
        self.save_hyperparameters()

        self.block = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,
                                  in_channels=self.input_seq_length, out_channels=self.output_seq_length,kernel_size=2,dilation=1)
        
        self.block.apply(weight_init)# Initialize weights 

    def forward(self,x_input,cond): 
        """
        input sequence: [batch_size, input_seq_length,num_features]
        condition: [batch_size,num_conditions]

        """
        dev = x_input.device #[32,6,7]

        l1a, l1b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=1)(x_input,cond)
        l2a, l2b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=2)(l1a,cond)
        l3a, l3b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=4)(l2a,cond)
        l4a, l4b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=8)(l3a,cond)
        l5a, l5b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=16)(l4a,cond)
        l6a, l6b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=32)(l5a,cond)

        l6b = nn.Dropout(0.8)(l6b) #dropout used to limit influence of earlier data

        l7a, l7b = DC_CNN_Block(num_conditions=self.num_conditions,num_features=self.num_features,in_channels=self.input_seq_length, 
                                out_channels=self.output_seq_length,kernel_size=2,dilation=64)(l6a,cond)

        l7b = nn.Dropout(0.8)(l7b) #dropout used to limit influence of earlier data

        l8 = l1b + l2b + l3b + l4b + l5b + l6b + l7b

        l9 = F.relu(l8)

        kernel_size = 7 if self.input_seq_length ==9 else 1
        l21 = nn.Conv1d(in_channels=l9.shape[1],out_channels=self.output_seq_length,kernel_size=1,bias=False).to(dev)(l9)

        return l21

    def training_step(self, batch, batch_idx):
        target_in, target_out,condition, mask = batch
        condition = condition[:,:self.num_conditions] #settings to number of conditions
        target_in = torch.tensor(target_in, dtype=torch.float32).to(target_in.device)#[32, 6, 7]
        target_out = torch.tensor(target_out, dtype=torch.float32).to(target_out.device)#[32, 6, 7]        
        
        y_pred = self(target_in,condition)#[32, 6, 7]
        synth_mask = y_pred
        real_mask = target_out

        rmse = rmse_loss(synth_mask,real_mask)

        self.log("loss_train", rmse,on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {"loss":rmse,"past":target_in,"ytrue":target_out,"ypred":y_pred,"conditions":condition, "mask":mask}

    def validation_step(self, batch, batch_idx):
        target_in, target_out,condition, mask = batch
        condition = condition[:,:self.num_conditions] #settings to number of conditions
        target_in = torch.tensor(target_in, dtype=torch.float32).to(target_in.device)#[32, 24, 8]
        target_out = torch.tensor(target_out, dtype=torch.float32).to(target_out.device)#[32, 24, 8]        
        
        y_pred = self(target_in,condition)#[32, 24, 8]    
        
        synth_mask = y_pred
        real_mask = target_out

        rmse = rmse_loss(synth_mask,real_mask)
        self.log("loss_val",rmse)
        return {"loss":rmse,"past":target_in,"ytrue":target_out,"ypred":y_pred,"conditions":condition, "mask":mask}

    def configure_optimizers(self):
        #weight_decay sustituye a: kernel_regularizer=l2(l2_layer_reg))
        optimizer = torch.optim.Adam(self.parameters(),lr=0.00075,betas=(0.9, 0.999),weight_decay=0.0)        
        return {"optimizer": optimizer,"monitor": "loss",}

  warn(f"Failed to load image Python extension: {e}")
