# CLS 토큰 only, fintuned 모델 사용

In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from decord import VideoReader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import json
from tqdm import tqdm
import sys

# === 파라미터 직접 지정 ===
ROOT = Path(r'D:/golfDataset/dataset')
FUSION_DIR = Path(r'D:/Jabez/golf/fusion')
PER_VIDEO_DIR = FUSION_DIR / 'embedding_data' / 'timesformer' / 'per_video'
PER_VIDEO_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = Path(r'D:/Jabez/golf/Timesformer_finetune/timesformer_finetuned.pth')
PRETRAINED = Path(r'D:/timesformer/pretrained/TimeSformer_divST_96x4_224_K600.pyth')
NUM_FRAMES = 96    # 비디오에서 추출할 프레임 수, pretrained 모델이 96이므로, 이에 맞춤
# 일반적으로 스윙은 2초 이상(jpg 결합이 30fps 기준, 120프레임은 필요)
CLIPS_PER_VID = 2
IMG_SIZE = 224
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

sys.path.append(r'D:/timesformer')
from timesformer.models.vit import TimeSformer

# ImageNet mean/std 사용 (공식 권장)
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]
eval_transform = transforms.Compose([
    transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),  # [0,1], (C,H,W)
    transforms.Normalize(mean, std),
])

def uniform_sample(length, num):
    if length >= num:
        return np.linspace(0, length-1, num, dtype=int)
    return np.pad(np.arange(length), (0,num-length), mode='edge')

def load_clip(path: Path):
    vr = VideoReader(str(path))
    L  = len(vr)
    segs = np.linspace(0, L, CLIPS_PER_VID+1, dtype=int)
    clips = []
    for s,e in zip(segs[:-1], segs[1:]):
        idx = uniform_sample(e-s, NUM_FRAMES) + s
        arr = vr.get_batch(idx).asnumpy()  # (T,H,W,3)
        proc = []
        for frame in arr:
            img = transforms.ToPILImage()(frame)
            img_t = eval_transform(img)
            proc.append(img_t)
        clip = torch.stack(proc, dim=1)  # (C,T,H,W)
        clips.append(clip)
    return clips

# train, test 폴더 내 balanced_true/false/crop_video/*.mp4 모두 처리
mapping = {'balanced_true': 1, 'false': 0}
all_mp4s = []

for split in ['train', 'test']:
    split_root = ROOT / split
    for cat, lbl in mapping.items():
        vd = split_root / cat / 'crop_video'
        if not vd.exists(): continue
        for mp4 in vd.glob('*.mp4'):
            all_mp4s.append((mp4, lbl, cat, split))

print(f'총 {len(all_mp4s)}개 mp4 처리')

# === 파인튜닝된 모델로 임베딩 추출 (timesformer_finetuned.pth) ===

class TimeSformerEmbed(nn.Module):
    def __init__(self, model_path, img_size, num_frames, num_classes, pretrained_path):
        super().__init__()
        self.base = TimeSformer(
            img_size=img_size,
            num_frames=num_frames,
            num_classes=num_classes,
            attention_type='divided_space_time',
            pretrained_model=str(pretrained_path)
        )
        ckpt = torch.load(model_path, map_location="cpu")
        self.base.load_state_dict(ckpt["model"])
        # 분류 헤드 제거
        self.base.head = nn.Identity()
        self.base.cls_head = nn.Identity()

    def forward(self, x):  # x: (B, 3, T, H, W)
        return self.base(x)  # (B, embed_dim)

# 기존 모델 대신 파인튜닝된 모델로 교체
embed_model = TimeSformerEmbed(
    model_path=MODEL_PATH,
    img_size=IMG_SIZE,
    num_frames=NUM_FRAMES,
    num_classes=2,
    pretrained_path=PRETRAINED
).to(DEVICE)
embed_model.eval()

# 임베딩 추출 및 저장
for mp4, lbl, cat, split in tqdm(all_mp4s, desc='Extracting', ncols=80):
    vid = mp4.stem
    out_path = PER_VIDEO_DIR / f'{vid}.npy'
    meta_path = PER_VIDEO_DIR / f'{vid}.json'
    # --- 기존 임베딩이 있어도 무조건 새로 생성 ---
    # if out_path.exists():
    #     continue
    clips = load_clip(mp4)
    feats = []
    for clip in clips:
        c = clip.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            out = embed_model.base.model.forward_features(c)
        cls = out[:,0,:] if out.ndim==3 else out
        feats.append(cls.squeeze(0).cpu().numpy())
    emb = np.stack(feats,0).mean(0)
    np.save(out_path, emb)
    meta = {
        'video_id': vid, 'label': lbl, 'category': cat, 'split': split,
        'mp4_path': str(mp4)
    }
    with open(meta_path, 'w', encoding='utf-8') as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)


총 2072개 mp4 처리


Extracting: 100%|███████████████████████████| 2072/2072 [40:35<00:00,  1.18s/it]
