In [1]:
import json
import numpy as np
from tqdm import tqdm
import trimesh
import torchvision.transforms as T

import torch
import torch.nn as nn
import torch.nn.functional as F

import open3d as o3d

from sklearn.preprocessing import LabelEncoder

from utils.data_load import load_image, load_3d_model, visualize_data # our own utils

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# Functions for model

In [2]:
def chamfer_distance(point_cloud1, point_cloud2):
    """
        Chamfer Distance between two point clouds.
    """
    
    B, N, D = point_cloud1.size()
    _, M, _ = point_cloud2.size()

    point_cloud1 = point_cloud1.unsqueeze(2).expand(B, N, M, D)
    point_cloud2 = point_cloud2.unsqueeze(1).expand(B, N, M, D)
    
    dist = torch.norm(point_cloud1 - point_cloud2, dim=3)  # Euclidean distance between points
    
    # Compute Chamfer Distance
    min_dist1, _ = torch.min(dist, dim=2)  # Closest point from pc1 to pc2
    min_dist2, _ = torch.min(dist, dim=1)  # Closest point from pc2 to pc1

    # calculates mean per batch
    chamfer = torch.mean(min_dist1, dim=1) + torch.mean(min_dist2, dim=1)
    
    return torch.mean(chamfer)

# Get main data

In [3]:
dir = "data/pix3d"

In [4]:
with open(f"{dir}/pix3d.json", "rb") as f:
    metadata = json.load(f)

In [5]:
transform = T.ToTensor()

In [None]:
imgs, masks, models, categories = [], [], [], []

np.random.shuffle(metadata)

for i in range(len(metadata))[:1000]:
    img_path = dir+'/'+metadata[i]["img"]
    mask_path = dir+'/'+metadata[i]["mask"]
    model_path = dir+'/'+metadata[i]["model"]

    # take actual img, mask and model
    img, mask = load_image(img_path, mask_path, (224,224))
    model_img = load_3d_model(model_path)

    # combine all geometries, if this is a scene
    if isinstance(model_img, trimesh.Scene):
        model_img = model_img.to_geometry()
    else:
        model_img = model_img

    point_cloud, _ = trimesh.sample.sample_surface(model_img, count=1024)
    # pcd = o3d.geometry.PointCloud()
    # pcd.points = o3d.utility.Vector3dVector(point_cloud)

    imgs.append(img)
    masks.append(mask)
    models.append(transform(point_cloud))
    categories.append(metadata[i]["category"])

In [None]:
# categories

In [None]:
categories = LabelEncoder().fit_transform(categories)
n_categories = len(set(categories))

In [None]:
n_categories

In [None]:
total_amount = len(imgs)

train_size = total_amount * 0.7
test_size = total_amount * 0.15
validation_size = total_amount * 0.15

In [None]:
# split into train, test and validation
from_dp, to_dp = 0, train_size
train, ,  = (imgs[from_dp:to_dp], masks[from_dp:to_dp], models[from_dp:to_dp], categories[from_dp:to_dp])

from_dp, to_dp = train_size, train_size+test_siz
test = (imgs[from_dp:to_dp], masks[from_dp:to_dp], models[from_dp:to_dp], categories[from_dp:to_dp])

from_dp, to_dp = train_size+test_size, train_size+test_size+validation_size
validation = (imgs[from_dp:to_dp], masks[from_dp:to_dp], models[from_dp:to_dp], categories[from_dp:to_dp])

# Analyze and preprocess the data

In [None]:
print(imgs[0].shape)
print(masks[0].shape)
print(models[0].shape)

# Modeling

### Data pipeline

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class PointCloudDataset(Dataset):
    def __init__(self, images, depth_maps, point_clouds, categories, transform=None):
        """
        Args:
            images (list of tensors): List of image tensors.
            depth_maps (list of tensors): List of depth map tensors.
            point_clouds (list of tensors): List of point cloud tensors.
            transform (callable, optional): Optional transform to be applied on the input data.
        """
        self.images = images
        self.depth_maps = depth_maps
        self.point_clouds = point_clouds
        self.categories = categories
        self.transform = transform
        
        assert len(images) == len(depth_maps) == len(point_clouds), \
            "Images, depth maps, and point clouds lists must have the same length"
    
    def __len__(self):
        # Returns the number of samples in the dataset
        return len(self.images)
    
    def __getitem__(self, idx):
        # Get image, depth map, and point cloud at the given index
        image = self.images[idx]
        depth_map = self.depth_maps[idx]
        point_cloud = self.point_clouds[idx]
        category = self.categories[idx]
        
        # Apply any transforms (e.g., data augmentation)
        if self.transform:
            image = self.transform(image)
            depth_map = self.transform(depth_map)
        
        # Return a dictionary of the inputs and target (point cloud)
        return {
            'image': image,
            'mask': depth_map,
            'model': point_cloud,
            'category': category
        }

In [None]:
# # Create dataset
# dataset = PointCloudDataset(imgs, masks, models, categories)

# # Create DataLoader for batching
# data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
train_dataset = PointCloudDataset(*train)
test_dataset = PointCloudDataset(*test)
validation_dataset = PointCloudDataset(*validation)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
validation_loader = DataLoader(validation_dataset, batch_size=32)

### Model

In [None]:
class ImageFeatureNet(nn.Module):
    def __init__(self):
        super(ImageFeatureNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.bn3 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 28 * 28, 256)  # Adjust based on output size

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # Output: (batch_size, 64, 112, 112)
        x = F.relu(self.bn2(self.conv2(x)))  # Output: (batch_size, 128, 56, 56)
        x = F.relu(self.bn3(self.conv3(x)))  # Output: (batch_size, 256, 28, 28)
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, 256 * 28 * 28)
        x = F.relu(self.fc1(x))  # Output: (batch_size, 256)
        return x

In [None]:
class PointCloudDecoder(nn.Module):
    def __init__(self, num_points=1024):
        super(PointCloudDecoder, self).__init__()
        self.num_points = num_points
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, num_points * 3)  # Predict 3D coordinates (x, y, z)

    def forward(self, x):
        x = F.relu(self.fc1(x))  # Shape: (batch_size, 512)
        x = F.relu(self.fc2(x))  # Shape: (batch_size, 1024)
        x = self.fc3(x)  # Shape: (batch_size, num_points * 3)
        x = x.view(-1, self.num_points, 3)  # Reshape to (batch_size, num_points, 3)
        return x

In [None]:
class FullModel(nn.Module):
    def __init__(self, num_classes, num_points=1024):
        super(FullModel, self).__init__()
        self.image_feature_net = ImageFeatureNet()
        self.classification_fc = nn.Linear(256, num_classes)  # Classification branch
        self.point_cloud_decoder = PointCloudDecoder(num_points=num_points)  # Point cloud reconstruction branch

    def forward(self, image):
        # Extract image features
        image_features = self.image_feature_net(image)  # Shape: (batch_size, 256)

        # Classification branch
        class_output = self.classification_fc(image_features)  # Shape: (batch_size, num_classes)

        # Point cloud reconstruction branch
        point_cloud_output = self.point_cloud_decoder(image_features)  # Shape: (batch_size, num_points, 3)

        return class_output, point_cloud_output

In [None]:
model = FullModel(num_classes=n_categories, num_points=1024)

### Train

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model

In [None]:
num_epochs = 10
train_losses, validation_losses = [], []

In [None]:
for epoch in range(num_epochs):
    
    model.train()
    running_loss = 0.0
    point_clouds_train_loss, point_clouds_validation_loss = 0.0, 0.0
    
    for batch in tqdm(train_loader):
        images_batch = batch['image']
        depth_maps_batch = batch['mask']
        point_clouds_batch = batch['model']
        category = batch['category']

        optimizer.zero_grad()
        class_output, point_cloud_output = model(images_batch)
        
        loss = criterion(class_output, category).float()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images_batch.size(0)

        # compute loss for point cloud prediction
        chamfer_distance_train = chamfer_distance(point_cloud_output, point_clouds_batch[:, 0, :])

    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validate at each epoch
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for batch in validation_loader:
            images_batch = batch['image']
            depth_maps_batch = batch['mask']
            point_clouds_batch = batch['model']
            category = batch['category']
        
            class_output, point_cloud_output = model(images_batch)
            loss = criterion(class_output, category)
            
            running_loss += loss.item() * images_batch.size(0)

            # compute loss for point cloud prediction
            chamfer_distance_validation = chamfer_distance(point_cloud_output, point_clouds_batch[:, 0, :])

    validation_loss = running_loss / len(validation_loader.dataset)
    validation_losses.append(validation_loss)

    print(f"Epochs {epoch+1}/{num_epochs} - Train Classification Loss: {train_loss}, Validation Classification Loss: {validation_loss}")
    print(f"\t\tTrain Chamfer Distance: {chamfer_distance_train}, Validation Chamfer Distance: {chamfer_distance_validation}")

In [None]:
def visuzlie_point_cloud(data):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(data)
    
    o3d.visualization.draw_geometries([pcd])

In [None]:
point_cloud_example = point_clouds_batch[0].cpu().numpy()[-1, :, :]
point_cloud_output_example = point_cloud_output[0].detach().cpu().numpy()

In [None]:
import plotly.graph_objs as go

def visualize_point_cloud_interactive(point_cloud):
    # Extract X, Y, Z coordinates
    x = point_cloud[:, 0]
    y = point_cloud[:, 1]
    z = point_cloud[:, 2]
    
    # Create a scatter plot
    trace = go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=2,   # Marker size
            color=z,  # Color by z-coordinate
            colorscale='Viridis',
            opacity=0.8
        )
    )
    
    layout = go.Layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        margin=dict(l=0, r=0, b=0, t=0)  # Minimal margin for full use of space
    )
    
    fig = go.Figure(data=[trace], layout=layout)
    
    # Show the plot
    fig.show()

In [None]:
# visualize_point_cloud_interactive(point_cloud_example)

In [None]:
# visualize_point_cloud_interactive(point_cloud_output_example)

In [None]:
# visuzlie_point_cloud(point_cloud_example)

In [None]:
visuzlie_point_cloud(point_cloud_output_example)