In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as topt
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from neuralop.models import FNO

In [2]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=64, input_dim=2, output_dim=1):
        super(Encoder, self).__init__()

        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.fno = FNO()

    def forward(self, u):

        eps = torch.randn(u.shape[0], self.latent_dim, device=u.device)
        mean = ...
        logvar = ...

        z = mean + eps * torch.exp(0.5 * logvar)

        return z, mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim=64, input_dim=2, output_dim=1):
        super(Decoder, self).__init__()

        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.output_dim = output_dim

        # (original) NeRF-like architecture
        self.mlp_x = nn.Sequential(
            nn.Linear(self.input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 32)
        )
        self.mlp_z = nn.Sequential(
            nn.Linear(self.latent_dim, 2 * self.latent_dim),
            nn.ReLU(),
            nn.Linear(2 * self.latent_dim, 2 * self.latent_dim),
            nn.ReLU(),
            nn.Linear(2 * self.latent_dim, 32)
        )
        self.joint_mlp = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )

    def forward(self, x, z):
        """
        Computes u(x) by conditioning on z, a latent representation of u.

        Args:
            x: (batch_size, input_dim) tensor of spatial locations
            z: (batch_size, latent_dim) tensor of latent representations
        """
        x = self.mlp_x(x)
        z = self.mlp_z(z)
        xz = torch.cat([x, z], dim=1)

        return self.joint_mlp(xz)
