In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch3d as p3d
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.io import load_obj, load_ply, save_ply
from pytorch3d.loss import chamfer_distance

In [2]:
### config

test_date = "10-02-1132"
# 定义点云数据的输入维度和输出维度
input_dim = 3  # 每个点的特征维度
output_dim = 3  # 输出点云的特征维度（可以与输入维度相同）

device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 200
batch_size = 15

sxxxxes = os.listdir("../../data/all_results/")
train_sxxxxes = sxxxxes[:372]
test_sxxxxes = sxxxxes[372:]

In [3]:
# 创建一个简单的全连接神经网络类
class PointCloudFCNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PointCloudFCNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)  # 输入层到隐藏层1
        self.fc2 = nn.Linear(64, 256)  # 隐藏层1到隐藏层2
        self.fc3 = nn.Linear(256, 512)  # 隐藏层2到隐藏层3
        self.fc4 = nn.Linear(512, 256)  # 隐藏层3到隐藏层4
        self.fc5 = nn.Linear(256, output_dim)  # 隐藏层4到输出层

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = self.fc5(x)  # 输出层不使用激活函数
        return x


# DataSet类
class TrainingSet(Dataset):
    def __init__(self, s_name_list, transform=None, sample_num=5000) -> None:
        self.input = [
            f"../../data/all_results/{sname}/0_sphere.obj" for sname in s_name_list
        ]
        self.target = [
            f"../../data/all_results/{sname}/{sname}.ply" for sname in s_name_list
        ]
        self.sample_num = sample_num
        self.transform = transform

    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        #对input采样5000个点
        input_tensor_points, input_tensor_faces, _ = load_obj(self.input[index], load_textures=False)
        input_mesh = Meshes([input_tensor_points], [input_tensor_faces.verts_idx]) #注意这里有个.verts_idx
        input_tensor_points_sampled = sample_points_from_meshes(input_mesh,5000).squeeze(0)
        #对output采样5000个点
        target_tensor_points, target_tensor_faces = load_ply(self.target[index])
        target_mesh = Meshes([target_tensor_points],[target_tensor_faces])
        target_tensor_points_sampled = sample_points_from_meshes(target_mesh).squeeze(0)

        return input_tensor_points_sampled.to(device),target_tensor_points_sampled.to(device)

    def get_target(self, input):
        raise NotImplemented

In [4]:
#################################################################################
# pts,faces = load_ply(f"../../data/all_results/s0004_pulmonary_artery.nii.g_1/s0004_pulmonary_artery.nii.g_1.ply")
# mymesh = Meshes([pts],[faces])
# a = sample_points_from_meshes(mymesh,5)
# a = a.squeeze(0)
# a
# mymesh
#################################################################################

In [5]:
def plot_pointcloud(points, title=""):
    """Sample points uniformly from the surface of the mesh."""
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter3D(x, z, -y)
    ax.set_xlabel("x")
    ax.set_ylabel("z")
    ax.set_zlabel("y")
    ax.set_title(title)
    ax.view_init(190, 30)
    plt.show()

In [6]:
my_train_dataset = TrainingSet(train_sxxxxes)
data_loader = DataLoader(my_train_dataset, batch_size=batch_size, shuffle=False)


In [7]:
# 创建模型实例
model = PointCloudFCNet(input_dim, output_dim).to(device)
model

PointCloudFCNet(
  (fc1): Linear(in_features=3, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=512, bias=True)
  (fc4): Linear(in_features=512, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=3, bias=True)
)

In [8]:
# 定义损失函数和优化器
criterion = chamfer_distance
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器

In [9]:
next(iter(data_loader))

[tensor([[[-0.1819, -0.2157,  0.0700],
          [-0.1685,  0.2583,  0.1361],
          [ 0.4801, -0.2360, -0.0906],
          ...,
          [ 0.3722,  0.0026,  0.0917],
          [-0.2036,  0.3814,  0.0016],
          [-0.1303,  0.7276, -0.1171]],
 
         [[-0.0755,  0.6681, -0.3534],
          [-0.0879,  0.5676, -0.4760],
          [ 0.0429,  0.2966, -0.2191],
          ...,
          [-0.2823,  0.4366, -0.1093],
          [ 0.0848,  0.2530,  0.1054],
          [ 0.0589, -0.3290,  0.0875]],
 
         [[-0.1538, -0.3456,  0.0956],
          [-0.1062,  0.1477,  0.2380],
          [ 0.1707,  0.2002,  0.0335],
          ...,
          [ 0.8604, -0.0865, -0.0585],
          [-0.4789,  0.4589, -0.2089],
          [ 0.1806,  0.2717,  0.0115]],
 
         ...,
 
         [[-0.3271,  0.3688, -0.1441],
          [-0.1728,  0.1301, -0.0534],
          [-0.2358,  0.0214,  0.0781],
          ...,
          [-0.0597,  0.3891, -0.2774],
          [-0.0427,  0.1593, -0.1176],
          [ 0.5415

In [10]:
# 打开log文件准备写入训练结果
logf = open(f"./logs/{test_date}.txt","w")

# 训练模型
for epoch in range(num_epochs):
    for input_point_cloud,target_point_cloud in data_loader:
        optimizer.zero_grad()
        output_point_cloud = model(input_point_cloud)  # 前向传播
        loss,_ = criterion(output_point_cloud, target_point_cloud)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 优化模型参数

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}',file=logf)
    logf.flush()
    # if (epoch + 1) % 5 == 0:
    #     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

        # plot_pointcloud(output_point_cloud)# 绘制点云


logf.close()

In [11]:
# 保存模型

torch.save(model,f"./models/{test_date}.pt")

In [12]:
# 加载模型

model = torch.load(f"./models/{test_date}.pt")

In [13]:
# 使用训练好的模型生成新的点云数据
try_input_point_cloud,_,_ = load_obj("../../data/all_results/s1272_pulmonary_artery.nii.g_1/0_sphere.obj", load_textures=False)
new_point_cloud = model(try_input_point_cloud.to(device))
new_point_cloud
# 输出的new_point_cloud包含了经过神经网络处理后的新点云数据

tensor([[-0.0243,  0.0037,  0.3169],
        [-0.0324,  0.0678,  0.3108],
        [-0.0558,  0.0657,  0.3169],
        ...,
        [ 0.8447, -0.1887, -0.2681],
        [ 0.8248, -0.1783, -0.2700],
        [ 0.7911, -0.2224, -0.2758]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [14]:
#看一眼新的输出长啥样
save_ply(f"./plys/{test_date}.ply",new_point_cloud)