In [4]:
import os
import json
import numpy as np
from PIL import Image


In [5]:
def train_validation_test_split(files_per_weather, train_percent=0.8, validation_percent=0.1, test_percent=0.1):
    """Splits the dataset into train, validation, and test sets for each weather condition."""
    assert train_percent + validation_percent + test_percent == 1.0, "Splits must sum to 1."

    train_indices, validation_indices, test_indices = {}, {}, {}

    for weather, files in files_per_weather.items():
        num_files = len(files)
        train_size = int(num_files * train_percent)
        validation_size = int(num_files * validation_percent)

        shuffled_files = np.random.permutation(files)
        train_indices[weather] = shuffled_files[:train_size]
        validation_indices[weather] = shuffled_files[train_size:train_size + validation_size]
        test_indices[weather] = shuffled_files[train_size + validation_size:]

    return train_indices, validation_indices, test_indices

In [6]:
def make_odgt(raw_folders, seg_folders, raw_data, annotated_data, train, validate, test, output_dir):
    datasets = {
        'train': train,
        'validate': validate,
        'test': test
    }
    odgt_files = {key: open(os.path.join(output_dir, f'{key}.odgt'), 'w') for key in datasets.keys()}

    for weather, files in raw_data.items():
        valid_files = [f for f in files if f.endswith('.png') and f.split('.')[0].isdigit()]
        for raw in valid_files:
            try:
                raw_index = int(raw.split('.')[0])
            except ValueError:
                print(f"Skipping {raw} in {weather}")
                continue
            raw_path = os.path.abspath(os.path.join(raw_folders[weather], raw))
            ann_path = os.path.abspath(os.path.join(seg_folders[weather], raw))

            # Debugging statements
            if not os.path.exists(raw_path):
                print(f"Raw file not found: {raw_path}")
                continue
            if not os.path.exists(ann_path):
                print(f"Annotation file not found: {ann_path}")
                continue

            raw_img = Image.open(raw_path)
            ann_img = Image.open(ann_path)
            assert raw_img.size == ann_img.size, f"Size mismatch for {raw} in {weather}"

            odgt_line = json.dumps({
                "fpath_img": raw_path,
                "fpath_segm": ann_path,
                "width": raw_img.width,
                "height": raw_img.height,
                "weather": weather  # Store weather condition in metadata
            })

            for key, indices in datasets.items():
                if raw_index in indices[weather]:
                    odgt_files[key].write(odgt_line + '\n')

    for f in odgt_files.values():
        f.close()

# Define paths
data_root_dir = '/home/zhaob/Desktop/semantic-segmentation-pytorch/1_17_clear_day_mixed'
weather_conditions = ["_outRaw", "_outRaw_foggy", "_outRaw_night"]
raw_folders = {w: os.path.join(data_root_dir, w) for w in weather_conditions}
seg_folders = {w: os.path.join(data_root_dir, w.replace("_outRaw", "_outSeg")) for w in weather_conditions}

# Collect all files
files_per_weather = {w: sorted(os.listdir(raw_folders[w])) for w in weather_conditions}

# Prepare raw_data and annotated_data dictionaries
raw_data = {w: [f for f in os.listdir(raw_folders[w]) if f.endswith('.png') and os.path.isfile(os.path.join(raw_folders[w], f))] for w in weather_conditions}
annotated_data = {w: [f for f in os.listdir(seg_folders[w]) if f.endswith('.png') and os.path.isfile(os.path.join(seg_folders[w], f))] for w in weather_conditions}

# Split dataset
def train_validation_test_split(files_per_weather):
    train = {w: [] for w in weather_conditions}
    validate = {w: [] for w in weather_conditions}
    test = {w: [] for w in weather_conditions}
    for weather, files in files_per_weather.items():
        total_files = len(files)
        train[weather] = list(range(0, int(0.7 * total_files)))
        validate[weather] = list(range(int(0.7 * total_files), int(0.85 * total_files)))
        test[weather] = list(range(int(0.85 * total_files), total_files))
    return train, validate, test

train, validate, test = train_validation_test_split(files_per_weather)

# Create ODGT
make_odgt(raw_folders, seg_folders, raw_data, annotated_data, train, validate, test, 'odgt')