# extended version of the `scannet_openseg.py` script

In [1]:
import os
import torch
import imageio
import numpy as np
from glob import glob
from tqdm import tqdm
import tensorflow as tf2
import tensorflow.compat.v1 as tf
from tensorflow import io
from os.path import join, exists
from utils.fusion_util import PointCloudToImageMapper, save_fused_feature_no_args, read_bytes

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 1457
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
#!### Dataset specific parameters #####
img_dim = (320, 240) # original images from ScanNet are (640, 480) but in the preprocess they resize  images to (320, 240)
depth_scale = 1000.0
#######################################
visibility_threshold = 0.25 # threshold for the visibility check
cut_num_pixel_boundary = 10  # do not use the features on the image boundary
feat_dim = 768 # CLIP feature dimension

In [3]:
split = 'train'
if split== 'train': # for training set, export a chunk of point cloud
    n_split_points = 20000
    num_rand_file_per_scene = 5
else: # for the validation set, export the entire point cloud instead of chunks
    n_split_points = 2000000
    num_rand_file_per_scene = 1

DATA_DIR = "D:/AT3DCV_Data/Preprocessed_OpenScene/data"
DATA_ROOT = join(DATA_DIR, 'scannet_3d')
DATA_ROOT_2D = join(DATA_DIR,'scannet_2d')

data_paths = sorted(glob(join(DATA_ROOT, split, '*.pth')))
total_num = len(data_paths) # total number of samples in dataset

OUT_DIR = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_fused_features"

#load openseg model
model_path = "C:/Users/aorhu/Masaüstü/AT3DCV/repo/openseg_model"
openseg_model = tf2.saved_model.load(model_path,tags=[tf.saved_model.tag_constants.SERVING],)
text_emb = tf.zeros([1, 1, feat_dim]) # creating zero tensor for text embeddings

In [4]:
# load intrinsic parameter
intrinsics=np.loadtxt(os.path.join(DATA_ROOT_2D, 'intrinsics.txt'))

# calculate image pixel-3D points correspondances
point2img_mapper = PointCloudToImageMapper(
        image_dim=img_dim, intrinsics=intrinsics,
        visibility_threshold=visibility_threshold,
        cut_bound=cut_num_pixel_boundary)

process_id_range = ["0,100"]  # to only process samples in this range
id_range = None
if process_id_range is not None:
    id_range = [int(process_id_range[0].split(',')[0]), int(process_id_range[0].split(',')[1])]


In [5]:
# looping over samples
for i in tqdm(range(total_num)):
    # check the given range for the samples
    if id_range is not None and (i<id_range[0] or i>id_range[1]):
        print('skip ', i, data_paths[i])
        continue
    
    # extraction of the features starts here, not using the provided functions
    # ------------------------------------------------------------------------
    data_path = data_paths[i]
    #scene_id might be different depending on the path string, should be like "scene0000_00"
    scene_id = data_path.split('/')[-1].split('\\')[-1].split('_vh')[0]
    
    # load 3D data (point cloud)
    locs_in = torch.load(data_path)[0]
    n_points = locs_in.shape[0] # number of points
    
    n_interval = num_rand_file_per_scene
    n_finished = 0
    for n in range(n_interval):
        if exists(join(OUT_DIR, scene_id +'_%d.pt'%(n))):
            n_finished += 1
            print(scene_id +'_%d.pt'%(n) + ' already done!')
            continue
    if n_finished == n_interval:
        continue
    
    # short hand for processing 2D features
    scene = join(DATA_ROOT_2D, scene_id)
    img_dirs = sorted(glob(join(scene, 'color/*')), key=lambda x: int(os.path.basename(x)[:-4]))
    num_img = len(img_dirs) # number of images that the scene have
        
    # creating tensors to keep features per 3D point
    n_points_cur = n_points
    counter = torch.zeros((n_points_cur, 1), device = device)
    sum_features = torch.zeros((n_points_cur, feat_dim), device = device)

    vis_id = torch.zeros((n_points_cur, num_img), dtype=int, device=device)
    
    # process images per scene and fuse 2D-3D features
    for img_id, img_dir in enumerate(tqdm(img_dirs)):
        # load pose
        posepath = img_dir.replace('color', 'pose').replace('.jpg', '.txt')
        pose = np.loadtxt(posepath)
        # load depth and convert to meter
        depth = imageio.v2.imread(img_dir.replace('color', 'depth').replace('jpg', 'png')) / depth_scale # (240, 320)
        
        # calculate the 3d-2d mapping based on the depth
        mapping = np.ones([n_points, 4], dtype=int)
        """
        :pose: 4 x 4
        :locs_in: N x 3 format (point cloud)
        :depth: H x W format
        :return: mapping, N x 3 format, (H,W,mask)
        """
        mapping[:, 1:4] = point2img_mapper.compute_mapping(pose, locs_in, depth)
        if mapping[:, 3].sum() == 0: # no points corresponds to this image, skip
            continue     
        mapping = torch.from_numpy(mapping).to(device)
        mask = mapping[:, 3]    # [number of points]
        vis_id[:, img_id] = mask # masking the points corresponding to the image index of the scene, [number of points, num of images per scene]

        # extraction of 2D features with OpenSeg
        # load RGB image
        np_image_string = read_bytes(img_dir) #read_bytes is a simple function to read images as bytes for OpenSeg
        # run OpenSeg
        '''
        results is a dictionary that has = ['region_probs_', 'text_embedding', 'segm_proposal_feats',
                                            'image_embedding_feat', 'pixel_pred_confidence', 'segm_confidence', 
                                            'segm_prediction', 'images', 'region_embeddings', 'region_probs',
                                            'ppixel_ave_feat', 'ppixel_ave_feat_confidence', 'segm_confidence_rw',
                                            'image', 'segm_prediction_rw', 'image_info', 'ppixel_ave_feat_pred',
                                            'pixel_prediction', 'segm_proposal', 'region_logits']
        check OpenSeg repo for more information
        
        '''
        results = openseg_model.signatures['serving_default'](
                        inp_image_bytes = tf.convert_to_tensor(np_image_string),
                        inp_text_emb = text_emb)
        
        img_info = results['image_info']
        
        crop_sz = [
            int(img_info[0, 0] * img_info[2, 0]),
            int(img_info[0, 1] * img_info[2, 1])
        ]
        
        # this parameter is True by default in the original code, regardless of it the shape will be same
        regional_pool = True 
        if regional_pool:
            image_embedding_feat = results['ppixel_ave_feat'][:, :crop_sz[0], :crop_sz[1]] # shape will be : (1, 480, 640, 768)
        else:
            image_embedding_feat = results['image_embedding_feat'][:, :crop_sz[0], :crop_sz[1]] # shape will be : (1, 480, 640, 768)
        
        # resizing with nearest neighbors to 240x320
        img_size=[240, 320] # set to this in the original code
        if img_size is not None:
            feat_2d = tf.cast(tf.image.resize_nearest_neighbor(
                image_embedding_feat, img_size, align_corners=True)[0], dtype=tf.float16).numpy()
        else:
            feat_2d = tf.cast(image_embedding_feat[[0]], dtype=tf.float16).numpy()
        
        # reshaping for the fusion
        feat_2d = torch.from_numpy(feat_2d).permute(2, 0, 1)
        #without this conversion, it gives error while indexing
        mapping = mapping.to(int)
        # fusion
        feat_2d_3d = feat_2d[:, mapping[:, 1], mapping[:, 2]].permute(1, 0).to(device) #has the shape [81369, 768]
        
        # counting the image numbers and corresponding points
        counter[mask!=0]+= 1
        sum_features[mask!=0] += feat_2d_3d[mask!=0] #has the shape [81369, 768]    
        
    counter[counter==0] = 1e-5 # to prevent division by zero
    # dividing sum of the features per point by how many times the points had effect on the feature extraction
    feat_bank = sum_features/counter # [81369, 768]
    point_ids = torch.unique(vis_id.nonzero(as_tuple = False)[:, 0])
    
    #saving the fused features of randomly chosen n_split_points to total of num_rand_file_per_scene files
    save_fused_feature_no_args(feat_bank, point_ids, n_points, OUT_DIR, scene_id, num_rand_file_per_scene, n_split_points)


  0%|                                                                                         | 0/1201 [00:00<?, ?it/s]
  0%|                                                                                          | 0/279 [00:00<?, ?it/s][A
  0%|▎                                                                                 | 1/279 [00:05<27:00,  5.83s/it][A
  1%|▌                                                                                 | 2/279 [00:10<24:51,  5.38s/it][A
  1%|▉                                                                                 | 3/279 [00:16<24:42,  5.37s/it][A
  1%|█▏                                                                                | 4/279 [00:22<26:10,  5.71s/it][A
  2%|█▍                                                                                | 5/279 [00:28<27:13,  5.96s/it][A
  2%|█▊                                                                                | 6/279 [00:35<27:28,  6.04s/it][A
  3%|██            

KeyboardInterrupt: 