In [1]:
import sys
sys.path.append(r"D:\timesformer")
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from decord import VideoReader
from torchvision.transforms import functional as F, InterpolationMode
import random

# ----------------- 설정 ----------------------------
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 224
NUM_FRAMES = 32
TRAIN_CLIPS_PER_VID = 2
TEST_CLIPS_PER_VID = 5
BATCH_SIZE = 1
EMBED_DIM = 768  # TimeSformer base
TRAIN_N = 300
TEST_N = 100

PRETRAIN_PYTH = Path(r"D:\timesformer\pretrained\TimeSformer_divST_96x4_224_K600.pyth")
MODEL_PATH = Path("timesformer_finetuned.pth")
TRAIN_ROOT = Path(r"D:\golfDataset\dataset\train")
TEST_ROOT = Path(r"D:\golfDataset\dataset\test")

# ----------------- 비디오 전처리 ----------------------------
def preprocess_tensor(img_tensor):
    img = F.resize(img_tensor, 256, interpolation=InterpolationMode.BICUBIC)
    img = F.center_crop(img, IMG_SIZE)
    img = F.normalize(img, [0.45]*3, [0.225]*3)
    return img

def uniform_sample(L, N):
    if L >= N:
        return np.linspace(0, L-1, N).astype(int)
    return np.pad(np.arange(L), (0, N-L), mode='edge')

def load_clip(path: Path, clips_per_vid):
    vr = VideoReader(str(path))
    L = len(vr)
    seg_edges = np.linspace(0, L, clips_per_vid + 1, dtype=int)
    clips = []
    for s0, s1 in zip(seg_edges[:-1], seg_edges[1:]):
        idx = uniform_sample(s1 - s0, NUM_FRAMES) + s0
        arr = vr.get_batch(idx).asnumpy().astype(np.uint8)
        clip = torch.from_numpy(arr).permute(0, 3, 1, 2).float() / 255.0
        clip = torch.stack([preprocess_tensor(f) for f in clip])
        clips.append(clip.permute(1, 0, 2, 3))  # (3, T, H, W)
    return clips

class VideoDataset(Dataset):
    def __init__(self, root: Path, n_samples=100, clips_per_vid=1):
        self.samples = []
        self.clips_per_vid = clips_per_vid
        # balanced_true → 1, false → 0
        for lbl, sub in zip((1, 0), ("balanced_true", "false")):
            folder = root / sub / "crop_video"
            if folder.exists():
                for p in folder.glob("*.mp4"):
                    self.samples.append((p, lbl))
        random.shuffle(self.samples)
        self.samples = self.samples[:n_samples]
        print(f"🔎 {len(self.samples)} samples loaded from {root}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, lbl = self.samples[idx]
        clips = load_clip(path, self.clips_per_vid)
        return torch.stack(clips), torch.tensor(lbl), str(path.name)

# ----------------- TimeSformer 임베딩 추출 ----------------------------
from timesformer.models.vit import TimeSformer

class TimeSformerEmbed(nn.Module):
    def __init__(self, model_path, img_size, num_frames, num_classes, pretrained_path):
        super().__init__()
        self.base = TimeSformer(
            img_size=img_size,
            num_frames=num_frames,
            num_classes=num_classes,
            attention_type='divided_space_time',
            pretrained_model=str(pretrained_path)
        )
        ckpt = torch.load(model_path, map_location="cpu")
        self.base.load_state_dict(ckpt["model"])
        # 분류 헤드 제거
        self.base.head = nn.Identity()
        self.base.cls_head = nn.Identity()

    def forward(self, x):  # x: (B, 3, T, H, W)
        return self.base(x)  # (B, embed_dim)

# ----------------- 임베딩 추출 함수 ----------------------------
def extract_embeddings(model, loader):
    model.eval()
    feats, labels, names = [], [], []
    with torch.no_grad():
        for clips, lbl, vname in loader:
            # clips: (B, CLIPS_PER_VID, 3, T, H, W)
            for i in range(clips.size(1)):
                x = clips[:, i].to(DEVICE)  # (B, 3, T, H, W)
                # TimeSformer forward_features로 임베딩 추출 (CLS 토큰)
                if hasattr(model, "base") and hasattr(model.base, "model"):
                    out = model.base.model.forward_features(x)
                else:
                    out = model.base.forward_features(x)
                # CLS 토큰만 사용
                cls_emb = out[:, 0, :] if out.ndim == 3 else out
                feats.append(cls_emb.squeeze(0).cpu().numpy())
                labels.append(lbl.item())
                names.append(vname[0])
    feats = np.stack(feats)
    labels = np.array(labels)
    # 임베딩 shape 확인
    print(f"[DEBUG] 임베딩 shape: {feats.shape}")  # (N, D)
    if feats.shape[1] != 768:
        print("[WARNING] 임베딩 차원이 768이 아닙니다!")
    return feats, labels, names

# ----------------- MLP 분류기 ----------------------------
class MLPClassifier(nn.Module):
    def __init__(self, in_dim, num_classes=2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.mlp(x)

# ----------------- 메인 파이프라인 ----------------------------
if __name__ == "__main__":
    # 1. 데이터셋 준비
    train_ds = VideoDataset(TRAIN_ROOT, n_samples=TRAIN_N, clips_per_vid=TRAIN_CLIPS_PER_VID)
    test_ds  = VideoDataset(TEST_ROOT, n_samples=TEST_N, clips_per_vid=TEST_CLIPS_PER_VID)
    train_ld = DataLoader(train_ds, batch_size=1, shuffle=False)
    test_ld  = DataLoader(test_ds, batch_size=1, shuffle=False)

    # 2. 임베딩 추출
    print("⏳ TimeSformer 임베딩 추출 중...")
    embed_model = TimeSformerEmbed(
        model_path=MODEL_PATH,
        img_size=IMG_SIZE,
        num_frames=NUM_FRAMES,
        num_classes=2,
        pretrained_path=PRETRAIN_PYTH
    ).to(DEVICE)
    train_feats, train_labels, _ = extract_embeddings(embed_model, train_ld)
    test_feats, test_labels, test_names = extract_embeddings(embed_model, test_ld)
    print(f"✅ 임베딩 추출 완료: train {train_feats.shape}, test {test_feats.shape}")

    # 3. MLP 학습
    mlp = MLPClassifier(in_dim=train_feats.shape[1]).to(DEVICE)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    EPOCHS = 20

    X_train = torch.tensor(train_feats, dtype=torch.float32).to(DEVICE)
    y_train = torch.tensor(train_labels, dtype=torch.long).to(DEVICE)
    X_test = torch.tensor(test_feats, dtype=torch.float32).to(DEVICE)
    y_test = torch.tensor(test_labels, dtype=torch.long).to(DEVICE)

    for ep in range(EPOCHS):
        mlp.train()
        optimizer.zero_grad()
        out = mlp(X_train)
        loss = criterion(out, y_train)
        loss.backward()
        optimizer.step()
        pred = out.argmax(1)
        acc = (pred == y_train).float().mean().item()
        print(f"Epoch {ep+1}/{EPOCHS} | Train Loss: {loss.item():.4f} | Train Acc: {acc:.3%}")

    # 4. 테스트
    mlp.eval()
    with torch.no_grad():
        out = mlp(X_test)
        pred = out.argmax(1)
        acc = (pred == y_test).float().mean().item()
        print(f"\n🎯 Test Accuracy: {acc:.3%} ({(pred==y_test).sum().item()}/{len(y_test)})")
        # 상세 결과
        for name, p, gt in zip(test_names, pred.cpu().numpy(), y_test.cpu().numpy()):
            print(f"{name}: pred={p}, gt={gt}, correct={p==gt}")

🔎 300 samples loaded from D:\golfDataset\dataset\train
🔎 100 samples loaded from D:\golfDataset\dataset\test
⏳ TimeSformer 임베딩 추출 중...
[DEBUG] 임베딩 shape: (600, 768)
[DEBUG] 임베딩 shape: (500, 768)
✅ 임베딩 추출 완료: train (600, 768), test (500, 768)
Epoch 1/20 | Train Loss: 0.6402 | Train Acc: 68.833%
Epoch 2/20 | Train Loss: 0.6680 | Train Acc: 76.000%
Epoch 3/20 | Train Loss: 0.5892 | Train Acc: 76.000%
Epoch 4/20 | Train Loss: 0.4900 | Train Acc: 76.000%
Epoch 5/20 | Train Loss: 0.4639 | Train Acc: 77.833%
Epoch 6/20 | Train Loss: 0.4747 | Train Acc: 82.333%
Epoch 7/20 | Train Loss: 0.4682 | Train Acc: 83.333%
Epoch 8/20 | Train Loss: 0.4373 | Train Acc: 83.833%
Epoch 9/20 | Train Loss: 0.4051 | Train Acc: 82.167%
Epoch 10/20 | Train Loss: 0.3888 | Train Acc: 80.000%
Epoch 11/20 | Train Loss: 0.3837 | Train Acc: 79.333%
Epoch 12/20 | Train Loss: 0.3725 | Train Acc: 79.833%
Epoch 13/20 | Train Loss: 0.3491 | Train Acc: 81.667%
Epoch 14/20 | Train Loss: 0.3253 | Train Acc: 86.167%
Epoch 15/20