In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

"""
Implementation based on original paper NeurIPS 2016 https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf
"""


## Discriminator

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels=3, dim=64, out_conv_channels=512):
        super(Discriminator, self).__init__()
        conv1_channels = int(out_conv_channels / 8)
        conv2_channels = int(out_conv_channels / 4)
        conv3_channels = int(out_conv_channels / 2)
        self.out_conv_channels = out_conv_channels
        self.out_dim = int(dim / 16)

        self.conv1 = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels, out_channels=conv1_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv1_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv2_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv3_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(out_conv_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Linear(out_conv_channels * self.out_dim * self.out_dim * self.out_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Flatten and apply linear + sigmoid
        x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim * self.out_dim)
        x = self.out(x)
        return x


## Generator

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, in_channels=512, out_dim=64, out_channels=1, noise_dim=200, activation="sigmoid"):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.out_dim = out_dim
        self.in_dim = int(out_dim / 16)
        conv1_out_channels = int(self.in_channels / 2.0)
        conv2_out_channels = int(conv1_out_channels / 2)
        conv3_out_channels = int(conv2_out_channels / 2)

        self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim * self.in_dim)

        self.conv1 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=in_channels, out_channels=conv1_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv1_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv2_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv3_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            )
        )
        if activation == "sigmoid":
            self.out = torch.nn.Sigmoid()
        else:
            self.out = torch.nn.Tanh()

    def project(self, x):
        """
        projects and reshapes latent vector to starting volume
        :param x: latent vector
        :return: starting volume
        """
        return x.view(-1, self.in_channels, self.in_dim, self.in_dim, self.in_dim)

    def forward(self, x):
        x = self.linear(x)
        x = self.project(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return self.out(x)


## Test

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def test_gan3d(print_summary=True):
    noise_dim = 200 # latent space vector dim
    in_channels = 512 # convolutional channels
    dim = 64  # cube volume
    model_generator = Generator(in_channels=512, out_dim=dim, out_channels=1, noise_dim=noise_dim).to(device)
    noise = torch.rand(1, noise_dim).to(device)
    generated_volume = model_generator(noise)
    print("Generator output shape", generated_volume.shape)
    model_discriminator = Discriminator(in_channels=1, dim=dim, out_conv_channels=in_channels).to(device)
    out = model_discriminator(generated_volume)
    print("Discriminator output", out.item())
    if print_summary:
        print("\n\nGenerator summary\n\n")
        summary(model_generator, (1, noise_dim))
        print("\n\nDiscriminator summary\n\n")
        summary(model_discriminator, (1,dim,dim,dim))

test_gan3d()

## Generator and Discriminator Training

In [None]:
from torch.autograd.variable import Variable


def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data


def zeros_target(size):
    '''
    FAKE data
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data


def train_discriminator(discriminator, optimizer, real_data, fake_data, loss):
    cuda = next(discriminator.parameters()).is_cuda
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    target_real = ones_target(N)
    if cuda:
        target_real.cuda()

    error_real = loss(prediction_real, target_real)
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    target_fake = zeros_target(N)
    if cuda:
        target_fake.cuda()
    error_fake = loss(prediction_fake, target_fake)
    error_fake.backward()

    # 1.3 Update weights with gradients
    optimizer.step()

    # Return error and predictions for real and fake inputs
    return error_real + error_fake, prediction_real, prediction_fake


def train_generator(discriminator, optimizer, fake_data, loss):
    cuda = next(discriminator.parameters()).is_cuda
    N = fake_data.size(0)  # Reset gradients
    optimizer.zero_grad()  # Sample noise and generate fake data
    prediction = discriminator(fake_data)  # Calculate error and backpropagate
    target = ones_target(N)
    if cuda:
        target.cuda()

    error = loss(prediction, target)
    error.backward()  # Update weights with gradients
    optimizer.step()  # Return error
    return error

## Converting Mesh to Voxel

In [None]:
import trimesh
import numpy as np
import os

def mesh_to_voxel(mesh_path, voxel_resolution=32):
    """
    Convert a mesh file to a voxel grid.

    Args:
    - mesh_path (str): Path to the mesh file (.off).
    - voxel_resolution (int): The resolution of the voxel grid (e.g., 32x32x32).

    Returns:
    - voxel_grid (numpy.ndarray): A voxel grid representation of the mesh.
    """
    # Load the mesh file
    mesh = trimesh.load(mesh_path, force='mesh')

    # Convert to voxel grid
    voxel_grid = mesh.voxelized(pitch=1.0 / voxel_resolution)

    # Fill the voxels inside the mesh
    voxel_grid = voxel_grid.fill()

    # Get the matrix representation
    matrix = voxel_grid.matrix

    return matrix

def process_dataset(input_dir, output_dir, voxel_resolution=32):
    """
    Process all mesh files in the dataset directory and save them as voxel grids.

    Args:
    - input_dir (str): Directory containing the ModelNet10 dataset.
    - output_dir (str): Directory where the voxelized numpy arrays will be saved.
    - voxel_resolution (int): Resolution for the voxel grid.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for root, dirs, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.off'):
                mesh_path = os.path.join(root, file)
                voxel_grid = mesh_to_voxel(mesh_path, voxel_resolution=voxel_resolution)
                # Define output path
                relative_path = os.path.relpath(mesh_path, input_dir)
                base_name = os.path.splitext(relative_path)[0]
                npy_path = os.path.join(output_dir, base_name + '.npy')
                npy_dir = os.path.dirname(npy_path)
                if not os.path.exists(npy_dir):
                    os.makedirs(npy_dir)
                # Save voxel grid as numpy array
                np.save(npy_path, voxel_grid)
                print(f'Saved {npy_path}')

# Example usage
input_dir = './archive/ModelNet10'  # Update this path to your ModelNet10 dataset directory
output_dir = './archive/ModelNet10_voxelized'  # Choose where you want to save the voxelized models
voxel_resolution = 32  # Resolution of the voxel grid

process_dataset(input_dir, output_dir, voxel_resolution=voxel_resolution)

## Loading the Dataset

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms



class ModelNet10Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the voxelized ModelNet10 objects.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.filepaths = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith('.npy')]

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        filepath = self.filepaths[idx]
        volume = np.load(filepath)
        sample = {'volume': volume}

        if self.transform:
            sample = self.transform(sample)

        return sample

# Define a transform to convert numpy arrays to PyTorch tensors and normalize data
class ToTensor(object):
    def __call__(self, sample):
        volume = sample['volume']
        volume = volume.astype(np.float32)
        volume = torch.from_numpy(volume)
        volume = volume.unsqueeze(0)  # Add channel dimension
        return {'volume': volume}

dataset = ModelNet10Dataset(root_dir='./archive', transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Training

In [None]:
# Hyperparameters
num_epochs = 50
learning_rate = 0.0002
beta1 = 0.5  # Beta1 hyperparam for Adam optimizers
batch_size = 64  # Assuming your DataLoader is set with this batch size

# Initialize models
discriminator = Discriminator().to(device)
generator = Generator().to(device)

# Optimizers
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Loss function
criterion = nn.BCELoss()

# Training loop
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # =============================
        # Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        # =============================
        discriminator.zero_grad()
        
        # Train with real data
        real_cpu = data.to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), 1., dtype=torch.float, device=device)
        
        output = discriminator(real_cpu).view(-1)
        loss_d_real = criterion(output, label)
        loss_d_real.backward()
        D_x = output.mean().item()

        # Train with fake data
        noise = torch.randn(batch_size, noise_dim, device=device)
        fake = generator(noise)
        label.fill_(0.)
        output = discriminator(fake.detach()).view(-1)
        loss_d_fake = criterion(output, label)
        loss_d_fake.backward()
        D_G_z1 = output.mean().item()
        
        loss_d = loss_d_real + loss_d_fake
        optimizer_d.step()

        # =============================
        # Update Generator: maximize log(D(G(z)))
        # =============================
        generator.zero_grad()
        label.fill_(1.)  # fake labels are real for generator cost
        output = discriminator(fake).view(-1)
        loss_g = criterion(output, label)
        loss_g.backward()
        D_G_z2 = output.mean().item()
        
        optimizer_g.step()

        # Print statistics
        if i % 100 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader), loss_d.item(), loss_g.item(), D_x, D_G_z1, D_G_z2))

## Saving Model

In [None]:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')