In [9]:
import os
import json
import numpy as np
from PIL import Image

In [10]:
def train_validation_test_split(raw_data, annotated_data, train_percent=0.8, validation_percent=0.1, test_percent=0.1):
  assert train_percent + validation_percent + test_percent == 1.0
  raw_files = os.listdir(raw_data)
  annotated_files = os.listdir(annotated_data)
  assert len(raw_files) == len(annotated_files)
  num_files = len(raw_files)
  # Adjust indices to match the range of filenames
  start_index = 1250
  end_index = 2422
  all_indices = np.arange(start_index, end_index + 1)
  
  train_size = int(num_files * train_percent)
  validation_size = int(num_files * validation_percent)
  test_size = num_files - train_size - validation_size
  
  train_indices = np.random.choice(all_indices, size=train_size, replace=False)
  remaining_indices = np.setdiff1d(all_indices, train_indices)
  validation_indices = np.random.choice(remaining_indices, size=validation_size, replace=False)
  test_indices = np.setdiff1d(remaining_indices, validation_indices)
  
  # This is for when files are not 0 indexed
  # all_indices = np.arange(1, num_files + 1)
    
  # train_indices = np.random.choice(all_indices, size=int(num_files * train_percent), replace=False)
  # remaining_indices = np.setdiff1d(all_indices, train_indices)
  # validation_indices = np.random.choice(remaining_indices, size=int(num_files * validation_percent), replace=False)
  # test_indices = np.setdiff1d(remaining_indices, validation_indices)

  # train_indices = np.random.choice(num_files, size=int(num_files * train_percent), replace=False)
  # remaining_indices = np.setdiff1d(np.arange(num_files), train_indices)
  # validation_indices = np.random.choice(remaining_indices, size=int(num_files * validation_percent), replace=False)
  # test_indices = np.setdiff1d(remaining_indices, validation_indices)



  print(f"Train indices: {train_indices}")
  print(f"Validation indices: {validation_indices}")
  print(f"Test indices: {test_indices}")
  
  for raw, ann in zip(sorted(raw_files), sorted(annotated_files)):
    raw_index = int(raw.split('.')[0])
    ann_index = int(ann.split('.')[0])
    assert raw_index == ann_index
    if (raw_index not in train_indices) and (raw_index not in validation_indices) and (raw_index not in test_indices):
      print(f"Index {raw_index} not found in any set")
      raise ValueError('Index not found in any set')
    
  return train_indices, validation_indices, test_indices

train, validate, test = train_validation_test_split('_outRaw', '_outSeg')

Train indices: [2029 2053 1499 1625 1505 1986 1389 2355 1719 1382 1766 2101 2181 1365
 2287 1497 2009 2222 2176 1823 1832 1966 1929 2095 2157 1671 2044 2229
 2285 1675 2063 1301 1847 1950 2085 2004 2162 1291 1969 1530 1926 1759
 2302 2072 2305 1346 2066 1713 1296 1651 2313 1466 1456 2385 2244 2266
 1778 1286 1454 1692 1399 2121 1963 2401 1890 1476 1641 1288 1811 1922
 1822 1297 1894 1621 1405 1691 2352 1733 1643 1886 2346 1867 1659 1520
 1404 2348 2413 2361 1725 1282 1806 1941 2188 1752 1341 2146 1687 2215
 1870 2142 2137 1632 1934 2128 1852 2048 2069 1794 2164 1810 1787 1280
 2052 2040 2321 2193 1739 1703 1819 1997 1546 2325 1250 2297 1374 1570
 2135 1948 1644 2171 1287 2196 1907 1377 2086 1732 1561 1313 2104 2235
 1770 2295 1568 1710 2015 2068 1517 1831 2129 2376 2013 1269 1633 1281
 1441 1348 1585 1707 2230 1800 2125 2245 2079 1783 2299 1680 1565 2336
 1962 1440 1335 1532 1519 1381 1375 2189 2271 2074 1987 2247 1861 1745
 1805 1995 2309 2132 2103 2255 2265 1803 2221 1576 1697 1755 1

In [11]:
def make_odgt(raw_data, annotated_data, train_idx, validate_idx, test_idx, out_dir):
  raw_files = os.listdir(raw_data)
  annotated_files = os.listdir(annotated_data)
  assert len(raw_files) == len(annotated_files)
  
  # Create output directory if it does not exist
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)
  
  for raw, ann in zip(sorted(raw_files), sorted(annotated_files)):
    raw_index = int(raw.split('.')[0])
    ann_index = int(ann.split('.')[0])
    assert raw_index == ann_index

    raw_path = os.path.join(raw_data, raw)
    ann_path = os.path.join(annotated_data, ann)
    raw_img = Image.open(raw_path)
    ann_img = Image.open(ann_path)
    width, height = raw_img.size
    ann_width, ann_height = ann_img.size
    assert width == ann_width
    assert height == ann_height
    
    #change paths to be enclosed in double quotes instead of single quotes
    odgt_line = {"fpath_img": raw_path, "fpath_segm": ann_path, "width": width, "height": height}
    if raw_index in train_idx:
      with open(os.path.join(out_dir, 'train.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    elif raw_index in validate_idx:
        with open(os.path.join(out_dir, 'validate.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    elif raw_index in test_idx:
        with open(os.path.join(out_dir, 'test.odgt'), 'a', encoding='utf-8') as f:
          f.write(json.dumps(odgt_line) + '\n')
    
make_odgt('_outRaw', '_outSeg', train, validate, test, 'odgt')