In [22]:
import os
import sys

import torch.utils
import torch.utils.data
sys.path.append('./')
sys.path.append('../')
sys.path.insert(0, '../6DoF/')
import argparse
import json
import glob

import numpy as np
import torch
from PIL import Image
import open3d as o3d
import trimesh
import logging
from logging import getLogger as get_logger

from koolai_dataset import KoolAIPanoData, KoolAIPersData
from utils.typing import *
from utils.misc import get_device, todevice

logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

transformation matrix for visualization in opengl

In [19]:
R_gl_cv = np.asarray([
            [1.0,  0.0,  0.0],
            [0.0, -1.0,  0.0],
            [0.0,  0.0, -1.0],
        ])
R_cv_gl = np.linalg.inv(R_gl_cv)

In [20]:
def parse_args():
    parser = argparse.ArgumentParser(description='Train a model on KoolAID dataset')
    parser.add_argument('--root_data_dir', type=str, default='/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/')
    parser.add_argument('--train_split_file', type=str, default='/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/train.txt')
    parser.add_argument('--test_split_file', type=str, default='/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/test.txt')
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--dataloader_num_workers', type=int, default=4)
    parser.add_argument('--skip_calc_scene_scale', action='store_true')
    parser.add_argument('--skip_save_scene_scale', action='store_true')
    
    return parser.parse_args()


def merge_split_files(file_path_lst: List[str], merge_file_path: str):
    with open(merge_file_path, 'w') as f:
        for file_path in file_path_lst:
            with open(file_path, 'r') as f_split:
                for line in f_split.readlines():
                    f.write(line)
                f.write('\n')

In [21]:

if __name__ == '__main__':    
    
    root_data_dir = '/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/'
    train_split_file = '/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/perspective_train.txt'
    test_split_file = '/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/perspective_test.txt'
    train_batch_size = 8
    dataloader_num_workers = 4
    skip_calc_scene_scale = False
    skip_save_scene_scale = False
    
    device = get_device()
    
    # in case the data was processed by kfp, it will results in many train.txt-0000xx files and test.txt-0000xx files
    train_file_paths = glob.glob(train_split_file + '*')
    test_file_paths = glob.glob(test_split_file + '*')
    
    assert len(train_file_paths) >= 1, f'found train split files: {train_file_paths}'
    if len(train_file_paths) > 1 and len(test_file_paths) > 1:
        if 'paorama_train' in train_split_file:
            save_train_split_file = train_split_file.replace('paorama_train.txt', 'panorama_train.txt')
        else:
            save_train_split_file = train_split_file
        save_test_split_file = test_split_file
        # merge split files into a single one
        merge_split_files(file_path_lst=train_file_paths, merge_file_path=save_train_split_file)
        merge_split_files(file_path_lst=test_file_paths, merge_file_path=save_test_split_file)
        # remove the original split files
        os.system(f'rm {train_split_file}-0000*')
        os.system(f'rm {test_split_file}-0000*')
        
        train_split_file = save_train_split_file
        test_split_file = save_test_split_file
    else:
        train_split_file = train_file_paths[0]
        test_split_file = test_file_paths[0]
    
    # prepare dataset
    T_in = 8
    T_out = 8
    train_dataset = KoolAIPersData(root_dir=root_data_dir,
                                   split_filepath=train_split_file,
                                    image_height=256, 
                                    image_width=256,
                                    total_view=16,
                                    validation=False,
                                    T_in=T_in,
                                    T_out=T_out,
                                    fix_sample=False,)
    
    # for training
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=train_batch_size,
        num_workers=dataloader_num_workers,
        collate_fn=train_dataset.collate
    )
    
    if not skip_calc_scene_scale:
    # calculate all rooms' scale
        room_scale_lst = []
        for idx in range(len(train_dataset)):
            room_uid = train_dataset.room_ids[idx]
            room_distance_scale = train_dataset.compute_room_scale(idx)
            room_scale_lst.append(torch.tensor([room_distance_scale], dtype=torch.float32))
            logger.info(f' room {room_uid} distance scale: {room_distance_scale}')

            
        room_scales = torch.cat(room_scale_lst, dim=0)    
        # draw room scale histogram
        import matplotlib.pyplot as plt
        plt.hist(room_scales.cpu().numpy(), bins=100)
        plt.show()
        
        # filter room_scales 
        import scipy.stats
        zscore = scipy.stats.zscore(room_scales.cpu().numpy())
        filtered_room_entries = np.abs(zscore) < 3
        logger.info(f'max zscore: {np.max(zscore)}')
        
        # we take 95% quantile as the new scene scale, to remove some rooms with invaliid scale
        # new_scene_scale = torch.quantile(room_scales, 0.99).item()    
        new_scene_scale = torch.max(room_scales[filtered_room_entries]).item()
        logger.info(f'training dataset room scale: {new_scene_scale}')
    else:
        # TODO: manually set the new_scene_scale
        new_scene_scale = 7.609875202178955
    new_scene_scale = 1.0 / new_scene_scale
    
    if not skip_save_scene_scale:
        for idx in range(len(train_dataset)):
            room_uid = train_dataset.room_ids[idx]
            room_folderpath = os.path.join(root_data_dir, room_uid)
            room_meta_filepath = os.path.join(room_folderpath, 'room_meta.json')
            logger.info(f'correct scene scale for room: {room_meta_filepath}')
            
            with open(room_meta_filepath, 'r') as f:
                room_meta = json.load(f)
                
            scale_mat = np.array(room_meta['scale_mat']).reshape(4,4).astype(np.float32)
            original_scale = float(room_meta['scale'])
            
            # savee corrected scale_mat
            camera_center = scale_mat[:3, 3] / original_scale
            
            new_scale_mat = np.eye(4).astype(np.float32)
            new_scale_mat[:3, 3] = camera_center
            new_scale_mat[:3] *= new_scene_scale
            
            room_meta['new_scale_mat'] = new_scale_mat.flatten().tolist()
            room_meta['new_scale'] = new_scene_scale
            json.dump(room_meta, open(room_meta_filepath, 'w'), indent=4)
            
            # visualize normalized camera poses
            pose_mesh = o3d.geometry.TriangleMesh()
            
            camera_metas = room_meta['cameras']
            for cam_idx in range(len(camera_metas)):     
                cam_meta = camera_metas[str(cam_idx)]             
                # w2c pose
                pose = np.array(cam_meta['camera_transform']).reshape(4, 4)
                c2w_pose = np.linalg.inv(pose)
                # scale pose_c2w
                c2w_pose = new_scale_mat @ c2w_pose
                R_c2w = c2w_pose[:3, :3]
                q_c2w = trimesh.transformations.quaternion_from_matrix(R_c2w)
                q_c2w = trimesh.transformations.unit_vector(q_c2w)
                R_c2w = trimesh.transformations.quaternion_matrix(q_c2w)[:3, :3]
                c2w_pose[:3, :3] = R_c2w
                
                T = c2w_pose
                T[:3, :3] = T[:3, :3] @ R_cv_gl
                pose_mesh += o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.01).transform(T)
                o3d.io.write_triangle_mesh(os.path.join(room_folderpath, 'pose_mesh.ply'), pose_mesh)
            

    # # test on a specific room
    # room_folderpath = os.path.join(root_data_dir, '3FO4K5FWN1Q2/panorama/room_824')
    # room_meta_filepath = os.path.join(room_folderpath, 'room_meta.json')
    # logger.info(f'correct scene scale for room: {room_meta_filepath}')
        
    # with open(room_meta_filepath, 'r') as f:
    #     room_meta = json.load(f)
        
    # scale_mat = np.array(room_meta['scale_mat']).reshape(4,4).astype(np.float32)
    # original_scale = float(room_meta['scale'])
    
    # # savee corrected scale_mat
    # camera_center = scale_mat[:3, 3] / original_scale
    # logger.info(f'original camera_center: {-camera_center}')
    
    # new_scale_mat = np.eye(4).astype(np.float32)
    # new_scale_mat[:3, 3] = camera_center
    # new_scale_mat[:3] *= new_scene_scale
    
    # room_meta['new_scale_mat'] = new_scale_mat.flatten().tolist()
    # room_meta['new_scale'] = new_scene_scale
    # json.dump(room_meta, open(room_meta_filepath, 'w'), indent=4)
    
    # # visualize normalized camera poses
    # pose_mesh = o3d.geometry.TriangleMesh()
    
    # camera_metas = room_meta['cameras']
    # for cam_idx in range(len(camera_metas)):     
    #     cam_meta = camera_metas[str(cam_idx)]       
    #     # w2c pose
    #     pose = np.array(cam_meta['camera_transform']).reshape(4, 4)
    #     c2w_pose = np.linalg.inv(pose)
    #     # scale pose_c2w
    #     c2w_pose = new_scale_mat @ c2w_pose
    #     R_c2w = c2w_pose[:3, :3]
    #     q_c2w = trimesh.transformations.quaternion_from_matrix(R_c2w)
    #     q_c2w = trimesh.transformations.unit_vector(q_c2w)
    #     R_c2w = trimesh.transformations.quaternion_matrix(q_c2w)[:3, :3]
    #     c2w_pose[:3, :3] = R_c2w
        
    #     T = c2w_pose
    #     T[:3, :3] = T[:3, :3] @ R_cv_gl
    #     pose_mesh += o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.01).transform(T)
        
    #     o3d.io.write_triangle_mesh(os.path.join(room_folderpath, 'pose_mesh.ply'), pose_mesh)





ic| self.num_all_rooms: 202, self.num_all_views: 81272


FileNotFoundError: [Errno 2] No such file or directory: '/seaweedfs/training/dataset/qunhe/PanoRoom/processed_data_20240312/3FO4K5FWGG13/perspective/room_702/depth'