In [2]:
# general imports
import torch

# neuralop imports
from neuralop.layers.gno_block import GNOBlock
from neuralop.layers.channel_mlp import ChannelMLP
from neuralop.layers.fno_block import FNOBlocks



Import Data

In [16]:
import dataset_utils 
from torch.utils.data import DataLoader, random_split

B = 1
train_percent = 0.8
grid_size = 5
radius = 1 / grid_size
print("Grid size:", grid_size, "Radius:", radius)

# load dataset
dataset = dataset_utils.SDFDataset("./cars100")

# split dataset into training and validation sets
train_size = int(train_percent * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=B, shuffle=False, num_workers=4, pin_memory=True)

# Get 1 Batch
batch = next(iter(train_loader))
batch = batch[:, :, :] # Limit to 5000 points
print("Batch shape:", batch.shape)

Grid size: 5 Radius: 0.2
Batch shape: torch.Size([1, 50000, 4])


Setup Data

In [17]:
# y
input_geom = batch[:, :, :3]  # x, y, z coordinates
input_geom = input_geom.squeeze(0) # ! unbatch

# x (grid points in 3D space) generate 64x64x64 grid with bounds [-1, 1] in each dimension
coords = torch.linspace(-1.0, 1.0, grid_size) 
x, y, z = torch.meshgrid(coords, coords, coords, indexing='ij') 
latent_queries = torch.stack((x, y, z), dim=-1)
# transform to match batch size
latent_queries = latent_queries.repeat(B, 1, 1, 1, 1)  # Repeat for batch size B
latent_queries = latent_queries.squeeze(0)  # ! unbatch

# f_y
features = batch[:, :, 3]  # features (e.g., colors, normals)
features = features.unsqueeze(-1).squeeze(0)  # ! unbatch


print(f'Supports shape: {input_geom.shape}, Latent queries shape: {latent_queries.shape}, Features shape: {features.shape}')

Supports shape: torch.Size([50000, 3]), Latent queries shape: torch.Size([5, 5, 5, 3]), Features shape: torch.Size([50000, 1])


Set up Layers

In [18]:
# params
IN_CHANNELS = 1
OUT_CHANNELS = 10 # latent embedding from GNO encoder
COORD_DIM = 3
LIFTING_CHANNELS = 16 # from paper
FNO_HIDDEN_CHANNELS = 32 # from paper

FNO_N_LAYERS = 4 # from paper  (# number of FNO layers in the FNOBlocks)
fno_n_modes=(16, 16, 16) # from paper, number of Fourier modes in each dimension 

""" GNO Block"""
gno_in = GNOBlock(
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    coord_dim=COORD_DIM,
    radius=radius
    )

print(gno_in)


""" Lifting """
# takes per-grid point features from GNO encoder and projects them into FNOs latent channel space 
lifting = ChannelMLP(in_channels=OUT_CHANNELS, 
                     hidden_channels=LIFTING_CHANNELS, 
                     out_channels=FNO_HIDDEN_CHANNELS, 
                     n_layers=2) 


""" FNO Blocks """
fno_blocks = FNOBlocks(in_channels=FNO_HIDDEN_CHANNELS, out_channels=FNO_HIDDEN_CHANNELS, n_modes=fno_n_modes, n_layers=FNO_N_LAYERS)
print(fno_blocks)
def latent_embedding(in_p):
    for idx in range(fno_blocks.n_layers):
        in_p = fno_blocks(in_p, idx)

    return in_p 
    

GNOBlock(
  (pos_embedding): SinusoidalEmbedding()
  (neighbor_search): NeighborSearch()
  (channel_mlp): LinearChannelMLP(
    (fcs): ModuleList(
      (0): Linear(in_features=384, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): Linear(in_features=128, out_features=10, bias=True)
    )
  )
  (integral_transform): IntegralTransform(
    (channel_mlp): LinearChannelMLP(
      (fcs): ModuleList(
        (0): Linear(in_features=384, out_features=128, bias=True)
        (1): Linear(in_features=128, out_features=256, bias=True)
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): Linear(in_features=128, out_features=10, bias=True)
      )
    )
  )
)
FNOBlocks(
  (convs): ModuleList(
    (0-3): 4 x SpectralConv(
      (weight): DenseTensor(shape=torch.Size([32, 32, 16, 16, 9]), rank=None)
    )
  )
  (fno_skips): ModuleList(
    (0-3): 4 x Flatten

In [19]:
reshaped_queries = latent_queries.view((-1, latent_queries.shape[-1]))  # Reshape for GNOBlock input
print(f'Reshaped queries shape: {reshaped_queries.shape}')

""" GNOBlock Encoding """ # Input shape after GNOBlock: torch.Size([1, 10, 10, 10, OUTchannels])
in_p = gno_in(y=input_geom, x=reshaped_queries, f_y=features)

# reshape
grid_shape = latent_queries.shape[:-1] # disregard positional encoding dim
in_p = in_p.view((B, *grid_shape, -1)) # add batch


""" Lifting to FNO latent space """ # Input shape after lifting: torch.Size([1, FNO_HIDDEN_CHANNELS, 10, 10, 10])
# reshape
in_p = in_p.permute(0, len(in_p.shape)-1, *list(range(1,len(in_p.shape)-1)))
in_p = lifting(in_p)

""" Latent Embedding with FNO Blocks """ # Input shape after FNOBlocks: torch.Size([1, 10, 10, 10,  FNO_HIDDEN_CHANNELS])
latent_embed = latent_embedding(in_p)
latent_embed = latent_embed.permute(0, 2, 3, 4, 1)  # Reshape to [B, grid_size, grid_size, grid_size, FNO_HIDDEN_CHANNELS]
print(f'Latent embedding shape: {latent_embed.shape}')

# flatten embeddings to [batch_size, grid_size**3, FNO_HIDDEN_CHANNELS]
latent_embed = latent_embed.view(B, -1, FNO_HIDDEN_CHANNELS)  # Flatten to [B, grid_size**3, FNO_HIDDEN_CHANNELS]
print(f'Flattened latent embedding shape: {latent_embed.shape}')

Reshaped queries shape: torch.Size([125, 3])
Latent embedding shape: torch.Size([1, 5, 5, 5, 32])
Flattened latent embedding shape: torch.Size([1, 125, 32])
