In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys
sys.path.append("/content/drive/My Drive/GoogleColab/pytorch3d_packages")

In [7]:
# Imports
import os
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from torch.utils.data import DataLoader, random_split

from pytorch3d.datasets import ShapeNetCore
from pytorch3d.structures import Meshes
from pytorch3d.ops import GraphConv

# Paths
zip_path = "/content/drive/MyDrive/GoogleColab/ShapeNetCore.zip"
extract_path = "/content/ShapeNetCore/ShapeNetCore"

# Extract zip if not already extracted
if not os.path.exists(extract_path):
    print(f"Extracting {zip_path} to {extract_path} ...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall("/content/ShapeNetCore")
    print("Extraction completed")

# ClassNames
SYNSET_TO_NAME = {
    "02808440": "bathtub",
    "03642806": "laptop",
    "02992529": "cellphone",
    "03211117": "display",
    "03046257": "clock"
}

# Verify extraction
top_level_ids = os.listdir(extract_path)
top_level_names = [SYNSET_TO_NAME.get(syn, syn) for syn in top_level_ids]
print("Top-level classes:", top_level_names)

# Removes warning about other classes
warnings.filterwarnings(
    "ignore",
    message="The following categories are included in ShapeNetCore ver.2's official mapping.*"
)

# Dataset
def build_shapenet_dataset(root_dir, synsets=None):
    """
    Use built-in PyTorch3D ShapeNetCore loader.
    """
    categories = ["03642806", "03211117", "03046257", "02992529", "02808440"]
    dataset = ShapeNetCore(
        data_dir=root_dir,
        synsets=categories,
        version=2,
        load_textures=False # Textures are not needed for classification
    )
    print(f"Detected {len(dataset)} meshes across {len(dataset.synset_ids)} classes.")
    return dataset

# Model
class SimpleGraphCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = GraphConv(3, 64) # Input: 3D vertex coordinates -> 64 features
        self.conv2 = GraphConv(64, 128) # 64 -> 128 features
        self.conv3 = GraphConv(128, 256) # 128 -> 256 features
        self.fc1 = nn.Linear(256, 128) # Fully connected Layer
        self.fc2 = nn.Linear(128, num_classes) # Output layer

    def forward(self, meshes: Meshes):
        x = meshes.verts_packed() # All vertices of all meshes
        edges = meshes.edges_packed() # All edges of all meshes

        # Apply 3 GCN layers with ReLU activations
        x = F.relu(self.conv1(x, edges))
        x = F.relu(self.conv2(x, edges))
        x = F.relu(self.conv3(x, edges))

        # Compute per-mesh global feature by averaging vertex features
        num_verts_per_mesh = meshes.num_verts_per_mesh()
        batch_index = torch.cat([torch.full((n,), i, device=x.device)
                                 for i, n in enumerate(num_verts_per_mesh)])

        sum_features = torch.zeros(len(num_verts_per_mesh), x.size(1), device=x.device)
        sum_features.index_add_(0, batch_index, x)
        avg_features = sum_features / num_verts_per_mesh.view(-1, 1).float()

        # Apply fully connected layers for final class prediction
        x = F.relu(self.fc1(avg_features))
        x = self.fc2(x)
        return x

# Training Loop
def train_model(zip_path=zip_path, extract_path=extract_path, epochs=5, batch_size=4, lr=1e-3, device=None):
    device = device or torch.device("cpu")

    dataset = build_shapenet_dataset(extract_path)
    if len(dataset) == 0:
        raise RuntimeError("Dataset is empty — check .obj files and folder structure!")

    # Build a consistent label map from synset_id to class index
    synset_to_idx = {sid: i for i, sid in enumerate(dataset.synset_ids)}
    idx_to_name = {i: SYNSET_TO_NAME.get(sid, sid) for sid, i in synset_to_idx.items()}

    # --- Collate function compatible with ShapeNetCore ---
    def collate_mesh_batch(batch):
        verts_list = [item["verts"] for item in batch]
        faces_list = [item["faces"] for item in batch]
        # Fix warning by converting to long on device instead of creating a new tensor
        labels = torch.as_tensor([synset_to_idx[item["synset_id"]] for item in batch], dtype=torch.long)
        return {"mesh": Meshes(verts=verts_list, faces=faces_list), "labels": labels}

    # Split the dataset into training and validation (80/20)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_mesh_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_mesh_batch)

    # Initialize the graph classifier, optimizer, and loss function
    model = SimpleGraphCNN(num_classes=len(dataset.synset_ids)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            meshes = batch['mesh'].to(device)
            labels = batch['labels'].to(device, dtype=torch.long)  # Fix warning
            optimizer.zero_grad()
            outputs = model(meshes)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        val_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                meshes = batch['mesh'].to(device)
                labels = batch['labels'].to(device, dtype=torch.long)
                outputs = model(meshes)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * labels.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_acc = 100. * correct / total

        # Print only validation loss and accuracy
        print(f"Epoch {epoch+1}/{epochs} - ValLoss: {val_loss/len(val_loader.dataset):.4f}, ValAcc: {val_acc:.2f}%")

    return model

#  Run training classifying 3D figures
if __name__ == "__main__":
    model = train_model(epochs=5, batch_size=4, lr=1e-3)


Top-level classes: ['laptop', 'display', 'bathtub', 'clock', 'cellphone']
Detected 3891 meshes across 3891 classes.
Epoch 1/5 - ValLoss: 1.0349, ValAcc: 46.73%
Epoch 2/5 - ValLoss: 0.7785, ValAcc: 68.93%
Epoch 3/5 - ValLoss: 0.6985, ValAcc: 74.58%
Epoch 4/5 - ValLoss: 0.6159, ValAcc: 77.92%
Epoch 5/5 - ValLoss: 0.6734, ValAcc: 75.99%
