In [35]:
import parameter
import argparse
import sys
import torch
import os
from plyfile import PlyData
import numpy as np
import cv2
from utils.camera_utils import focal2fov

from utils.colmap_utils import (
    read_extrinsics_binary, 
    read_intrinsics_binary, 
    read_points3D_binary, 
    read_points3D_text,
    qvec2rotmat
)

def ParamReader(argv=None):
    if argv is None:
        argv = sys.argv[1:]
    
    parser = argparse.ArgumentParser()
    data = parameter.data
    for key, value in data.items():
        parser.add_argument(f'--{key}', default=value, help=f'{key} parameter')
    
    parser.add_argument("--detect_anomaly", action="store_true", default=False)
    parser.add_argument("--test_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
    parser.add_argument("--save_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])    
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
    parser.add_argument("--start_checkpoint", type=str, default=None)
    
    args = parser.parse_args(argv)
    return args

class PointClond:
    def __init__(self, positions, colors, normals):
        self.positions = positions
        self.colors = colors
        self.normals = normals

class ImageInfo:
    def __init__(self,image_name, image_path, image, image_width, image_height, R, T, fov_x, fov_y, device):
        self.image_name=image_name
        self.image_path=image_path
        self.image = torch.from_numpy(image).to(device)
        self.image_width=image_width
        self.image_height=image_height
        self.R=R
        self.T=T
        self.fov_x=fov_x
        self.fov_y=fov_y
        self.zfar = 100.0  # 远平面
        self.znear = 0.01  # 近平面
        # 计算相机矩阵并移动到指定设备
        self.viewMatrix = self.getViewMatrix(self.R, self.T).to(device)
        # 计算投影矩阵并移动到指定设备
        self.projMatrix = self.getProjMatrix(self.znear, self.zfar, self.fov_x, self.fov_y).to(device)
        self.viewProjMatrix = self.viewMatrix @ self.projMatrix
        #计算相机中心，逆相机矩阵的前三行第三列中
        self.cameraCenter = torch.inverse(self.viewMatrix)[:3, 3]
    def getViewMatrix(self, R, T):
        viewMatrix=torch.zeros((4, 4), dtype=torch.float32)
        viewMatrix[3, 3] = 1.0
        R=torch.tensor(R, dtype=torch.float32)
        viewMatrix[:3, :3] = R
        T=torch.tensor(T, dtype=torch.float32)
        viewMatrix[:3, 3] = T
        return viewMatrix
    
    def getProjMatrix(self,znear, zfar, fov_x,fov_y):
        #计算视场角正切值的一半
        tan_fov_y = np.tan((fov_y / 2))
        tan_fov_x = np.tan((fov_x / 2))
        projMatrix=torch.zeros((4, 4), dtype=torch.float32)
        projMatrix[3,2]=1.0
        
        #left、right,bottom,top 是近平面的边界
        top = tan_fov_y * znear
        bottom = -top
        right = tan_fov_x * znear
        left = -right
        projMatrix[0,0]=2*znear/(right-left)
        projMatrix[1,1]=2*znear/(top-bottom)
        projMatrix[2,2]=(zfar+znear)/(zfar-znear)
        projMatrix[2,3]=2*znear*zfar/(znear-zfar)
        
        return projMatrix
        
        
        
        
class GSDataLoader:
    def __init__(self,data_path,reading_dir,device):
        self.data_path=data_path
        self.reading_dir=reading_dir
        self.device=device
        self.cameras=[]#相机
        self.points={}#点云
        self.data_path=r'D:\CODE_ALL\3DGS-LISHA\data' 
        self.loadColmap(reading_dir)
    
    def loadColmap(self,reading_dir):
        
        cameras_intrinsic_path = os.path.join( self.data_path, "sparse", "0", "cameras.bin")
        cameras_extrinsic_path= os.path.join( self.data_path, "sparse", "0", "images.bin")
        #读取相机的外参
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_path)
        #读取相机的内参
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_path)
        #print('外参',cam_extrinsics)
        #print('内参',cam_intrinsics)
        
        #点云
        ply_path = os.path.join(self.data_path,  "sparse", "0", "points3D.ply")
        plydata = PlyData.read(ply_path)
        # 提取顶点数据
        vertices = plydata['vertex']
        x = vertices['x']
        y = vertices['y']
        z = vertices['z']
        positions = np.vstack([x,y, z]).T
        #print( positions)
        
        
        # 提取颜色数据
        red = vertices['red']
        green = vertices['green']
        blue = vertices['blue']
        colors = np.vstack([red,green, blue]).T/255.0
        #print( colors)
        
        # 提取法线数据
        nx = vertices['nx']
        ny = vertices['ny']
        nz = vertices['nz']
        normals = np.vstack([nx,ny, nz]).T
        #print( normals)
        self.points=PointClond(positions,colors,normals)
        
        # 读取图片
        # 将图像文件夹路径与数据根路径结合，形成完整的图像文件夹路径
        images_folder=os.path.join(self.data_path, reading_dir)
        for idx, key in enumerate(cam_extrinsics):
            #外参
            extr=cam_extrinsics[key]
            #内参
            intr=cam_intrinsics[extr.camera_id]
             # 获取图像的高度和宽度
            image_height = intr.height
            image_width = intr.width
            #id
            uid=intr.id
            # 将四元数转换为旋转矩阵，然后对旋转矩阵进行转置
            R = qvec2rotmat(extr.qvec) 
            #print(R)
            # 获取平移向量
            T = extr.tvec  
            # 获取图像路径并读取图像
            
            # 处理不同的相机模型
            # 判断相机模型是否为简单针孔相机
            if intr.model=="SIMPLE_PINHOLE":
                #焦距
                focal_length_x = intr.params[0]
                 # 通过焦距和height计算垂直方向视场角（摄像机在垂直方向上可见的视场范围的角度大小）
                fov_y= focal2fov(focal_length_x, image_height)
                 # 计算水平方向视场角
                fov_x = focal2fov(focal_length_x, image_width)
            elif intr.model=="PINHOLE":
                focal_length_x = intr.params[0]
                focal_length_y = intr.params[1]
                fov_y= focal2fov(focal_length_y, image_height)
                fov_x = focal2fov(focal_length_x, image_width)
            else:
                assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
            image_path = os.path.join(images_folder, extr.name)
            #读取图像
            image = cv2.imread(image_path)
            if image is None:
                print(f"Error loading image: {image_path}")
                continue
            #print(image_path)
            # 将图像颜色空间从BGR转换为RGB
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_info=ImageInfo(image_name=extr.name, 
                image_path=image_path, 
                image=image, 
                image_width=image_width, 
                image_height=image_height, 
                R=R, T=T, 
                fov_x=fov_x,
                fov_y=fov_y, 
                device=self.device)
            self.cameras.append(image_info)
               
    
# 测试
argv = ['--detect_anomaly', '--test_iterations', '100', '2000', '--save_iterations', '100', '2000']
args=ParamReader(argv)
device = torch.device('cpu')
data = GSDataLoader(args.source_path, args.images, device)
