In [13]:
import os
import json
import numpy as np
import re
from PIL import Image
from collections import defaultdict


In [14]:
def extract_numbers(filename):
    numbers = re.findall(r'\d+', filename)
    return tuple(map(int, numbers)) if numbers else (float('inf'),)
def get_base_name(filepath):
    return os.path.splitext(os.path.basename(filepath))[0]

In [15]:
def make_odgt(data_dir, out_file, original_odgt):
    seg_files_dict = defaultdict(list)
    
    with open(original_odgt) as f:
        for line in f:
            match = re.search(r'"fpath_segm":\s*"(.*?)"', line)
            if match:
                full_path = match.group(1)
                base_name = get_base_name(full_path)
                seg_files_dict[base_name].append(full_path)
    
    img_files = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.png'):
                img_files.append(os.path.join(root, file))
    
    img_files.sort(key=lambda x: extract_numbers(os.path.basename(x)))
    
    if not os.path.exists(os.path.dirname(out_file)):
        os.makedirs(os.path.dirname(out_file))
    
    matched_entries = []
    seg_counters = defaultdict(int)
    
    for img_path in img_files:
        img_base_name = re.match(r'(\d+)', get_base_name(img_path))
        if img_base_name:
            img_base_name = img_base_name.group(1)
            if img_base_name in seg_files_dict:
                seg_index = seg_counters[img_base_name]
                if seg_index < len(seg_files_dict[img_base_name]):
                    seg_path = seg_files_dict[img_base_name][seg_index]
                    seg_counters[img_base_name] += 1
                    img = Image.open(img_path)
                    img_info = {
                        'fpath_img': img_path,
                        'fpath_segm': seg_path,
                        'width': img.width,
                        'height': img.height
                    }
                    matched_entries.append(img_info)
    
    with open(out_file, 'w', encoding='utf-8') as f:
        for entry in matched_entries:
            f.write(json.dumps(entry) + '\n')

    

data_dir = '/home/zhaob/Desktop/semantic-segmentation-pytorch/new_data/rainy_deweathered'
original_odgt = '/home/zhaob/Desktop/semantic-segmentation-pytorch/new_data/odgt_rainy_day/test.odgt'
out_file = '/home/zhaob/Desktop/semantic-segmentation-pytorch/new_data/rainy_deweatheredJulian/test.odgt'
make_odgt(data_dir, out_file, original_odgt)
