In [None]:
import torch
import torch.nn as nn

class ResidualBlock_1(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ResidualBlock_1, self).__init__()
        self.relu1 = nn.Sequential(nn.Linear(input_size, input_size),
                          nn.ReLU())
        self.gru1 = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.relu2 = nn.Sequential(nn.Linear(hidden_size*2, input_size),
                                   nn.ReLU()
                                   )
        self.gru2 = nn.GRU(input_size, int(hidden_size), batch_first=True, bidirectional=True)
        self.relu3 = nn.Sequential(nn.Linear(hidden_size*2, input_size),
                            nn.ReLU()
                            )
        self.BN = nn.BatchNorm1d(45, affine=False)


    def forward(self, x):
        out = self.relu1(x)
        out, _ = self.gru1(out)
        out_1 = self.relu2(out)
        out, _ = self.gru2(out_1)
        out = self.relu3(out)
        out = torch.add(out, out_1)
        out = self.BN(out)

        return out



class ResidualBlock_2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ResidualBlock_2, self).__init__()
        self.gru1 = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.relu2 = nn.Sequential(nn.Linear(hidden_size*2, input_size),
                                   nn.ReLU()
                                   )
        self.gru2 = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.relu3 = nn.Sequential(nn.Linear(hidden_size*2, input_size),
                                   nn.ReLU()
                                   )
        self.BN1 = nn.BatchNorm1d(45, affine=False)


    def forward(self, x):
        out, _ = self.gru1(x)
        out_1 = self.relu2(out)
        out,_ = self.gru2(out_1)
        out = self.relu3(out)
        out = torch.add(out, out_1)
        out = self.BN1(out)

        return out


class DeepResidualBidirGRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DeepResidualBidirGRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size


        self.res_block1 = ResidualBlock_1(input_size, hidden_size)
        self.res_block2 = ResidualBlock_2(input_size, hidden_size)

        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):

        out = self.res_block1(x)
        out = self.res_block2(out)
        out = self.linear(out)
        return out

    def compute_l2_loss(self, w):
        return torch.square(w).sum()