In [None]:
#!/usr/bin/env python3
"""
parent_split_and_run.py

1) D:\golfDataset\dataset\train 내 CSV 이름으로 ID 수집 → 90:10 split
2) train_ids.txt / valid_ids.txt 생성
3) train/valid 각각 PKL(annotations + split) 생성
4) extract_embedding_timesformer.py, extract_embedding_stgcn.py 호출하여 임베딩 추출

반드시 mmaction 환경에서 실행해야함,
timesformer는 txt를 통해 데이터를 받기에 numpy, pands, pkle 모듈에 상관없지만,
mmaction은 pkl 파일을 직접 읽어야 하므로 numpy._core ↔ numpy.core 호환 패치가 필요하다.
"""



# ── NumPy pickle-호환 패치: numpy._core ↔ numpy.core ─────────────
import numpy as np, sys
sys.modules.setdefault('numpy._core', np.core)     # ← 핵심 한 줄
# ────────────────────────────────────────────────────────────────

import random
import subprocess
import pickle
import pandas as pd
from pathlib import Path

# ───────────────────────────────────────────────────────────────
# 설정
ROOT       = Path(r"D:\golfDataset\dataset\train")
TEST_RATIO = 0.1
SEED       = 42

TS_ENV     = 'timesformer'
STGCN_ENV  = 'mmaction'
# ───────────────────────────────────────────────────────────────

# Body25 → COCO17 인덱스 매핑
MAPPING_BODY25_TO_COCO17 = [
    0,16,15,18,17,
    5,2,6,3,7,
    4,12,9,13,10,
    14,11
]


def run(cmd, env):
    full = ['conda', 'run', '-n', env, '--no-capture-output'] + cmd
    print('[RUN]', *full)
    subprocess.run(full, check=True)


def load_and_process(csv_path: Path,
                     img_shape=(1080,1920),
                     confidence_threshold=0.1,
                     normalize_method='0to1') -> dict:
    df = pd.read_csv(csv_path)
    T, cols = df.shape
    V25 = 25
    kp25 = np.zeros((1, T, V25, 2), dtype=np.float32)
    score25 = np.zeros((1, T, V25), dtype=np.float32)
    for t, row in enumerate(df.values):
        vals = row.reshape(V25, 3)
        kp25[0, t] = vals[:, :2]
        score25[0, t] = vals[:, 2]
    mask = score25 < confidence_threshold
    kp25[mask] = 0
    score25[mask] = 0
    h, w = img_shape
    if normalize_method == '0to1':
        kp25[..., 0] /= w
        kp25[..., 1] /= h
    kp17 = kp25[:, :, MAPPING_BODY25_TO_COCO17, :]
    score17 = score25[:, :, MAPPING_BODY25_TO_COCO17]
    return {
        'total_frames': T,
        'img_shape': img_shape,
        'original_shape': img_shape,
        'keypoint': kp17,
        'keypoint_score': score17
    }


def make_pkl(id_list, out_path: Path):
    annotations = []
    for vid in id_list:
        csv_file = None
        label = None
        for cat in ['balanced_true','false']:
            p = ROOT / cat / 'crop_keypoint' / f"{vid}.csv"
            if p.exists():
                csv_file = p
                label = 1 if cat=='balanced_true' else 0
                break
        if csv_file is None:
            print(f"[WARN] CSV not found for id={vid}")
            continue
        info = load_and_process(csv_file)
        info.update({
            'frame_dir': vid,
            'label': label,
            'img_shape': info['img_shape'],
            'original_shape': info['original_shape'],
            'metainfo': {'frame_dir':vid, 'img_shape':info['img_shape']}
        })
        annotations.append(info)
    data = {'annotations':annotations, 'split':{'xsub_val':id_list}}
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, 'wb') as f:
        pickle.dump(data, f, protocol=4)
    return len(annotations)


def main():
    # 1) ID 수집 및 split
    ids = []
    for cat in ['balanced_true','false']:
        kp_dir = ROOT / cat / 'crop_keypoint'
        vid_dirs = [ROOT/cat/'crop_video', ROOT/cat/'video']
        if not kp_dir.exists(): continue
        for csv_path in kp_dir.glob('*.csv'):
            vid_id = csv_path.stem
            if any(vd.glob(f"{vid_id}*_crop.mp4") for vd in vid_dirs):
                ids.append(vid_id)
    if not ids:
        raise RuntimeError(f"No matching CSV↔MP4 under {ROOT}")
    random.seed(SEED)
    random.shuffle(ids)
    idx = int(len(ids)*(1-TEST_RATIO))
    train_ids, valid_ids = ids[:idx], ids[idx:]

    # 2) ID 리스트 저장
    train_list = Path('train_ids.txt').resolve()
    valid_list = Path('valid_ids.txt').resolve()
    train_list.write_text('\n'.join(train_ids))
    valid_list.write_text('\n'.join(valid_ids))
    print(f"▶ {len(train_ids)} train / {len(valid_ids)} valid IDs saved")

    # 3) PKL 생성
    base_pkl = ROOT/'crop_pkl'
    train_pkl = base_pkl/'skeleton_dataset_train.pkl'
    valid_pkl = base_pkl/'skeleton_dataset_valid.pkl'
    tcnt = make_pkl(train_ids, train_pkl)
    vcnt = make_pkl(valid_ids, valid_pkl)
    print(f"▶ PKL created: train={tcnt}, valid={vcnt}")
    """
    # 4) TimeSformer 임베딩
    run([
        'python', '-u', 'extract_embedding_timesformer.py',
        '--root',       str(ROOT),
        '--train-list', str(train_list),
        '--valid-list', str(valid_list),
        '--num-frames','32',
        '--clips-per-vid','5',
        '--img-size','224',
        '--batch-size','1',
        '--num-workers','0',
        '--pretrained', r"D:\TimeSformer\pretrained\TimeSformer_divST_96x4_224_K600.pyth",
        '--output-dir','embbeding_data\timesformer'
    ], TS_ENV)
    """
    # timesformer는 너무 오래 걸려서 주석처리함
    
    print(str(train_pkl), str(valid_pkl))
    # 5) ST-GCN 임베딩 (PKL 경로 인수로 전달), stgcn의 경우 내부 경로 수정때문에 절대경로로 지정해 줘야한다
    run([
        'python', '-u', 'D:\\Jabez\\golf\\fusion\\extract_embedding_stgcn.py',
        '--cfg',        'D:\\mmaction2\\configs\\skeleton\\stgcnpp\\my_stgcnpp.py',
        '--ckpt',       'D:\\mmaction2\\checkpoints\\stgcnpp_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d_20221228-86e1e77a.pth',
        '--device',     'cuda:0',
        '--out-dir',    'D:\\Jabez\\golf\\fusion\\embbeding_data\\stgcnpp',
        '--train-pkl',  str(train_pkl),
        '--valid-pkl',  str(valid_pkl),
        '--num-workers','0'
    ], STGCN_ENV)


    print("✅ All done.")

if __name__=='__main__':
    main()


▶ 392 train / 44 valid IDs saved
▶ PKL created: train=392, valid=44
D:\golfDataset\dataset\train\crop_pkl\skeleton_dataset_train.pkl D:\golfDataset\dataset\train\crop_pkl\skeleton_dataset_valid.pkl
[RUN] conda run -n mmaction --no-capture-output python -u D:\Jabez\golf\fusion\extract_embedding_stgcn.py --cfg D:\mmaction2\configs\skeleton\stgcnpp\my_stgcnpp.py --ckpt D:\mmaction2\checkpoints\stgcnpp_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d_20221228-86e1e77a.pth --device cuda:0 --out-dir D:\Jabez\golf\fusion\embbeding_data\stgcnpp --train-pkl D:\golfDataset\dataset\train\crop_pkl\skeleton_dataset_train.pkl --valid-pkl D:\golfDataset\dataset\train\crop_pkl\skeleton_dataset_valid.pkl --num-workers 0
✅ All done.
