In [None]:
#!/usr/bin/env python3
"""
combined_extract_embedding_auto.py

csv_to_pkl.ipynb, extract_embbeding_stgcn.ipynb, extract_embbeding_timesformer.ipynb
세 노트북의 변수(경로)와 Anaconda 환경 이름을 자동으로 읽어와,
CSV→PKL 변환 및 ST-GCN(mmAction2)/TimeSformer 임베딩을 일괄 처리합니다.
"""
import nbformat
import re
import subprocess
import os
from pathlib import Path

# Conda 환경 이름을 환경 변수 또는 기본값으로 설정
TS_ENV = os.environ.get('TS_ENV', 'timesformer_env')       # TimeSformer 전용 Conda 환경 이름
STGCN_ENV = os.environ.get('STGCN_ENV', 'mmaction2_env')   # MMAction2 (ST-GCN) 전용 Conda 환경 이름


def extract_var(nb_path: Path, var_name: str):
    """nb_path(.ipynb)에서 `var_name = ...` 형태로 선언된 값을 eval()하여 반환."""
    nb = nbformat.read(nb_path, as_version=4)
    pat = re.compile(rf'^{var_name}\s*=\s*(.+)')
    for cell in nb.cells:
        if cell.cell_type != 'code': continue
        for line in cell.source.splitlines():
            m = pat.match(line.strip())
            if not m: continue
            val = m.group(1)
            try:
                return eval(val)
            except Exception:
                # Path(...) 형태 처리
                if val.startswith('Path('):
                    inner = val[len('Path('):-1]
                    return Path(eval(inner))
                raise
    raise KeyError(f"{var_name} not found in {nb_path.name}")


def run(cmd, env_name=None):
    """Conda 환경이 지정된 경우 `conda run -n <env> --no-capture-output`로 실행."""
    if env_name:
        base = ['conda', 'run', '-n', env_name, '--no-capture-output']
        cmd = base + cmd
    print("[RUN]", " ".join(cmd))
    subprocess.run(cmd, check=True)


def main():
    # 1) ID 수집 & 섞기 (CSV ↔ MP4 페어 확인)
    ids = []
    for sub in ['balanced_true', 'false', 'true']:
        kp_dir  = ROOT / sub / 'crop_keypoint'
        vid_dirs = [ROOT / sub / 'crop_video', ROOT / sub / 'video']
        if not kp_dir.exists():
            continue

        for csv_path in kp_dir.glob('*.csv'):
            vid_id = csv_path.stem
            found = False
            for vd in vid_dirs:
                for mp4_path in vd.glob(f"{vid_id}*_crop.mp4"):
                    found = True
                    break
                if found: break
            if found:
                ids.append(vid_id)
            else:
                print(f"[WARN] 매칭되는 비디오 없음 → {vid_dirs} / {vid_id}*_crop.mp4")

    if not ids:
        raise RuntimeError(f"No matching CSV↔MP4 pairs under {ROOT}")

    random.seed(SEED)
    random.shuffle(ids)
    split_idx = int(len(ids) * (1 - TEST_RATIO))
    train_ids = ids[:split_idx]
    valid_ids = ids[split_idx:]

    # 2) 리스트 파일로 저장
    train_list = Path('train_ids.txt')
    valid_list = Path('valid_ids.txt')
    train_list.write_text('
'.join(train_ids))
    valid_list.write_text('
'.join(valid_ids))
    print(f"▶ Saved {len(train_ids)} train  / {len(valid_ids)} valid IDs")

    # 3) PKL 생성 (train/valid 각각 xsub_val로)
    from pathlib import Path as _P, pickle
    import numpy as _np
    annotations = []
    def load_and_process(csv_path: _P):
        import pandas as _pd
        vals = _pd.read_csv(csv_path).values.reshape(-1,3)
        # 변환 생략, 예시로 coords 및 scores 만 저장
        return {'frame_dir': csv_path.stem,
                'keypoint': vals[:,:2][None,...],
                'keypoint_score': vals[:,2][None,...]}

    def make_pkl(id_list, out_path):
        ann_list = []
        for vid in id_list:
            csv_file = next((ROOT / 'balanced_true' / 'crop_keypoint' / f"{vid}.csv",
                             ROOT / 'false'          / 'crop_keypoint' / f"{vid}.csv"), None)
            if not csv_file or not csv_file.exists(): continue
            info = load_and_process(csv_file)
            ann_list.append(info)
        data = {'annotations': ann_list, 'split': {'xsub_val': id_list}}
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with open(out_path, 'wb') as f:
            pickle.dump(data, f, protocol=4)

    base_pkl = ROOT / 'crop_pkl'
    make_pkl(train_ids, base_pkl / 'skeleton_dataset_train.pkl')
    make_pkl(valid_ids, base_pkl / 'skeleton_dataset_valid.pkl')

    # 4) TimeSformer 실행
    run([
        'python', 'extract_embedding_timesformer.py',
        '--root',        str(ROOT),
        '--train-list',  str(train_list),
        '--valid-list',  str(valid_list),
        '--batch-size',  '8',
        '--num-frames',  '32',
        '--clips-per-vid','5',
        '--img-size',    '224',
    ], TS_ENV)

    # 5) ST-GCN 실행
    run([
        'python', 'extract_embedding_stgcn.py',
        '--csv-root',   str(ROOT),
        '--train-list', str(train_list),
        '--valid-list', str(valid_list),
        '--cfg',        'configs/skeleton/stgcnpp/my_stgcnpp.py',
        '--ckpt',       'checkpoints/stgcnpp_best.pth',
    ], STGCN_ENV)

    print("✅ All done.")
    here = Path(__file__).resolve().parent
    # 1) 노트북 경로
    nb_csv   = here / 'csv_to_pkl.ipynb'
    nb_stgcn = here / 'extract_embbeding_stgcn.ipynb'
    nb_ts    = here / 'extract_embbeding_timesformer.ipynb'

    # 2) 변수 추출
    train_csv_root = extract_var(nb_csv, 'base')    
    valid_csv_root = extract_var(nb_csv, 'BASE')    
    default_out_pkl = extract_var(nb_csv, 'out_pkl')

    stgcn_cfg  = extract_var(nb_stgcn, 'CFG_PATH')   
    stgcn_ckpt = extract_var(nb_stgcn, 'CKPT_PATH') 

    ts_root        = extract_var(nb_ts, 'ROOT')      
    try:
        ts_pretrained = extract_var(nb_ts, 'PRE_PTH')
    except KeyError:
        ts_pretrained = extract_var(nb_ts, 'PRETRAIN_PYTH')

    # 3) ID 리스트 생성
    train_ids = [p.stem for p in train_csv_root.glob('*.csv')]
    valid_ids = [p.stem for p in valid_csv_root.glob('*.csv')]
    (here/'train_ids.txt').write_text('\n'.join(train_ids))
    (here/'valid_ids.txt').write_text('\n'.join(valid_ids))

    # 4) PKL 경로 설정
    out_folder = default_out_pkl.parent
    train_pkl = out_folder / (default_out_pkl.stem + '_train.pkl')
    valid_pkl = out_folder / (default_out_pkl.stem + '_valid.pkl')
    out_folder.mkdir(parents=True, exist_ok=True)

    # 5) CSV→PKL 변환 (ST-GCN 환경)
    for split_name, ids_file, out_pkl in [
        ('xsub_train', here/'train_ids.txt', train_pkl),
        ('xsub_val',   here/'valid_ids.txt', valid_pkl),
    ]:
        run([
            'python', 'csv_to_pkl.py',
            '--csv-root', str(train_csv_root if split_name=='xsub_train' else valid_csv_root),
            '--ids-list', str(ids_file),
            '--output-file', str(out_pkl),
            '--split-name', split_name,
        ], env_name=STGCN_ENV)

    # 6) ST-GCN 임베딩 추출
    for mode, in_pkl, out_npy in [
        ('train', train_pkl, here/'train_stgcn_embeddings.npy'),
        ('valid', valid_pkl, here/'valid_stgcn_embeddings.npy'),
    ]:
        run([
            'python', 'extract_embedding_stgcn.py',
            '--input', str(in_pkl),
            '--output', str(out_npy),
            '--cfg', stgcn_cfg,
            '--ckpt', stgcn_ckpt,
        ], env_name=STGCN_ENV)

    # 7) TimeSformer 임베딩 추출
    for mode, root_dir, out_dir in [
        ('train', ts_root, here/'train_timesformer_embeddings'),
        ('valid', ts_root.parent / valid_csv_root.name, here/'valid_timesformer_embeddings'),
    ]:
        out_dir.mkdir(exist_ok=True)
        run([
            'python', 'extract_embedding_timesformer.py',
            '--input-root', str(root_dir),
            '--pretrained', str(ts_pretrained),
            '--output-dir', str(out_dir),
        ], env_name=TS_ENV)

    print("✅ All embeddings extracted successfully.")

if __name__ == "__main__":
    main()


▶ Saved 392 train  / 44 valid IDs
[RUN] conda run -n timesformer --no-capture-output python extract_embedding_timesformer.py --root D:\golfDataset\dataset\train --train-list train_ids.txt --valid-list valid_ids.txt --batch-size 8 --num-frames 32 --clips-per-vid 5 --img-size 224


CalledProcessError: Command '['conda', 'run', '-n', 'timesformer', '--no-capture-output', 'python', 'extract_embedding_timesformer.py', '--root', 'D:\\golfDataset\\dataset\\train', '--train-list', 'train_ids.txt', '--valid-list', 'valid_ids.txt', '--batch-size', '8', '--num-frames', '32', '--clips-per-vid', '5', '--img-size', '224']' returned non-zero exit status 2.