In [1]:
import torch
import pytest
from tensorly import tenalg
tenalg.set_backend("einsum")

# Parameterize use of torch_scatter if it is built
try: 
    from torch_scatter import segment_csr
    use_torch_scatter = [True, False]
except:
    use_torch_scatter = [False]
    print("Install torch_scatter to use it in GINO., your version of torch is:", torch.__version__)

try:
    import open3d
except:
    print("Install open3d to use GINO.")
from neuralop.models.gino import GINO

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# general imports
import torch
import torch.nn as nn

# neuralop imports
from neuralop.layers.gno_block import GNOBlock
from neuralop.layers.channel_mlp import ChannelMLP
from neuralop.layers.fno_block import FNOBlocks

# data set imports
import dataset_utils as du
from torch.utils.data import DataLoader, random_split

In [3]:
B = 1 # must use batch size of 1 for GNOBlock

# load dataset
dataset = du.SDFDataset("./cars100")
train_percent = 0.95 

# split dataset into training and validation sets
train_size = int(train_percent * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=B, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train dataset size: {len(train_dataset)}, Validation dataset size: {len(val_dataset)}")

Train dataset size: 95, Validation dataset size: 5


In [4]:
def process_batch(batch, grid_size, query_size, input_size):
    # y
    input_geom = batch[:, :input_size, :3]  # x, y, z coordinates
    input_geom = input_geom.squeeze(0) 

    # x (grid points in 3D space) generate 64x64x64 grid with bounds [-1, 1] in each dimension
    coords = torch.linspace(-1.0, 1.0, grid_size) 
    x, y, z = torch.meshgrid(coords, coords, coords, indexing='ij') 
    latent_queries = torch.stack((x, y, z), dim=-1)
    # transform to match batch size
    latent_queries = latent_queries.repeat(B, 1, 1, 1, 1)  # Repeat for batch size B
    latent_queries = latent_queries

    # f_y
    features = batch[:, :input_size, 3]  # features (e.g., colors, normals)
    features = features.unsqueeze(-1)

    # queries (for now just the same as input_geom)
    output_queries = input_geom.clone().squeeze(0)[:query_size,:]  # !For now, just use first 1000 points of input_geom as output_queries
    output_labels = features.clone()[:query_size,:]  # !For now, just use first 1000 points of features as output_labels
    return input_geom, latent_queries, features, output_queries, output_labels

# test
batch = next(iter(train_loader))
input_geom, latent_queries, features, output_queries, output_labels = process_batch(batch, grid_size=64, query_size=1000, input_size=1000)
print(f'input_geom shape: {input_geom.shape}, latent_queries shape: {latent_queries.shape}, features shape: {features.shape}, output_queries shape: {output_queries.shape}, output_labels shape: {output_labels.shape}')
        
    

input_geom shape: torch.Size([1000, 3]), latent_queries shape: torch.Size([1, 64, 64, 64, 3]), features shape: torch.Size([1, 1000, 1]), output_queries shape: torch.Size([1000, 3]), output_labels shape: torch.Size([1, 1000, 1])


Test and train prebuilt GINO

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
model = GINO(
        in_channels=1,
        out_channels=1
    ).to(device)

# process the batch
input_geom, latent_queries, features, output_queries, output_labels = process_batch(batch, grid_size=64, query_size=1000, input_size=1000)

out = model(x=features, input_geom=input_geom, latent_queries=latent_queries, output_queries=output_queries)

In [11]:
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
model = GINO(
        in_channels=1,
        out_channels=1,
        gno_use_open3d=True,         
        gno_use_torch_scatter=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()  # Mean Squared Error loss for SDF regression

num_epochs = 5
for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_loader, desc="Training Batches"):
        batch = batch.to(device)  # Move batch to device
        optimizer.zero_grad()  # Zero the gradients
        
        # process the batch
        input_geom, latent_queries, features, output_queries, output_labels = process_batch(batch, grid_size=10, query_size=1000, input_size=1000)
        
        out = model(x=features, input_geom=input_geom, latent_queries=latent_queries, output_queries=output_queries)
        
        # Compute loss
        loss = loss_fn(out, output_labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Training Batches: 100%|██████████| 95/95 [01:47<00:00,  1.13s/it]
Training Epochs:  20%|██        | 1/5 [01:47<07:09, 107.41s/it]

Epoch 1/5, Loss: 0.0329


Training Batches: 100%|██████████| 95/95 [01:46<00:00,  1.12s/it]
Training Epochs:  40%|████      | 2/5 [03:33<05:20, 106.86s/it]

Epoch 2/5, Loss: 0.0255


Training Batches: 100%|██████████| 95/95 [01:48<00:00,  1.14s/it]
Training Epochs:  60%|██████    | 3/5 [05:22<03:35, 107.50s/it]

Epoch 3/5, Loss: 0.0255


Training Batches: 100%|██████████| 95/95 [01:47<00:00,  1.13s/it]
Training Epochs:  80%|████████  | 4/5 [07:09<01:47, 107.42s/it]

Epoch 4/5, Loss: 0.0256


Training Batches: 100%|██████████| 95/95 [01:45<00:00,  1.11s/it]
Training Epochs: 100%|██████████| 5/5 [08:54<00:00, 106.91s/it]

Epoch 5/5, Loss: 0.0256



