File này: Huấn luyện model dựa trên dataset mới

In [1]:
import torch
from torch.utils.data import Dataset
import os
import numpy as np

class PointCloudDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.samples = [os.path.join(data_dir, f) 
                        for f in os.listdir(data_dir) if f.endswith(".npz")]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        data = np.load(self.samples[idx])
        
        points = torch.tensor(data["points"], dtype=torch.float32)      # Nx6
        center_gt = torch.tensor(data["center_gt"], dtype=torch.float32)  # (3,)
        normal_gt = torch.tensor(data["normal_gt"], dtype=torch.float32)  # (3,)

        return points, center_gt, normal_gt

In [2]:
import torch
import torch.nn as nn

class SimplePointNet(nn.Module):
    def __init__(self):
        super(SimplePointNet, self).__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(6, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU()
        )
        
        self.fc = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 6)    # 3 center + 3 normal
        )

    def forward(self, x):
        # x: [B, N, 6]
        x = self.mlp(x)       # [B, N, 256]
        x = torch.max(x, dim=1)[0]   # Global Max Pool → [B, 256]
        out = self.fc(x)      # [B, 6]
        center_pred = out[:, :3]
        normal_pred = out[:, 3:]
        return center_pred, normal_pred

In [3]:
import torch.nn.functional as F

def compute_loss(center_pred, center_gt, normal_pred, normal_gt):
    loss_center = F.mse_loss(center_pred, center_gt)
    loss_normal = 1 - F.cosine_similarity(normal_pred, normal_gt, dim=-1).mean()
    
    return loss_center + loss_normal, loss_center, loss_normal

In [None]:
import torch
from torch.utils.data import DataLoader
from module6_dataset import PointCloudDataset
from module6_model import SimplePointNet
from module6_loss import compute_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = PointCloudDataset("dataset")  # từ Module 5
loader = DataLoader(dataset, batch_size=8, shuffle=True)

model = SimplePointNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
for epoch in range(10):
    for points, center_gt, normal_gt in loader:
        points = points.to(device)      # [B, N, 6]
        center_gt = center_gt.to(device)
        normal_gt = normal_gt.to(device)

        center_pred, normal_pred = model(points)

        loss, lc, ln = compute_loss(center_pred, center_gt, normal_pred, normal_gt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: total={loss.item():.4f}, center={lc:.4f}, normal={ln:.4f}")