In [None]:
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from tqdm import tqdm

from pointnet_utils import PointNetEncoder, feature_transform_reguliarzer


## Write a Dataset Class to generate Cylinder Dataset, with different poses and augmentations. 



class BarrelNet(nn.Module):
    def __init__(self, k=40, normal_channel=True):
        super(BarrelNet, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.feat_normal = PointNetEncoder(global_feat=True, use_Tnet=False, channel=channel)
        self.feat_radius = PointNetEncoder(global_feat=True, use_Tnet=True, channel=channel)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        xn, _, _ = self.feat_normal(x)
        xr, _, _ = self.feat_radius(x)
        x = (xn + xr)/2# in future make it (xn + xr)/2
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        radius = F.sigmoid(x[:,-1])        
        normal = torch.concatenate([F.tanh(x[:,:2]), F.sigmoid(x[:,2:3])], dim=1)
        normal = normal / torch.linalg.norm(normal, dim=-1, keepdim=True)
        return radius, normal 
    
    

In [None]:
import trimesh

def create_capped_cylinder(radius, height, sections=32):
    """
    Create a capped cylinder using trimesh.
    
    Args:
        radius (float): Radius of the cylinder.
        height (float): Height of the cylinder.
        sections (int): Number of sections for the cylinder's roundness.
        
    Returns:
        trimesh.Trimesh: A capped cylinder mesh.
    """
    # Create the capped cylinder with cap=True
    capped_cylinder = trimesh.creation.cylinder(radius=radius, height=height, sections=sections, cap=True)
    
    return capped_cylinder

# Example usage
radius = 1.0
height = 4.0
capped_cylinder = create_capped_cylinder(radius, height)

# Visualization (optional)
capped_cylinder.show()

In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import roma
from data import generate_cylinder_pts, prepare_point_cloud, normalize_pc, CylinderData
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import Dataset, DataLoader

    
train_data = CylinderData(num_poses=5000)
test_data = CylinderData(num_poses=1000)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

pointnet = BarrelNet(k=4, normal_channel=False).cuda()
pointnet.train()


optimizer = optim.Adam(pointnet.parameters(), lr=0.00005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.9)

def compute_loss(sample, radius_pred, axis_pred, use_radius_loss=False, use_axis_loss=True):
    """ Compute loss on the predictions of pointnet """
    assert use_axis_loss or use_radius_loss, "Atleast one of the losses should be used"
    loss_axis = (1 - F.cosine_similarity(sample['axis_vec'], axis_pred, dim=1)).mean()
    loss_radius = F.mse_loss(radius_pred, sample['radius_gt'] / sample['scale_gt'])
    loss = 0.0 
    if use_radius_loss:
        loss = loss + loss_radius
    if use_axis_loss:
        loss = loss + loss_axis
    return loss

def train(model, train_loader, optimizer, scheduler, num_epochs=10000, save_epoch=1000, save_dir='weights/r0'):
    """ 
    """
    os.makedirs(save_dir, exist_ok=True)
    model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        running_loss = 0.0
        for sample in train_loader:
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            radius_pred, axis_pred = model(sample['pts'])
            loss = compute_loss(sample, radius_pred, axis_pred, use_radius_loss=False, use_axis_loss=True)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Step the scheduler at the end of each epoch
        scheduler.step()
        if epoch % 50 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, LR: {scheduler.get_last_lr()[0]}')

        if epoch % save_epoch == 0:
            torch.save(pointnet.state_dict(), os.path.join(save_dir, f'pointnet_iter{epoch}.pth'))

train(pointnet, train_loader, optimizer, scheduler)
torch.save(pointnet.state_dict(), 'weights/pointnet.pth')

# # Example usage

# points = generate_cylinder_pts(1.0, 1.0)
# axis = torch.tensor([0.0, 0.0, 1.0])
# points_valid, radius, burial_offset = prepare_point_cloud(points, axis)
# pts, scale = normalize_pc(points_valid)
# point_cloud_np = points.cpu().numpy()
# point_cloud_rotated_np = points_valid.cpu().numpy()

# fig = plt.figure()
# ax1 = fig.add_subplot(121, projection='3d')
# ax1.scatter(point_cloud_np[:, 0], point_cloud_np[:, 1], point_cloud_np[:, 2], s=1)
# ax1.set_title("Original Point Cloud")

# ax2 = fig.add_subplot(122, projection='3d')
# ax2.scatter(point_cloud_rotated_np[:, 0], point_cloud_rotated_np[:, 1], point_cloud_rotated_np[:, 2], s=1)
# ax2.set_title("Rotated Point Cloud")

In [36]:
from tqdm import tqdm
def test(model, test_loader):
    model.eval()
    running_acc_axis = 0.0
    accs_axis = []
    criterion_cosine = nn.CosineSimilarity(dim=1)
    for sample in tqdm(test_loader):
        with torch.no_grad():
            radius_pred, axis_pred = model(sample['pts'])
            radius_pred = radius_pred * sample['scale_gt'].cuda()
            acc_axis = (1 + criterion_cosine(sample['axis_vec'], axis_pred).mean())/2
            running_acc_axis += acc_axis.item()
            accs_axis.append(acc_axis)
    accs = torch.tensor(accs_axis)
    print(f"Best Accuracy: {torch.max(accs)}")
    print(f"Worst Accuracy: {torch.min(accs)}")
    print(f'Average Accuracy: {running_acc_axis / len(test_loader)}')


test(pointnet, test_loader)


100%|██████████| 1000/1000 [00:03<00:00, 253.10it/s]

Best Accuracy: 0.9999985694885254
Worst Accuracy: 0.0026373863220214844
Average Accuracy: 0.9705951552689075



