In [15]:
from tqdm import tqdm
import os

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader  # Correct import

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import random_split
from torch.cuda.amp import GradScaler, autocast

import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", filename="sharpnet.log")
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())

from utils import plot_3d_shape, save_point_cloud_ply, ResamplePoints, save_model, chamfer_loss, mmd_cd_loss, chamfer_loss_eval
from models import ResnetGenerator, NLayerDiscriminator

# Path to the dataset
DATASET_PATH = "./shapenetcore_partanno_segmentation_benchmark_v0_normal"
SPLIT_RATIO = 0.8
BATCH_SIZE = 16
EPOCHS = 1 #50

# Optimizers
BETA1 = 0.5
BETA2 = 0.999
LEARNING_RATE = 0.0002

# Initialize networks
INPUT_NC = 3
OUTPUT_NC = 3
NGF = 64
NDF = 64
NUM_BLOCKS = 6

# Training parameters
ACCUMULATION_STEPS = 1
batch_size = 1  # Start with a very small batch size
TARGET_SIZE = 64  # Reduced target size for smaller GPU usage

MODEL_PATH = "./saved_models"
os.makedirs(MODEL_PATH, exist_ok=True)
PLOT_PATH = "./plots"
os.makedirs(PLOT_PATH, exist_ok=True)



dataset = ShapeNet(root=DATASET_PATH, categories=["Airplane"]).shuffle()[:100]
# Provide the correct path to the extracted dataset
# dataset = ShapeNet(root=dataset_path, categories=["Table", "Lamp", "Guitar", "Motorbike"]).shuffle()[:1000]

logger.info(f"Number of Samples: {len(dataset)}")
logger.info(f"Sample: {dataset[0]}")

sample = dataset[0]
logger.info(f"Number of points: {sample.pos.shape[0]}, Dimension of each point: {sample.pos.shape[1]}")

#%%
# Visualize a sample
sample_idx = 9
# plot_3d_shape(dataset[sample_idx]) # NOTE: each data points have different shapes
save_point_cloud_ply(dataset[sample_idx], os.path.join(PLOT_PATH, f"point_cloud_{sample_idx}.ply"))
logger.info(f"Number of points: {dataset[sample_idx].pos.shape[0]}, Dimension of each point: {dataset[sample_idx].pos.shape[1]}")
#%% 
# Data Augmentation
augmentation = T.Compose([
    ResamplePoints(2048),
    T.RandomJitter(0.03),
    T.RandomFlip(axis=1),
    T.RandomShear(0.2)
])

# Apply augmentation
dataset.transform = augmentation

#%%
# DataLoader

# Split the dataset into train and test sets (80% train, 20% test)
train_size = int(SPLIT_RATIO * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders for train and test sets
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Verify the batch sizes
for batch in train_loader:
    print("Train Batch: ", batch)
    break

for batch in test_loader:
    print("Test Batch: ", batch)
    break

logger.info(f"Number of Training Samples: {len(train_dataset)}")
logger.info(f"Number of Test Samples: {len(test_dataset)}")
logger.info(f"Train dataset[0]: {train_dataset[0]}")
logger.info(f"Test dataset[0]: {test_dataset[0]}")


Number of Samples: 100
Number of Samples: 100
Sample: Data(x=[2483, 3], y=[2483], pos=[2483, 3], category=[1])
Sample: Data(x=[2483, 3], y=[2483], pos=[2483, 3], category=[1])
Number of points: 2483, Dimension of each point: 3
Number of points: 2483, Dimension of each point: 3
Number of points: 2635, Dimension of each point: 3
Number of points: 2635, Dimension of each point: 3
Number of Training Samples: 80
Number of Training Samples: 80
Number of Test Samples: 20
Number of Test Samples: 20
Train dataset[0]: Data(x=[2048, 3], y=[2048], pos=[2048, 3], category=[1])
Train dataset[0]: Data(x=[2048, 3], y=[2048], pos=[2048, 3], category=[1])
Test dataset[0]: Data(x=[2048, 3], y=[2048], pos=[2048, 3], category=[1])
Test dataset[0]: Data(x=[2048, 3], y=[2048], pos=[2048, 3], category=[1])


Point cloud saved to ./plots/point_cloud_9.ply
Train Batch:  DataBatch(x=[32768, 3], y=[32768], pos=[32768, 3], category=[16], batch=[32768], ptr=[17])
Test Batch:  DataBatch(x=[32768, 3], y=[32768], pos=[32768, 3], category=[16], batch=[32768], ptr=[17])


In [17]:
for train_batch in train_loader:
    print("Train Batch: ", train_batch)
    print("Train Batch pos shape: ", train_batch.pos.shape)
    print("Train Batch y shape: ", train_batch.batch.shape)
    print(train_batch.pos.view(BATCH_SIZE, 3, 2048).contiguous().unsqueeze(3).permute(0, 1, 3, 2).shape)
    break


Train Batch:  DataBatch(x=[32768, 3], y=[32768], pos=[32768, 3], category=[16], batch=[32768], ptr=[17])
Train Batch pos shape:  torch.Size([32768, 3])
Train Batch y shape:  torch.Size([32768])
torch.Size([16, 3, 1, 2048])
