<a href="https://colab.research.google.com/github/boneseva/Diffusion-SDF/blob/main/testing/VAE_sdf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Voxel VAE Training (Diffusion-SDF)**
# This notebook trains a 3D VAE on Signed Distance Functions (SDFs) following the Diffusion-SDF paper architecture.

In [1]:
!pip install pytorch-lightning wandb nibabel torch edt

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting edt
  Downloading edt-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.2-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)


In [2]:
import warnings
warnings.filterwarnings(
    "ignore",
    message="Torchmetrics v0.9 introduced.*full_state_update",
    category=UserWarning
)

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import nibabel as nib
import wandb
from pytorch_lightning.loggers import WandbLogger

In [3]:
class VoxelSDFDataset(Dataset):
    def __init__(self, sdf_dir):
        self.sdf_files = [os.path.join(sdf_dir, f) for f in os.listdir(sdf_dir)
                          if f.endswith(('.nii', '.nii.gz'))]

    def __len__(self):
        return len(self.sdf_files)

    def __getitem__(self, idx):
        # Load and verify SDF data
        sdf = nib.load(self.sdf_files[idx]).get_fdata()
        assert np.abs(sdf).max() <= 1.0, "SDF values not normalized!"
        return torch.tensor(sdf, dtype=torch.float32).unsqueeze(0)  # (1, 80, 80, 80)

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

class VoxelVAE(pl.LightningModule):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 64, 3, stride=2, padding=1),  # 80 -> 40
            nn.ReLU(),
            nn.Conv3d(64, 128, 3, stride=2, padding=1),  # 40 -> 20
            nn.ReLU(),
            nn.Conv3d(128, 256, 3, stride=2, padding=1),  # 20 -> 10
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 10 * 10 * 10, latent_dim * 2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 10 * 10 * 10),
            View((-1, 256, 10, 10, 10)),
            nn.ConvTranspose3d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu, logvar = mu_logvar.chunk(2, dim=1)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def training_step(self, batch, batch_idx):
        recon, mu, logvar = self(batch)

        recon_loss = F.l1_loss(recon, batch)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch.size(0)
        total_loss = recon_loss + 1e-5 * kl_loss

        self.log('train_loss', total_loss)
        self.log('recon_loss', recon_loss)
        self.log('kl_loss', kl_loss)

        if batch_idx == 0 and self.current_epoch % 5 == 0:
            with torch.no_grad():
                input_slice = batch[0][0][40].cpu().numpy()
                recon_slice = recon[0][0][40].cpu().numpy()

                if isinstance(self.logger.experiment, wandb.sdk.wandb_run.Run):
                    self.logger.experiment.log({
                        "epoch": self.current_epoch,
                        "input_slice": wandb.Image(input_slice),
                        "reconstruction_slice": wandb.Image(recon_slice),
                        "train_loss": total_loss.item(),
                        "recon_loss": recon_loss.item(),
                        "kl_loss": kl_loss.item()
                    })

        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In [6]:
def train():
    # Mount Google Drive if using Colab
    from google.colab import drive
    drive.mount('/content/drive')

    config = {
        'batch_size': 8,
        'latent_dim': 256,
        'max_epochs': 200,  # Reduced for Colab demo
        'data_path': '/content/drive/MyDrive/Lyso_SDF'  # Update this path
    }

    # Initialize WandB
    wandb.login()
    wandb_logger = WandbLogger(
        project="Diffusion-SDF-VAE",
        name="colab-vae-training",
        config=config,
    )

    # Dataset & DataLoader
    dataset = VoxelSDFDataset(config['data_path'])
    loader = DataLoader(dataset,
                        batch_size=config['batch_size'],
                        shuffle=True,
                        num_workers=2,
                        pin_memory=True)

    # Model & Trainer
    model = VoxelVAE(latent_dim=config['latent_dim'])

    trainer = pl.Trainer(
        accelerator='auto',
        devices=1 if torch.cuda.is_available() else None,
        max_epochs=config['max_epochs'],
        logger=wandb_logger,
        callbacks=[
            ModelCheckpoint(
                dirpath='checkpoints',
                filename='vae-{epoch}-{train_loss:.2f}',
                save_top_k=3,
                monitor='train_loss'
            )
        ]
    )

    trainer.fit(model, loader)
    wandb.finish()

In [7]:
train()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 132 M  | train
1 | decoder | Sequential | 66.9 M | train
-----------------------------------------------
199 M     Trainable params
0         Non-trainable params
199 M     Total params
796.322   Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


0,1
epoch,▁▃▅██
kl_loss,▁█▅▅▅
recon_loss,█▄▂▁▁
train_loss,█▄▂▁▁
trainer/global_step,▁▁▁▁█

0,1
epoch,16.0
kl_loss,251.66876
recon_loss,0.06723
train_loss,0.06975
trainer/global_step,49.0
