In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_channels, nf, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(nf, nf*2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(nf*2, nf*4, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Conv2d(nf*4, latent_dim, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(self.pool(x1)))
        x3 = F.relu(self.conv3(self.pool(x2)))
        z = self.fc(x3)
        return z, [x1, x2, x3]

class Decoder(nn.Module):
    def __init__(self, latent_dim, nf, out_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Conv2d(latent_dim, nf*4, kernel_size=3, padding=1)
        self.deconv1 = nn.ConvTranspose2d(nf*4, nf*2, kernel_size=3, padding=1)
        self.deconv2 = nn.ConvTranspose2d(nf*2, nf, kernel_size=3, padding=1)
        self.deconv3 = nn.ConvTranspose2d(nf, out_channels, kernel_size=3, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, z):
        x = self.fc(z)
        x = F.relu(self.deconv1(self.upsample(x)))
        x = F.relu(self.deconv2(self.upsample(x)))
        x_hat = self.deconv3(x)
        return x_hat

class Regressor(nn.Module):
    def __init__(self, latent_dim, nf, out_channels):
        super(Regressor, self).__init__()
        self.fc = nn.Conv2d(latent_dim, nf*4, kernel_size=3, padding=1)
        self.conv1 = nn.ConvTranspose2d(nf*4+nf*2, nf*2, kernel_size=3, padding=1)
        self.conv2 = nn.ConvTranspose2d(nf*2+nf, nf, kernel_size=3, padding=1)
        self.conv3 = nn.ConvTranspose2d(nf, out_channels, kernel_size=3, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, z, intermediate_outputs):
        x1, x2, x3 = intermediate_outputs
        x = self.fc(z)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = F.relu(self.conv1(x))
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = F.relu(self.conv2(x))
        y_hat = self.conv3(x)
        return y_hat

class HybridModel(nn.Module):
    def __init__(self, in_channels, nf, latent_dim, rec_channels, out_channels):
        super(HybridModel, self).__init__()
        self.encoder = Encoder(in_channels, nf, latent_dim)
        self.decoder = Decoder(latent_dim, nf, rec_channels)
        self.regressor = Regressor(latent_dim, nf, out_channels)

    def forward(self, x):
        # Rescale input from (n_batch, n_in_channels, 150, 49) to (n_batch, n_in_channels, 152, 48)
        x = F.interpolate(x, size=(152, 48), mode='bilinear', align_corners=False)
        
        z, intermediate_outputs = self.encoder(x)
        x_hat = self.decoder(z)
        x_hat = F.interpolate(x_hat, size=(150, 49), mode='bilinear', align_corners=False)
        y_hat = self.regressor(z, intermediate_outputs)
        
        # Rescale output from (n_batch, n_out_channels, 152, 48) to (n_batch, n_out_channels, 150, 49)
        y_hat = F.interpolate(y_hat, size=(150, 49), mode='bilinear', align_corners=False)
        
        return x_hat, y_hat

# Example usage
in_channels = 29
rec_channels = 13
latent_dim = 128
out_channels = 9
nx = 128
ny = 48
nf = 16

hybrid_model = HybridModel(in_channels, nf, latent_dim, rec_channels, out_channels)

# Example input
x = torch.randn(1, in_channels, 150, 49)

# Forward pass
x_hat, y_hat = hybrid_model(x)

print("x_hat shape:", x_hat.shape)
print("y_hat shape:", y_hat.shape)

# Export the model to ONNX format
torch.onnx.export(hybrid_model, x, "hybrid_model.onnx", verbose=True, input_names=['input'], output_names=['output1', 'output2'])

x_hat shape: torch.Size([1, 13, 150, 49])
y_hat shape: torch.Size([1, 9, 150, 49])
Exported graph: graph(%input : Float(1, 29, 150, 49, strides=[213150, 7350, 49, 1], requires_grad=0, device=cpu),
      %encoder.conv1.weight : Float(16, 29, 3, 3, strides=[261, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv1.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv2.weight : Float(32, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv2.bias : Float(32, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv3.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv3.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %encoder.fc.weight : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.fc.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
      %decoder.fc.weight : Float(64, 128, 3, 3, strides=[1152, 9, 3