In [3]:
import os
import json
import numpy as np
from PIL import Image

In [7]:
def train_validation_test_split(raw_data, annotated_data, train_percent=0.8, validation_percent=0.1, test_percent=0.1):
  assert train_percent + validation_percent + test_percent == 1.0
  raw_files = os.listdir(raw_data)
  annotated_files = os.listdir(annotated_data)
  assert len(raw_files) == len(annotated_files)
  num_files = len(raw_files)

  filename_range = [int(f.split('.')[0]) for f in raw_files]
  # Adjust indices to match the range of filenames
  start_index = min(filename_range)
  end_index = max(filename_range)
  all_indices = np.arange(start_index, end_index + 1)
  
  train_size = int(num_files * train_percent)
  validation_size = int(num_files * validation_percent)
  test_size = num_files - train_size - validation_size
  
  train_indices = np.random.choice(all_indices, size=train_size, replace=False)
  remaining_indices = np.setdiff1d(all_indices, train_indices)
  validation_indices = np.random.choice(remaining_indices, size=validation_size, replace=False)
  test_indices = np.setdiff1d(remaining_indices, validation_indices)
  
  # This is for when files are not 0 indexed
  # all_indices = np.arange(1, num_files + 1)
    
  # train_indices = np.random.choice(all_indices, size=int(num_files * train_percent), replace=False)
  # remaining_indices = np.setdiff1d(all_indices, train_indices)
  # validation_indices = np.random.choice(remaining_indices, size=int(num_files * validation_percent), replace=False)
  # test_indices = np.setdiff1d(remaining_indices, validation_indices)

  # train_indices = np.random.choice(num_files, size=int(num_files * train_percent), replace=False)
  # remaining_indices = np.setdiff1d(np.arange(num_files), train_indices)
  # validation_indices = np.random.choice(remaining_indices, size=int(num_files * validation_percent), replace=False)
  # test_indices = np.setdiff1d(remaining_indices, validation_indices)



  # print(f"Train indices: {train_indices}")
  # print(f"Validation indices: {validation_indices}")
  # print(f"Test indices: {test_indices}")
  
  for raw, ann in zip(sorted(raw_files), sorted(annotated_files)):
    raw_index = int(raw.split('.')[0])
    ann_index = int(ann.split('.')[0])
    assert raw_index == ann_index
    if (raw_index not in train_indices) and (raw_index not in validation_indices) and (raw_index not in test_indices):
      print(f"Index {raw_index} not found in any set")
      raise ValueError('Index not found in any set')
    
  return train_indices, validation_indices, test_indices

data_root_dir = '/Data/dataLIDAR/0221-1817_seed_9876/clear_night'
raw_data_dir = os.path.join(data_root_dir, 'rgb')
annotated_data_dir = os.path.join(data_root_dir, 'rgb_seg')
train, validate, test = train_validation_test_split(raw_data_dir, annotated_data_dir)

In [8]:
def make_odgt(raw_data, annotated_data, train_idx, validate_idx, test_idx, out_dir):
  raw_files = os.listdir(raw_data)
  annotated_files = os.listdir(annotated_data)
  assert len(raw_files) == len(annotated_files)
  
  # Create output directory if it does not exist
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)
  
  for raw, ann in zip(sorted(raw_files), sorted(annotated_files)):
    raw_index = int(raw.split('.')[0])
    ann_index = int(ann.split('.')[0])
    assert raw_index == ann_index

    raw_path = os.path.join(raw_data, raw)
    ann_path = os.path.join(annotated_data, ann)
    raw_img = Image.open(raw_path)
    ann_img = Image.open(ann_path)
    width, height = raw_img.size
    ann_width, ann_height = ann_img.size
    assert width == ann_width
    assert height == ann_height
    
    #change paths to be enclosed in double quotes instead of single quotes
    odgt_line = {"fpath_img": raw_path, "fpath_segm": ann_path, "width": width, "height": height}
    if raw_index in train_idx:
      with open(os.path.join(out_dir, 'train.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    elif raw_index in validate_idx:
        with open(os.path.join(out_dir, 'validate.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    elif raw_index in test_idx:
        with open(os.path.join(out_dir, 'test.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    
make_odgt(raw_data_dir, annotated_data_dir, train, validate, test, 'odgt')