In [1]:
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from decord import VideoReader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import json
from tqdm import tqdm
import sys

# === 파라미터 직접 지정 ===
ROOT = Path(r'D:/golfDataset/dataset')
FUSION_DIR = Path(r'D:/Jabez/golf/fusion')
PER_VIDEO_DIR = FUSION_DIR / 'embbeding_data' / 'timesformer' / 'per_video'
PER_VIDEO_DIR.mkdir(parents=True, exist_ok=True)
PRETRAINED = Path(r'D:/timesformer/pretrained/TimeSformer_divST_96x4_224_K600.pyth')
NUM_FRAMES = 32
CLIPS_PER_VID = 5
IMG_SIZE = 224
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

sys.path.append(r'D:/timesformer')
from timesformer.models.vit import TimeSformer

# transform
mean = [0.45,0.45,0.45]; std=[0.225,0.225,0.225]
eval_transform = transforms.Compose([
    transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

def uniform_sample(length, num):
    if length >= num:
        return np.linspace(0, length-1, num, dtype=int)
    return np.pad(np.arange(length), (0,num-length), mode='edge')

def load_clip(path: Path):
    vr = VideoReader(str(path))
    L  = len(vr)
    segs = np.linspace(0, L, CLIPS_PER_VID+1, dtype=int)
    clips = []
    for s,e in zip(segs[:-1], segs[1:]):
        idx = uniform_sample(e-s, NUM_FRAMES) + s
        arr = vr.get_batch(idx).asnumpy()  # (T,H,W,3)
        proc = []
        for frame in arr:
            img = transforms.ToPILImage()(frame)
            img_t = eval_transform(img)
            proc.append(img_t)
        clip = torch.stack(proc, dim=1)  # (C,T,H,W)
        clips.append(clip)
    return clips

# 모델 로드
model = TimeSformer(
    img_size=IMG_SIZE,
    num_frames=NUM_FRAMES,
    num_classes=2,
    attention_type='divided_space_time',
    pretrained_model=str(PRETRAINED)
).to(DEVICE)
for attr in ('head','cls_head'):
    if hasattr(model, attr): setattr(model, attr, nn.Identity())
    if hasattr(model, 'model') and hasattr(model.model, attr):
        setattr(model.model, attr, nn.Identity())
model.eval()

# train, test 폴더 내 balanced_true/false/crop_video/*.mp4 모두 처리
mapping = {'balanced_true': 1, 'false': 0}
all_mp4s = []

for split in ['train', 'test']:
    split_root = ROOT / split
    for cat, lbl in mapping.items():
        vd = split_root / cat / 'crop_video'
        if not vd.exists(): continue
        for mp4 in vd.glob('*.mp4'):
            all_mp4s.append((mp4, lbl, cat, split))

print(f'총 {len(all_mp4s)}개 mp4 처리')

for mp4, lbl, cat, split in tqdm(all_mp4s, desc='Extracting', ncols=80):
    vid = mp4.stem
    out_path = PER_VIDEO_DIR / f'{vid}.npy'
    meta_path = PER_VIDEO_DIR / f'{vid}.json'
    if out_path.exists():
        continue
    clips = load_clip(mp4)
    feats = []
    for clip in clips:
        c = clip.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            out = model.model.forward_features(c)
        cls = out[:,0,:] if out.ndim==3 else out
        feats.append(cls.squeeze(0).cpu().numpy())
    emb = np.stack(feats,0).mean(0)
    np.save(out_path, emb)
    meta = {
        'video_id': vid, 'label': lbl, 'category': cat, 'split': split,
        'mp4_path': str(mp4)
    }
    with open(meta_path, 'w', encoding='utf-8') as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)


총 275개 mp4 처리


Extracting: 100%|█████████████████████████████| 275/275 [08:08<00:00,  1.78s/it]
