In [1]:
import json
import numpy as np
import trimesh
import torchvision.transforms as T

import torch
import torch.nn as nn
import torch.nn.functional as F

import open3d as o3d

from sklearn.preprocessing import LabelEncoder

from utils.data_load import load_image, load_3d_model, visualize_data # our own utils

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# Get main data

In [2]:
dir = "data/pix3d"

In [3]:
with open(f"{dir}/pix3d.json", "rb") as f:
    metadata = json.load(f)

In [4]:
transform = T.ToTensor()

In [5]:
imgs, masks, models, categories = [], [], [], []

np.random.shuffle(metadata)

for i in range(len(metadata))[:100]:
    img_path = dir+'/'+metadata[i]["img"]
    mask_path = dir+'/'+metadata[i]["mask"]
    model_path = dir+'/'+metadata[i]["model"]

    # take actual img, mask and model
    img, mask = load_image(img_path, mask_path, (224,224))
    model_img = load_3d_model(model_path)

    # combine all geometries, if this is a scene
    if isinstance(model_img, trimesh.Scene):
        model_img = model_img.to_geometry()
    else:
        model_img = model_img

    point_cloud, _ = trimesh.sample.sample_surface(model_img, count=1024)
    # pcd = o3d.geometry.PointCloud()
    # pcd.points = o3d.utility.Vector3dVector(point_cloud)

    imgs.append(img)
    masks.append(mask)
    models.append(transform(point_cloud))
    categories.append(metadata[i]["category"])

In [6]:
# categories

In [7]:
categories = LabelEncoder().fit_transform(categories)
n_categories = len(set(categories))

In [8]:
n_categories

8

# Analyze and preprocess the data

In [9]:
print(imgs[0].shape)
print(masks[0].shape)
print(models[0].shape)

torch.Size([3, 224, 224])
torch.Size([1, 224, 224])
torch.Size([1, 1024, 3])


In [10]:
print(type(imgs[0]))
print(type(masks[0]))
print(type(models[0]))

<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>


# Modeling

### Data pipeline

In [11]:
import torch
from torch.utils.data import Dataset, DataLoader

class PointCloudDataset(Dataset):
    def __init__(self, images, depth_maps, point_clouds, categories, transform=None):
        """
        Args:
            images (list of tensors): List of image tensors.
            depth_maps (list of tensors): List of depth map tensors.
            point_clouds (list of tensors): List of point cloud tensors.
            transform (callable, optional): Optional transform to be applied on the input data.
        """
        self.images = images
        self.depth_maps = depth_maps
        self.point_clouds = point_clouds
        self.categories = categories
        self.transform = transform
        
        assert len(images) == len(depth_maps) == len(point_clouds), \
            "Images, depth maps, and point clouds lists must have the same length"
    
    def __len__(self):
        # Returns the number of samples in the dataset
        return len(self.images)
    
    def __getitem__(self, idx):
        # Get image, depth map, and point cloud at the given index
        image = self.images[idx]
        depth_map = self.depth_maps[idx]
        point_cloud = self.point_clouds[idx]
        category = self.categories[idx]
        
        # Apply any transforms (e.g., data augmentation)
        if self.transform:
            image = self.transform(image)
            depth_map = self.transform(depth_map)
        
        # Return a dictionary of the inputs and target (point cloud)
        return {
            'image': image,
            'mask': depth_map,
            'model': point_cloud,
            'category': category
        }

In [12]:
# Create dataset
dataset = PointCloudDataset(imgs, masks, models, categories)

# Create DataLoader for batching
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

### Model

In [13]:
# class ImageDepthEncoder(nn.Module):
#     def __init__(self):
#         super(ImageDepthEncoder, self).__init__()
        
#         # Image Encoder (ResNet-like)
#         self.image_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
#         self.image_bn1 = nn.BatchNorm2d(64)
#         self.image_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
#         self.image_bn2 = nn.BatchNorm2d(128)
#         self.image_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
#         self.image_bn3 = nn.BatchNorm2d(256)
        
#     def forward(self, image):
#         img_feat = F.relu(self.image_bn1(self.image_conv1(image)))
#         img_feat = F.relu(self.image_bn2(self.image_conv2(img_feat)))
#         img_feat = F.relu(self.image_bn3(self.image_conv3(img_feat)))

#         return img_feat

In [14]:
# class PointCloudDecoder(nn.Module):
#     def __init__(self, in_channels, num_points=1024):
#         super(PointCloudDecoder, self).__init__()
        
#         # Decoder with ConvTranspose layers to upsample and decode into a point cloud
#         self.conv1 = nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=2, padding=1)
#         self.conv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
#         self.conv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
#         self.conv4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        
#         # Final layer to output (num_points * 3) channels for the 3D coordinates
#         self.conv_final = nn.Conv2d(256, num_points * 3, kernel_size=3, padding=1)
        
#     def forward(self, x):
#         # Upsample with ConvTranspose layers
#         x = F.relu(self.conv1(x))
#         print("Decoding after first layer -->", x.shape)
#         x = F.relu(self.conv2(x))
#         print("Decoding after second layer -->", x.shape)
#         x = F.relu(self.conv3(x))
#         print("Decoding after third layer -->", x.shape)
#         x = F.relu(self.conv4(x))
#         print("Decoding after fourth layer -->", x.shape)
        
#         # Final point cloud prediction
#         point_cloud = self.conv_final(x)  # Output shape: (batch_size, num_points*3, H, W)
        
#         # Reshape to (batch_size, num_points, 3) for point cloud output
#         point_cloud = point_cloud.view(point_cloud.size(0), -1, 3)
        
#         return point_cloud

In [15]:
# class PointCloudPredictor(nn.Module):
#     def __init__(self, num_points=1024):
#         super(PointCloudPredictor, self).__init__()
#         self.encoder = ImageDepthEncoder()
#         self.decoder = PointCloudDecoder(in_channels=256, num_points=num_points)  # Combined feature dim from encoder
    
#     def forward(self, image, depth):
#         # Encode image and depth features
#         combined_features = self.encoder(image)
#         print("After encoding -->", combined_features.shape)
        
#         # Decode features into point cloud
#         point_cloud = self.decoder(combined_features)
#         print("After decoding -->", point_cloud.shape)
#         return point_cloud

In [16]:
# model = PointCloudPredictor()

In [17]:
# class Encoder(nn.Module):
#     def __init__(self):
#         super(Encoder, self).__init__()
#         # Convolutional layers for encoding
#         self.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1)  # Input: (4, 256, 256), Output: (64, 128, 128)
#         self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # Output: (128, 64, 64)
#         self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 32, 32)
#         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 16, 16)
#         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 8, 8)
#         self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 4, 4)
#         self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 2, 2)
#         self.conv8 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # Output: (256, 1, 1)

#     def forward(self, x):
#         x = F.relu(self.conv1(x))
#         x = F.relu(self.conv2(x))
#         x = F.relu(self.conv3(x))
#         x = F.relu(self.conv4(x))
#         x = F.relu(self.conv5(x))
#         x = F.relu(self.conv6(x))
#         x = F.relu(self.conv7(x))
#         x = F.relu(self.conv8(x))
#         return x  # Shape: (B, 256, 1, 1)

In [18]:
# class Decoder(nn.Module):
#     def __init__(self):
#         super(Decoder, self).__init__()
#         # Fully connected layers to transform the encoded feature map to a point cloud
#         self.fc1 = nn.Linear(256, 1024 * 3)  # Transform from 256 -> 1024*3

#     def forward(self, x):
#         x = x.view(x.size(0), -1)  # Flatten (B, 256, 1, 1) -> (B, 256)
#         x = F.relu(self.fc1(x))  # Fully connected layer to get (B, 1024*3)
#         x = x.view(-1, 1024, 3)  # Reshape to (B, 1024, 3)
#         return x

In [19]:
# class PointCloudPredictionNet(nn.Module):
#     def __init__(self):
#         super(PointCloudPredictionNet, self).__init__()
#         self.encoder = Encoder()
#         self.decoder = Decoder()

#     def forward(self, image):
#         encoded = self.encoder(x)
#         point_cloud = self.decoder(encoded)
#         return point_cloud  # Output shape: (B, 1024, 3)

In [20]:
# model = PointCloudPredictionNet()

In [21]:
class ImageFeatureNet(nn.Module):
    def __init__(self):
        super(ImageFeatureNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.bn3 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 28 * 28, 256)  # Adjust based on output size

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # Output: (batch_size, 64, 112, 112)
        x = F.relu(self.bn2(self.conv2(x)))  # Output: (batch_size, 128, 56, 56)
        x = F.relu(self.bn3(self.conv3(x)))  # Output: (batch_size, 256, 28, 28)
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, 256 * 28 * 28)
        x = F.relu(self.fc1(x))  # Output: (batch_size, 256)
        return x

In [22]:
class PointCloudDecoder(nn.Module):
    def __init__(self, num_points=1024):
        super(PointCloudDecoder, self).__init__()
        self.num_points = num_points
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, num_points * 3)  # Predict 3D coordinates (x, y, z)

    def forward(self, x):
        x = F.relu(self.fc1(x))  # Shape: (batch_size, 512)
        x = F.relu(self.fc2(x))  # Shape: (batch_size, 1024)
        x = self.fc3(x)  # Shape: (batch_size, num_points * 3)
        x = x.view(-1, self.num_points, 3)  # Reshape to (batch_size, num_points, 3)
        return x

In [23]:
class FullModel(nn.Module):
    def __init__(self, num_classes, num_points=1024):
        super(FullModel, self).__init__()
        self.image_feature_net = ImageFeatureNet()
        self.classification_fc = nn.Linear(256, num_classes)  # Classification branch
        self.point_cloud_decoder = PointCloudDecoder(num_points=num_points)  # Point cloud reconstruction branch

    def forward(self, image):
        # Extract image features
        image_features = self.image_feature_net(image)  # Shape: (batch_size, 256)

        # Classification branch
        class_output = self.classification_fc(image_features)  # Shape: (batch_size, num_classes)

        # Point cloud reconstruction branch
        point_cloud_output = self.point_cloud_decoder(image_features)  # Shape: (batch_size, num_points, 3)

        return class_output, point_cloud_output

In [24]:
model = FullModel(num_classes=n_categories, num_points=1024)

### Train

In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [26]:
model

FullModel(
  (image_feature_net): ImageFeatureNet(
    (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc1): Linear(in_features=200704, out_features=256, bias=True)
  )
  (classification_fc): Linear(in_features=256, out_features=8, bias=True)
  (point_cloud_decoder): PointCloudDecoder(
    (fc1): Linear(in_features=256, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=1024, bias=True)
    (fc3): Linear(in_features=1024, out_features=3072, bias=True)
  )
)

In [27]:
num_epochs = 2

for epoch in range(num_epochs):
    model.train()
    for batch in data_loader:
        images_batch = batch['image']
        depth_maps_batch = batch['mask']
        point_clouds_batch = batch['model']
        category = batch['category']

        optimizer.zero_grad()
        class_output, point_cloud_output = model(images_batch)
        
        loss = criterion(class_output, category).float()
        loss.backward()
        optimizer.step()

In [28]:
def visuzlie_point_cloud(data):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(point_cloud_example)
    
    o3d.visualization.draw_geometries([pcd])

In [29]:
point_cloud_example = point_clouds_batch[0].cpu().numpy()[-1, :, :]
point_cloud_output_exampole = point_cloud_output[0].detach().cpu().numpy()

In [32]:
# visuzlie_point_cloud(point_cloud_example)

In [33]:
# visuzlie_point_cloud(point_cloud_output_exampole)