In [1]:
import numpy as np
import torch

# Data parameters
data_root = "./data/"
data_filename = "nucleon_41x41x41_uint8.raw"
data_shape = (41, 41, 41)
data_dtype = np.uint8

batch_size = 1024
num_workers = 2
num_epochs = 20
lr = 0.008
device = torch.device("cpu")  # or "cuda" for GPU
run_dtype = torch.float32  # or torch.float32 for half precision

In [2]:
from networks import INR_Base

network_type = "mlp"
use_native_encoder = not torch.cuda.is_available()
use_native_network = not torch.cuda.is_available() or network_type == "kan"

# Create model
model = INR_Base(
    log2_hashmap_size=16,
    native_encoder=use_native_encoder,
    native_network=use_native_network,
    network_type=network_type,
)

In [3]:
from pathlib import Path

from torch.utils.data import DataLoader
from volumetric_dataset import VolumetricDataset

data_filename = Path(data_root, data_filename)
# Create dataset and dataloader
dataset = VolumetricDataset(
    data_filename,
    data_shape,
    data_dtype,
    batch_size,
    normalize_coords=True,
    normalize_values=True,
    initial_shuffle=True,  # Shuffle dataset once initially
)

pin = device.type == "cuda"
loader = DataLoader(dataset, batch_size=None, num_workers=num_workers, pin_memory=pin)

In [4]:
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm


model.to(device, run_dtype)
optimizer = AdamW(model.parameters(), lr=lr)
loss_fn = nn.functional.mse_loss

# Training loop
epoch_losses = []
for epoch in range(num_epochs):
    loss_total = 0.0
    for x, y_hat in tqdm(loader):
        x = x.to(device, run_dtype)
        y_hat = y_hat.to(device, run_dtype)
        y = model(x)
        loss = loss_fn(y, y_hat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_total += loss.item()

    avg_loss = loss_total / len(loader)
    epoch_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

  0%|          | 0/68 [00:01<?, ?it/s]


RuntimeError: ZeroDivisionError

In [None]:
import matplotlib.pyplot as plt

# Plot training loss
plt.figure()
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
plt.title('Training Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [None]:
from torchmetrics.functional import mean_squared_error
from torchmetrics.functional.image import peak_signal_noise_ratio
from torchmetrics.functional.image import structural_similarity_index_measure

reconst_dataset = VolumetricDataset(
    file_path=data_filename,
    data_shape=data_shape,
    data_type=data_dtype,
    batch_size=batch_size,
    normalize_coords=True,
    normalize_values=True,
    initial_shuffle=False,  # No shuffle for reconstruction
)
reconst_dataloader = DataLoader(
    reconst_dataset,
    batch_size=None,
    num_workers=0,  # Single process for eval
    pin_memory=pin,
)
# Reconstruct the entire volume from the INR
model.eval()
with torch.no_grad():
    data_shape = torch.as_tensor(data_shape, device=device, dtype=run_dtype)
    reconst_data = torch.zeros(data_shape, device=device, dtype=run_dtype)
    for x, _ in tqdm(reconst_dataloader):
        x = x.to(device, dtype=run_dtype, non_blocking=True)
        y = model(x).to(dtype=run_dtype, non_blocking=True)

        indices = (x * (data_shape - 1)).long()
        i, j, k = indices.split(1, dim=-1)

        # (batch_size,)
        i, j, k = i.squeeze(), j.squeeze(), k.squeeze()
        y = y.squeeze()

        reconst_data[i, j, k] = y

In [None]:
gt_data = torch.as_tensor(
    reconst_dataset.volume_data(), device=device, dtype=run_dtype
).contiguous()
reconst_data = torch.clamp(reconst_data, 0.0, 1.0).contiguous()

# Process metrics slice-by-slice (more stable for 3D volumes)
psnr_values = []
mse_values = []
ssim_values = []

# Iterate through one dimension
num_slices = gt_data.shape[2]
for i in tqdm(range(num_slices)):

    # (1, 1, H, W)
    gt_slice = gt_data[:, :, i].unsqueeze(0).unsqueeze(0)
    reconst_slice = reconst_data[:, :, i].unsqueeze(0).unsqueeze(0)

# Calculate metrics for this slice
psnr_values.append(peak_signal_noise_ratio(reconst_slice, gt_slice, data_range=1.0))
mse_values.append(mean_squared_error(reconst_slice, gt_slice))
ssim_values.append(
    structural_similarity_index_measure(reconst_slice, gt_slice, data_range=1.0)
)

# Average the metrics across all slices
psnr = torch.stack(psnr_values).mean().item()
mse = torch.stack(mse_values).mean().item()
ssim = torch.stack(ssim_values).mean().item()
print(f"PSNR: {psnr}, MSE: {mse}, SSIM: {ssim}")