In [1]:
import torch
import torch.nn as nn

# Define the low-fidelity sub-network
class LowFidelityNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(LowFidelityNet, self).__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_size, output_size)
  
  def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out

# Define the high-fidelity sub-network
class HighFidelityNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(HighFidelityNet, self).__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_size, output_size)
  
  def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out

# Define the multi-fidelity neural network
class MFNN(nn.Module):
  def __init__(self, low_fidelity_net, high_fidelity_net):
    super(MFNN, self).__init__()
    self.low_fidelity_net = low_fidelity_net
    self.high_fidelity_net = high_fidelity_net
  
  def forward(self, x, fidelity):
    # Use the low-fidelity sub-network if the input fidelity is low,
    # and use the high-fidelity sub-network if the input fidelity is high
    if fidelity == 'low':
      out = self.low_fidelity_net(x)
    else:
      out = self.high_fidelity_net(x)
    return out

# Create instances of the low-fidelity and high-fidelity sub-networks
input_size = 10
hidden_size = 20
output_size = 1
low_fidelity_net = LowFidelityNet(input_size, hidden_size, output_size)
high_fidelity_net = HighFidelityNet(input_size, hidden_size, output_size)

# Create an instance of the MFNN
mfnn = MFNN(low_fidelity_net, high_fidelity_net)

# Test the MFNN with low-fidelity input
x = torch.randn(1, input_size)
fidelity = 'low'
output = mfnn(x, fidelity)
print(output)

# Test the MFNN with high-fidelity input
x = torch.randn(1, input_size)
fidelity = 'high'
output = mfnn(x, fidelity)
print(output)


tensor([[-0.0407]], grad_fn=<AddmmBackward0>)
tensor([[0.0553]], grad_fn=<AddmmBackward0>)
