In [1]:
import torch
import numpy as np
import torch.nn as nn
from math import sqrt
import torch.nn.functional as F
import pytorch_lightning as pl

def weight_init(m):
    #sustituye a :kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.05,seed=42)
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
      torch.nn.init.trunc_normal_(m.weight, 0.0, 0.05)

class CausalConv1d(torch.nn.Conv1d):
    def __init__(self,num_conditions,num_features,input_seq_length,in_channels,out_channels,
                 kernel_size,stride=1,dilation=1,groups=1,bias=True):
        
        self.__padding = (kernel_size - 1) * dilation
        
        super(CausalConv1d, self).__init__(
            in_channels,in_channels,kernel_size=kernel_size,stride=stride,
            padding=self.__padding,dilation=dilation,groups=groups,bias=bias)

        self.hidden_layer0 = nn.Sequential(
            nn.Linear(input_seq_length*num_features+num_conditions, input_seq_length*num_features))

        self.cnn_block_1 = nn.Conv1d(
            in_channels=input_seq_length,out_channels=input_seq_length,kernel_size=1,bias=False)

    
    def forward(self, x_input,cond):
        """
        x_input = [batch_size,input_setlength,num_features]
        cond = [batch_size,num_conditions]
        """
        x_input_np = x_input.reshape(-1,x_input.shape[1]*x_input.shape[2])
        cond = cond.long()
        x = torch.cat([x_input_np,cond],1) #[N,input_setlength*num_features+num_conditions]
        x = self.hidden_layer0(x) #[32,42]
        x = x.reshape(-1,x_input.shape[1],x_input.shape[2]) #[N,input_setlength,num_features]
          
        #------- Encoder -----------#
        residual = x #[16,6,7]
        layer_out = super(CausalConv1d, self).forward(x)
        layer_out = layer_out[:, :, :-self.__padding] if self.__padding != 0 else layer_out
        layer_out = F.selu(layer_out) #[32,6,7]
          
        #------- Decoder -----------#
        skip_out = self.cnn_block_1(layer_out)
        network_in = self.cnn_block_1(layer_out)
        network_out = residual + network_in #[32,6,7]
        return network_out, skip_out

class CausalModel(pl.LightningModule):
    def __init__(self, w_decay,dropout,alpha,gamma,input_seq_length,output_seq_length,
                 num_features,lr,num_conditions, path,feature_list,
                 net,in_channels,out_channels,kernel_size,stride,dilation,groups,bias):

        super(CausalModel,self).__init__()
        self.CausalConv1d = CausalConv1d(num_conditions,num_features,input_seq_length,in_channels,out_channels,kernel_size,stride,dilation,groups,bias)
        self.w_decay = w_decay
        self.dropout = dropout
        self.alpha = alpha
        self.gamma = gamma
        self.input_seq_length = input_seq_length
        self.output_seq_length = output_seq_length
        self.num_features = num_features
        self.lr = lr
        self.num_conditions = num_conditions
        self.path = path
        self.feature_list = feature_list
        self.net = net
        self.save_hyperparameters()

        # Initialize weights
        self.CausalConv1d.apply(weight_init)

    def forward(self, x_input,cond,output_seq_length): #[32,6,7]
        dev = x_input.device
        l1a, l1b = CausalConv1d(num_conditions=cond.shape[1],num_features=x_input.shape[2],input_seq_length=x_input.shape[1],in_channels=self.input_seq_length, out_channels=self.input_seq_length, kernel_size=2, dilation=1).to(dev)(x_input,cond)
        l2a, l2b = CausalConv1d(num_conditions=cond.shape[1],num_features=l1a.shape[2],input_seq_length=l1a.shape[1],in_channels=l1a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=2).to(dev)(l1a,cond)
        l3a, l3b = CausalConv1d(num_conditions=cond.shape[1],num_features=l2a.shape[2],input_seq_length=l2a.shape[1],in_channels=l2a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=4).to(dev)(l2a,cond)
        l4a, l4b = CausalConv1d(num_conditions=cond.shape[1],num_features=l3a.shape[2],input_seq_length=l3a.shape[1],in_channels=l3a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=8).to(dev)(l3a,cond)
        l5a, l5b = CausalConv1d(num_conditions=cond.shape[1],num_features=l4a.shape[2],input_seq_length=l4a.shape[1],in_channels=l4a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=16).to(dev)(l4a,cond)
        l6a, l6b = CausalConv1d(num_conditions=cond.shape[1],num_features=l5a.shape[2],input_seq_length=l5a.shape[1],in_channels=l5a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=32).to(dev)(l5a,cond)
        l6b = nn.Dropout(p=0.8)(l6b) #dropout used to limit influence of earlier data
        l7a, l7b = CausalConv1d(num_conditions=cond.shape[1],num_features=l6a.shape[2],input_seq_length=l6a.shape[1],in_channels=l6a.shape[1], out_channels=output_seq_length, kernel_size=2, dilation=64).to(dev)(l6a,cond)
        l7b = nn.Dropout(p=0.8)(l7b) #dropout used to limit influence of earlier data

        l8 =  l1b + l2b + l3b + l4b + l5b + l6b + l7b

        l9 = F.leaky_relu(l8)
        l21 = nn.Conv1d(in_channels=l9.shape[1],out_channels=output_seq_length,kernel_size=1,bias=False).to(dev)(l9)#[32, 6, 7]
        return l21

    def training_step(self, batch, batch_idx):
        target_in, target_out,condition, mask = batch
        condition = condition[:,:self.num_conditions] #settings to number of conditions

        target_in = torch.tensor(target_in, dtype=torch.float32).to(target_in.device)#[32, 6, 7]
        target_out = torch.tensor(target_out, dtype=torch.float32).to(target_out.device)#[32, 6, 7]        
        y_pred = self(target_in,condition,target_out.shape[1])#[32, 6, 7]

        synth_mask = torch.masked_select(y_pred, mask)
        real_mask = torch.masked_select(target_out, mask)

        rmse = rmse_loss(synth_mask,real_mask)
        self.log("loss_train", rmse,on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {"loss":rmse,
                "past":target_in,"ytrue":target_out,"ypred":y_pred, 
                "conditions":condition, "mask":mask}

    def training_epoch_end(self, training_step_outputs):
        loss = torch.flatten(torch.stack([x['loss'] for x in training_step_outputs]))
        loss = loss.view(-1).detach().cpu().numpy().reshape(-1,1)
        saving_logs_training(loss)

        # Convert from list to tensor
        past = torch.flatten(torch.stack([x['past'] for x in training_step_outputs]))
        ytrue = torch.flatten(torch.stack([x['ytrue'] for x in training_step_outputs]))
        ypred = torch.flatten(torch.stack([x['ypred'] for x in training_step_outputs]))
        conditions = torch.flatten(torch.stack([x['conditions'] for x in training_step_outputs]))
        mask = torch.flatten(torch.stack([x['mask'] for x in training_step_outputs]))

        past = past.view(-1).detach().cpu().numpy().reshape(-1,1)
        ytrue = ytrue.view(-1).detach().cpu().numpy().reshape(-1,1)
        ypred = ypred.view(-1).detach().cpu().numpy().reshape(-1,1)
        conditions = conditions.view(-1).detach().cpu().numpy().reshape(-1,1)
        mask = mask.view(-1).detach().cpu().numpy().reshape(-1,1)

    def validation_step(self, batch, batch_idx):
        target_in, target_out,condition, mask = batch
        condition = condition[:,:self.num_conditions] #settings to number of conditions

        target_in = torch.tensor(target_in, dtype=torch.float32).to(target_in.device)#[32, 24, 8]
        target_out = torch.tensor(target_out, dtype=torch.float32).to(target_out.device)#[32, 24, 8]        
        y_pred = self(target_in,condition,target_out.shape[1])#[32, 24, 8]    
        
        #------- Computing RMSE loss (using masking)-------#
        synth_mask = torch.masked_select(y_pred, mask)
        real_mask = torch.masked_select(target_out, mask)

        rmse = rmse_loss(synth_mask,real_mask)

        self.log("loss_val",rmse)
        return {"loss":rmse,
                "past":target_in,"ytrue":target_out,"ypred":y_pred, 
                "conditions":condition, "mask":mask}

    def validation_epoch_end(self, validation_step_outputs):
        loss_val = torch.flatten(torch.stack([x['loss'] for x in validation_step_outputs]))
        loss_val = loss_val.view(-1).detach().cpu().numpy().reshape(-1,1)
        saving_logs_validation(loss_val)

        # Convert from list to tensor
        past_val = torch.flatten(torch.stack([x['past'] for x in validation_step_outputs]))
        ytrue_val = torch.flatten(torch.stack([x['ytrue'] for x in validation_step_outputs]))
        ypred_val = torch.flatten(torch.stack([x['ypred'] for x in validation_step_outputs]))
        conditions_val = torch.flatten(torch.stack([x['conditions'] for x in validation_step_outputs]))
        mask_val = torch.flatten(torch.stack([x['mask'] for x in validation_step_outputs]))

        past_val = past_val.view(-1).detach().cpu().numpy().reshape(-1,1)
        ytrue_val = ytrue_val.view(-1).detach().cpu().numpy().reshape(-1,1)
        ypred_val = ypred_val.view(-1).detach().cpu().numpy().reshape(-1,1)
        conditions_val = conditions_val.view(-1).detach().cpu().numpy().reshape(-1,1)
        mask_val = mask_val.view(-1).detach().cpu().numpy().reshape(-1,1)

        plotting_predictions(past_val,ytrue_val,ypred_val,mask_val,self.input_seq_length,self.output_seq_length,
                             self.num_features,self.num_conditions,conditions_val,self.path,"dcnn",1,"val",self.current_epoch)
        dwprobability(ytrue_val,ypred_val,self.output_seq_length,self.num_features,self.path,self.net,self.current_epoch,"val")

    def configure_optimizers(self):
        #weight_decay sustituye a: kernel_regularizer=l2(l2_layer_reg))
        optimizer = torch.optim.Adam(self.parameters(),lr=0.00075,betas=(0.9, 0.999),weight_decay=0.001)        
        return {"optimizer": optimizer,"monitor": "loss",}      

  warn(f"Failed to load image Python extension: {e}")


In [2]:
tensor = torch.rand([1,6,7]) #[batch_size,input_seqlength,num_features]
print(tensor)

tensor([[[0.0852, 0.3519, 0.0379, 0.4666, 0.5846, 0.8126, 0.3698],
         [0.7432, 0.1818, 0.4830, 0.7041, 0.3490, 0.7891, 0.1373],
         [0.5632, 0.2166, 0.4447, 0.5470, 0.3433, 0.3877, 0.2046],
         [0.7716, 0.6005, 0.5688, 0.4146, 0.4406, 0.4069, 0.4131],
         [0.2992, 0.6376, 0.8898, 0.7747, 0.6553, 0.1490, 0.0798],
         [0.3513, 0.0342, 0.7974, 0.9686, 0.3627, 0.5371, 0.7010]]])
