In [None]:
import pandas as pd
import os
import io
import random
import ast
from PIL import Image
from tqdm import tqdm

import sys
sys.path.append('../src')

from color_palette_completion.utils.image_label_detector import detect_labels
from color_palette_completion.text_color_model.model_config import Config

MAX_LABELS = Config['Max_Image_Labels_Length']

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] ="/workspace/text2palette/profound-actor-347102-b0e7c4a8632c.json"

data_path = "/workspace/text2palette/color_palette_completion/data/t2p/text_miridih"
csv_path = "/workspace/text2palette/color_palette_completion/data/colors/data_colors_labels/text_palette_extracted.csv"
image_root_path = "/data/shared/xml_image_csv_converter/image/group_true/total_images"

lang_type = "_en"
filename_template = f"image_labels_imagemust_{{}}{lang_type}.txt"
splits = ['train', 'val', 'test']
split_ratios = [0.8, 0.1, 0.1]



df = pd.read_csv(csv_path)

df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)
total_len = len(df_shuffled)
train_end = int(total_len * split_ratios[0])
val_end = train_end + int(total_len * split_ratios[1])

df_split = {
    "train": df_shuffled.iloc[:train_end],
    "val": df_shuffled.iloc[train_end:val_end],
    "test": df_shuffled.iloc[val_end:]
}

def extract_labels_and_save(df_part, split):
    label_lines = []

    print(f"[{split}] Processing {len(df_part)} rows...")
    for _, row in tqdm(df_part.iterrows(), total=len(df_part), desc=f"Extracting labels ({split})"):
        image_files = row['image_file_name']
        group_key = row['group_key']
        group_id, sub_id = group_key.split('_')

        # 리스트 문자열이면 파싱
        try:
            image_files = ast.literal_eval(image_files) if isinstance(image_files, str) and image_files.startswith('[') else [image_files]
        except Exception as e:
            print(f"Could not parse image_file_name: {image_files}. Error: {e}")
            image_files = []

        labels_collected = []
        for image_file in image_files:
            image_path = os.path.join(image_root_path, group_id, sub_id, image_file)
            try:
                with open(image_path, 'rb') as img_file:
                    image_bytes = img_file.read()
                labels = detect_labels(image_bytes)
                label_descriptions = [label.description for label in labels]
                labels_collected.extend(label_descriptions)
            except Exception as e:
                print(f"Failed to process {image_path}: {e}")

        # 최대 label 수 제한
        label_collected_clipped = labels_collected[:MAX_LABELS]
        label_lines.append(str(label_collected_clipped))

    output_path = os.path.join(data_path, filename_template.format(split))
    with open(output_path, 'w', encoding='utf-8') as f:
        for line in label_lines:
            f.write(f"{line}\n")
    print(f"[{split}] Saved to {output_path}")

for split in splits:
    extract_labels_and_save(df_split[split], split)


[train] Processing 835 rows...


Extracting labels (train):  10%|█         | 86/835 [08:34<1:14:44,  5.99s/it]


KeyboardInterrupt: 