In [None]:
class BiLSTM_layer(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional, batch_first=False):
        super(BiLSTM_layer, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=batch_first
        )

        self.fc = nn.Linear(hidden_size, 26)
        

    def forward(self, inputs):
        out, (h_n, c_n) = self.lstm(inputs, None)
        outputs = self.fc(torch.mean(h_n.squeeze(0), dim=0))

        return outputs


In [None]:
class DataEncoder(nn.Module):
  def __init__(self, input_dim, output_dim, hidden_dim=3,dropout=0.4):
    super(DataEncoder, self).__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    
    self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                             nn.ReLU(),
                             nn.Dropout(dropout),
                             nn.Linear(hidden_dim, output_dim)
                            )
  def forward(self, x):
    return self.net(x)

class minmax_RuleEncoder(nn.Module):
  def __init__(self, input_dim, output_dim, hidden_dim=3,dropout=0.4):
    super(minmax_RuleEncoder, self).__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                             nn.ReLU(),
                             nn.Dropout(dropout),
                             nn.Linear(hidden_dim, output_dim)
                            )

  def forward(self, x):
    return self.net(x)

class outbound_RuleEncoder(nn.Module):
  def __init__(self, input_dim, output_dim, hidden_dim=3,dropout=0.4):
    super(outbound_RuleEncoder, self).__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                             nn.ReLU(),
                             nn.Dropout(dropout),
                             nn.Linear(hidden_dim, output_dim)
                            )

  def forward(self, x):
    return self.net(x)


In [None]:
class Net(nn.Module):
  def __init__(self, input_dim, output_dim, minmax_rule_encoder, outbound_rule_encoder, data_encoder, hidden_dim=3, n_layers=1, merge='cat', skip=False, input_type='state'):
    super(Net, self).__init__()
    self.skip = skip
    self.input_type = input_type
    self.minmax_rule_encoder = minmax_rule_encoder
    self.outbound_rule_encoder = outbound_rule_encoder
    self.data_encoder = data_encoder
    self.n_layers = n_layers
    assert self.minmax_rule_encoder.input_dim ==  self.data_encoder.input_dim
    assert self.minmax_rule_encoder.output_dim ==  self.data_encoder.output_dim
    self.merge = merge
    if merge == 'cat':
      self.input_dim_decision_block = self.minmax_rule_encoder.output_dim * 3
    elif merge == 'add':
      self.input_dim_decision_block = self.minmax_rule_encoder.output_dim

    self.net = []
    for i in range(n_layers):
      if i == 0:
        in_dim = self.input_dim_decision_block
      else:
        in_dim = hidden_dim

      if i == n_layers-1:
        out_dim = output_dim
      else:
        out_dim = hidden_dim
      
      self.net.append(BiLSTM_layer(
              input_size=in_dim,
              hidden_size=64,
              num_layers=1,
              bidirectional=True,
              batch_first=True
          ))
    self.net = nn.Sequential(*self.net)

  def get_z(self, x, alpha=0.0, beta=0.0):
    minmax_rule_z = self.minmax_rule_encoder(x)
    outbound_rule_z = self.outbound_rule_encoder(x)
    data_z = self.data_encoder(x)

    if self.merge=='add':
      z = alpha * minmax_rule_z + beta * outbound_rule_z + (1-alpha-beta) * data_z    
    elif self.merge=='cat':
      z = torch.cat((alpha*minmax_rule_z , beta*outbound_rule_z , (1-alpha-beta)*data_z), dim=-1)   

    return z

  def forward(self, x, alpha=0.0, beta=0.0):
    minmax_rule_z = self.minmax_rule_encoder(x)
    outbound_rule_z = self.outbound_rule_encoder(x)
    data_z = self.data_encoder(x)

    if self.merge=='add':
      z = alpha*minmax_rule_z + beta*outbound_rule_z + (1-alpha-beta)*data_z   
    elif self.merge=='cat':
      z = torch.cat((alpha*minmax_rule_z , beta*outbound_rule_z , (1-alpha-beta)*data_z), dim=-1)  
    else:
      print(self.merge)
  
    if self.skip:
      if self.input_type == 'seq':
        return self.net(z) + x[:,-1,:]
      else:
        return self.net(z) + x    
    else:
      return self.net(z) 
      
      try:
        return self.net(z)   
      except:
        print(type(z))
        print(z.shape)

class DataonlyNet(nn.Module):
  def __init__(self, input_dim, output_dim, data_encoder, hidden_dim=4, n_layers=2, skip=False, input_type='state'):
    super(DataonlyNet, self).__init__()
    self.skip = skip
    self.input_type = input_type
    self.data_encoder = data_encoder
    self.n_layers = n_layers
    self.input_dim_decision_block = self.data_encoder.output_dim

    self.net = []
    for i in range(n_layers):
      if i == 0:
        in_dim = self.input_dim_decision_block
      else:
        in_dim = hidden_dim

      if i == n_layers-1:
        out_dim = output_dim
      else:
        out_dim = hidden_dim

      self.net.append(BiLSTM_layer(
            input_size=in_dim,
            hidden_size=64,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        ))
      
    self.net = nn.Sequential(*self.net)

  def get_z(self, x, alpha=0.0):
    data_z = self.data_encoder(x)

    return data_z

  def forward(self, x, alpha=0.0):
    data_z = self.data_encoder(x)
    z = data_z

    if self.skip:
      if self.input_type == 'seq':
        return self.net(z) + x[:,-1,:]
      else:
        return self.net(z) + x    
    else:
      return self.net(z)    

In [None]:
merge = 'cat'

input_dim =3
input_dim_encoder=3
output_dim_encoder=2
hidden_dim_encoder=64
hidden_dim_db=64
output_dim_encoder = output_dim_encoder
hidden_dim_encoder = hidden_dim_encoder
hidden_dim_db =hidden_dim_db
output_dim = 26
n_layers=1
use_type=''


outbound_rule_encoder = outbound_RuleEncoder(input_dim, output_dim_encoder, hidden_dim_encoder,dropout=0.2)
minmax_rule_encoder = minmax_RuleEncoder(input_dim, output_dim_encoder, hidden_dim_encoder,dropout=0.2)
data_encoder = DataEncoder(input_dim, output_dim_encoder, hidden_dim_encoder,dropout=0.2)
if use_type == 'no_rule':
  model = DataonlyNet(input_dim, output_dim, data_encoder, hidden_dim=hidden_dim_db, n_layers=n_layers)
else:
  model = Net(input_dim, output_dim, minmax_rule_encoder, outbound_rule_encoder, data_encoder, hidden_dim=hidden_dim_db, n_layers=n_layers, merge=merge)

print(model)
#optimizer = optim.RMSprop(model.parameters(), lr=0.001,  eps=1e-08, weight_decay=0, momentum=0, centered=False)
optimizer = optim.Adam(model.parameters(), lr=0.01)        
 

def custom_mse(y_true, y_pred, alpha, beta):
    alpha_loss=alpha
    beta_loss=beta
    
    loss_task = nn.MSELoss()(y_true, y_pred)
    #loss_rule=  K.relu(y_pred[:,0] - 3.35)+K.relu(3.30 - y_pred[:,0])+K.relu(y_pred[:,1] - 2.3)+K.relu(2.2 - y_pred[:,1])+K.relu(y_pred[:,2] - 6.2)+K.relu(6.14 - y_pred[:,2])+K.relu(y_pred[:,3] - 2.77)+K.relu(2.63 - y_pred[:,3])+K.relu(y_pred[:,4] - 2.3)+K.relu(2.1 - y_pred[:,4])+K.relu(y_pred[:,5] - 3.37)+K.relu(3.23 - y_pred[:,5])+K.relu(y_pred[:,6] - 2.54)+K.relu(2.34 - y_pred[:,6])+K.relu(y_pred[:,7] - 0.42)+K.relu(0.38 - y_pred[:,7])+K.relu(y_pred[:,8] - 0.63)+K.relu(0.53 - y_pred[:,8])+K.relu(y_pred[:,9] - 0.63)+K.relu(0.53 - y_pred[:,9])+K.relu(y_pred[:,10] - 0.63)+K.relu(0.53 - y_pred[:,10])+K.relu(y_pred[:,11] - 0.63)+K.relu(0.53 - y_pred[:,11])+K.relu(y_pred[:,12] - 0.63)+K.relu(0.53 - y_pred[:,12])+K.relu(y_pred[:,13] - 3.35)+K.relu(3.30 - y_pred[:,13])+K.relu(y_pred[:,14] - 2.3)+K.relu(2.2 - y_pred[:,14])+K.relu(y_pred[:,15] - 6.2)+K.relu(6.14 - y_pred[:,15])+K.relu(y_pred[:,16] - 2.77)+K.relu(2.63 - y_pred[:,16])+K.relu(y_pred[:,17] - 2.3)+K.relu(2.1 - y_pred[:,17])+K.relu(y_pred[:,18] - 3.37)+K.relu(3.23 - y_pred[:,18])+K.relu(y_pred[:,19] - 2.54)+K.relu(2.34 - y_pred[:,19])+K.relu(y_pred[:,20] - 0.42)+K.relu(0.38 - y_pred[:,20])+K.relu(y_pred[:,21] - 0.63)+K.relu(0.53 - y_pred[:,21])+K.relu(y_pred[:,22] - 0.63)+K.relu(0.53 - y_pred[:,22])+K.relu(y_pred[:,23] - 0.63)+K.relu(0.53 - y_pred[:,23])+K.relu(y_pred[:,24] - 0.63)+K.relu(0.53 - y_pred[:,24])+K.relu(y_pred[:,25] - 0.63)+K.relu(0.53 - y_pred[:,25])
    loss_rule_minmax= F.relu(y_pred[:,13] - y_pred[:,0])+F.relu(y_pred[:,14] - y_pred[:,1])+F.relu(y_pred[:,15] - y_pred[:,2])+F.relu(y_pred[:,16] - y_pred[:,3])+F.relu(y_pred[:,17] - y_pred[:,4])+F.relu(y_pred[:,18] - y_pred[:,5])+F.relu(y_pred[:,19] - y_pred[:,6])+F.relu(y_pred[:,20] - y_pred[:,7])+F.relu(y_pred[:,21] - y_pred[:,8])+F.relu(y_pred[:,22] - y_pred[:,9])+F.relu(y_pred[:,23] - y_pred[:,10])+F.relu(y_pred[:,24] - y_pred[:,11])+F.relu(y_pred[:,25] - y_pred[:,12])
    loss_rule_outbound=F.relu((y_true[:,0] - 3.35)*(3.35-y_pred[:,0]))+F.relu((y_true[:,1] - 2.3)*(2.3-y_pred[:,1]))+F.relu((y_true[:,2] - 6.3)*(6.3-y_pred[:,2]))+F.relu((y_true[:,3] - 2.77)*(2.77-y_pred[:,3]))+F.relu((y_true[:,4] - 2.3)*(2.3-y_pred[:,4]))+F.relu((y_true[:,5] - 3.37)*(3.37-y_pred[:,5]))+F.relu((y_true[:,6] - 2.54)*(2.54-y_pred[:,6]))+F.relu((y_true[:,7] - 0.42)*(0.42-y_pred[:,7]))+F.relu((y_true[:,8] - 0.63)*(0.63-y_pred[:,8]))+F.relu((y_true[:,9] - 0.63)*(0.63-y_pred[:,9]))+F.relu((y_true[:,10] - 0.63)*(0.63-y_pred[:,10]))+F.relu((y_true[:,11] - 0.63)*(0.63-y_pred[:,11]))+F.relu((y_true[:,12] - 0.63)*(0.63-y_pred[:,12]))+F.relu((y_true[:,13] - 3.25)*(3.25-y_pred[:,13]))+F.relu((y_true[:,14] - 2.2)*(2.2-y_pred[:,14]))+F.relu((y_true[:,15] - 6.14)*(6.14-y_pred[:,15]))+F.relu((y_true[:,16] - 2.63)*(2.63-y_pred[:,16]))+F.relu((y_true[:,17] - 2.1)*(2.1-y_pred[:,17]))+F.relu((y_true[:,18] - 3.23)*(3.23-y_pred[:,18]))+F.relu((y_true[:,19] - 2.34)*(2.34-y_pred[:,19]))+F.relu((y_true[:,20] - 2.38)*(2.38-y_pred[:,20]))+F.relu((y_true[:,21] - 0.53)*(0.53-y_pred[:,21]))+F.relu((y_true[:,22] - 0.53)*(0.53-y_pred[:,22]))+F.relu((y_true[:,23] - 0.53)*(0.53-y_pred[:,23]))+F.relu((y_true[:,24] - 0.53)*(0.53-y_pred[:,24]))+F.relu((y_true[:,25] - 0.53)*(0.53-y_pred[:,25]))
    #scale = loss_rule_minmax.item() / loss_rule_outbound.item()
    #loss_rule = alpha_loss*loss_rule_minmax + scale* beta_loss*loss_rule_outbound
    loss_rule_outbound=loss_rule_outbound*10
    #scale2 = loss_rule.item() / loss_task.item()
    #loss = loss_rule+ scale2*(1-beta_loss-alpha_loss)*loss_task

    #print('alpha_loss',alpha_loss)
    #print('beta_loss',beta_loss)
    loss= alpha_loss*loss_rule_minmax + beta_loss*loss_rule_outbound + (1-beta_loss-alpha_loss)*loss_task

    return loss





Net(
  (minmax_rule_encoder): minmax_RuleEncoder(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (outbound_rule_encoder): outbound_RuleEncoder(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (data_encoder): DataEncoder(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (net): Sequential(
    (0): BiLSTM_layer(
      (lstm): LSTM(6, 64, batch_first=True, bidirectional=True)
      (fc): Linear(in_features=64, out_features=26, bias=True)
    )
  )
)
