In [None]:
import random
import json
import os
from collections import defaultdict
from tqdm import tqdm

def load_list(file_path):
    id_to_imgs = defaultdict(list)
    with open(file_path, 'r') as f:
        for line in f:
            img_path, pid = line.strip().split()
            pid = int(pid)
            id_to_imgs[pid].append(img_path)
    return id_to_imgs

def generate_pairs(id_to_imgs, img_root):
    positive_pairs = []
    negative_pairs = []
    pids = list(id_to_imgs.keys())

    for pid in tqdm(pids, desc="Generating pairs"):
        imgs = id_to_imgs[pid]
        # positive pairs
        if len(imgs) >= 2:
            for i in range(len(imgs)):
                for j in range(i+1, len(imgs)):
                    positive_pairs.append((imgs[i], imgs[j], "yes"))
        
        # negative pair
        neg_pid = random.choice([p for p in pids if p != pid])
        img1 = random.choice(imgs)
        img2 = random.choice(id_to_imgs[neg_pid])
        negative_pairs.append((img1, img2, "no"))

    all_pairs = positive_pairs + negative_pairs
    random.shuffle(all_pairs)

    samples = []
    for img1, img2, label in all_pairs:
        item = {
            "system_prompt": "You are a helpful assistant.",
            "image": [os.path.join(img_root, img1), os.path.join(img_root, img2)],
            "conversation": [
                {
                    "role": "user",
                    "content": "Q: <image><image> Are these two people the same?"
                },
                {
                    "role": "assistant",
                    "content": label
                }
            ]
        }
        samples.append(item)
    return samples

def main():
    # 설정
    list_train_path = 'MSMT17_V1/list_train.txt'
    list_val_path = 'MSMT17_V1/list_val.txt'
    train_img_root = 'MSMT17_V1/train/'  # train 이미지 루트
    val_img_root = 'MSMT17_V1/test/'     # val 이미지 루트
    output_dir = './'          # 결과 저장 폴더

    os.makedirs(output_dir, exist_ok=True)

    # train 데이터
    id_to_imgs_train = load_list(list_train_path)
    train_samples = generate_pairs(id_to_imgs_train, train_img_root)

    # val 데이터
    id_to_imgs_val = load_list(list_val_path)
    val_samples = generate_pairs(id_to_imgs_val, val_img_root)

    # JSON 파일 저장
    with open(os.path.join(output_dir, 'msmt17_train_llava_format.json'), 'w') as f:
        json.dump(train_samples, f, indent=4)

    with open(os.path.join(output_dir, 'msmt17_val_llava_format.json'), 'w') as f:
        json.dump(val_samples, f, indent=4)

    print(f"✅ 저장 완료! (train: {len(train_samples)}개, val: {len(val_samples)}개)")

if __name__ == "__main__":
    main()

Generating pairs: 100%|██████████| 1041/1041 [00:00<00:00, 5025.10it/s]
Generating pairs: 100%|██████████| 1041/1041 [00:00<00:00, 32814.54it/s]
