In [1]:
import os
import glob

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
tf.random.set_seed(1234)


2024-06-21 04:29:01.304109: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-21 04:29:01.304202: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-21 04:29:01.478767: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm

# Encoder definition
class PointNetEncoder(nn.Module):
    def __init__(self):
        super(PointNetEncoder, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1024)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.max(x, 2)[0]
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Decoder definition
class PointCloudDecoder(nn.Module):
    def __init__(self, num_points):
        super(PointCloudDecoder, self).__init__()
        self.num_points = num_points
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, num_points * 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 3, self.num_points)
        return x

# Model combining encoder and decoder
class PointCompletionNet(nn.Module):
    def __init__(self, num_points=2048):
        super(PointCompletionNet, self).__init__()
        self.encoder = PointNetEncoder()
        self.decoder = PointCloudDecoder(num_points)

    def forward(self, x):
        x = x.transpose(1, 2)  # Transpose to (batch_size, 3, num_points)
        features = self.encoder(x)
        reconstructed = self.decoder(features)
        return reconstructed.transpose(1, 2)  # Transpose back to (batch_size, num_points, 3)

# Dataset class
class PointCloudDataset(Dataset):
    def __init__(self, partial_data, gt_data):
        self.partial_data = partial_data.astype(np.float32)
        self.gt_data = gt_data.astype(np.float32)

    def __len__(self):
        return len(self.partial_data)

    def __getitem__(self, idx):
        partial = self.partial_data[idx]
        gt = self.gt_data[idx]
        return partial, gt

# Chamfer Distance (simplified version)
def chamfer_distance(pred, gt):
    batch_size, num_points, _ = pred.size()
    pred = pred.unsqueeze(1).repeat(1, num_points, 1, 1)
    gt = gt.unsqueeze(2).repeat(1, 1, num_points, 1)
    dist = torch.norm(pred - gt, dim=-1)
    dist1 = dist.min(dim=2)[0]
    dist2 = dist.min(dim=1)[0]
    return dist1.mean(dim=1) + dist2.mean(dim=1)

# Load your datasets
partial_dataset = partial_dataset
gt_dataset = gt_dataset

# Create DataLoader
train_dataset = PointCloudDataset(partial_dataset, gt_dataset)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Hyperparameters and model initialization
num_points = 2048
model = PointCompletionNet(num_points).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 25

# Training loop with progress bar
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    # Initialize the progress bar
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
        for partial, complete in train_loader:
            partial, complete = partial.cuda(), complete.cuda()
            optimizer.zero_grad()

            reconstructed = model(partial)
            loss = chamfer_distance(reconstructed, complete).mean()

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss/len(train_loader))
            pbar.update(1)

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}')


Epoch 1/25: 100%|██████████| 906/906 [02:08<00:00,  7.05batch/s, loss=0.0886]


Epoch [1/25], Loss: 0.0886


Epoch 2/25: 100%|██████████| 906/906 [02:09<00:00,  7.02batch/s, loss=0.064]  


Epoch [2/25], Loss: 0.0640


Epoch 3/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0607] 


Epoch [3/25], Loss: 0.0607


Epoch 4/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0589] 


Epoch [4/25], Loss: 0.0589


Epoch 5/25: 100%|██████████| 906/906 [02:09<00:00,  6.99batch/s, loss=0.0574] 


Epoch [5/25], Loss: 0.0574


Epoch 6/25: 100%|██████████| 906/906 [02:09<00:00,  6.99batch/s, loss=0.0563] 


Epoch [6/25], Loss: 0.0563


Epoch 7/25: 100%|██████████| 906/906 [02:09<00:00,  6.99batch/s, loss=0.0554] 


Epoch [7/25], Loss: 0.0554


Epoch 8/25: 100%|██████████| 906/906 [02:09<00:00,  6.99batch/s, loss=0.0547] 


Epoch [8/25], Loss: 0.0547


Epoch 9/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0538] 


Epoch [9/25], Loss: 0.0538


Epoch 10/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0533] 


Epoch [10/25], Loss: 0.0533


Epoch 11/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0528] 


Epoch [11/25], Loss: 0.0528


Epoch 12/25: 100%|██████████| 906/906 [02:09<00:00,  6.99batch/s, loss=0.0522] 


Epoch [12/25], Loss: 0.0522


Epoch 13/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.052]  


Epoch [13/25], Loss: 0.0520


Epoch 14/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0516] 


Epoch [14/25], Loss: 0.0516


Epoch 15/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0513] 


Epoch [15/25], Loss: 0.0513


Epoch 16/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0509] 


Epoch [16/25], Loss: 0.0509


Epoch 17/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0506] 


Epoch [17/25], Loss: 0.0506


Epoch 18/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0504] 


Epoch [18/25], Loss: 0.0504


Epoch 19/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0502] 


Epoch [19/25], Loss: 0.0502


Epoch 20/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0499] 


Epoch [20/25], Loss: 0.0499


Epoch 21/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0497] 


Epoch [21/25], Loss: 0.0497


Epoch 22/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0495] 


Epoch [22/25], Loss: 0.0495


Epoch 23/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0494] 


Epoch [23/25], Loss: 0.0494


Epoch 24/25: 100%|██████████| 906/906 [02:09<00:00,  7.01batch/s, loss=0.0492] 


Epoch [24/25], Loss: 0.0492


Epoch 25/25: 100%|██████████| 906/906 [02:09<00:00,  7.00batch/s, loss=0.0491] 

Epoch [25/25], Loss: 0.0491



