In [None]:
import sys
sys.path.append(r"D:\timesformer")  # timesformer 경로를 python 모듈 경로에 추가

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from pathlib import Path
import numpy as np
import random
from decord import VideoReader
from tqdm import tqdm
from timesformer.models.vit import TimeSformer

from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode

def preprocess_tensor(img_tensor):  # img_tensor: (3, H, W)
    img = F.resize(img_tensor, 256, interpolation=InterpolationMode.BICUBIC)
    img = F.center_crop(img, IMG_SIZE)
    img = F.normalize(img, [0.45]*3, [0.225]*3)
    return img



# ----------------- 하이퍼파라미터 ----------------------------
ROOT          = Path(r"D:\golfDataset\스포츠 사람 동작 영상(골프)\Training\Public\male\train")
PRETRAIN_PYTH = Path(r"D:\timesformer\pretrained\TimeSformer_divST_96x4_224_K600.pyth")
NUM_FRAMES    = 32
CLIPS_PER_VID = 5
IMG_SIZE      = 224
BATCH         = 4
LR            = 1e-3
EPOCHS        = 10
SEED          = 42
TEST_RATIO    = 0.1     # 테스트셋 비율

# ----------------- 재현성 ----------------------------
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
device = "cuda"  # 무조건 gpu 사용 안되면 오류 내버리기

# ----------------- 전처리 ----------------------------
base_tf = T.Compose([
    T.Resize(256, interpolation=InterpolationMode.BICUBIC),
    T.CenterCrop(IMG_SIZE),
    T.ToTensor(),
    T.Normalize([0.45]*3, [0.225]*3),
])

def uniform_sample(L, N):
    if L >= N:
        return np.linspace(0, L-1, N).astype(int)
    return np.pad(np.arange(L), (0, N-L), mode='edge')

def load_clip(path: Path):  # 비디오를 timesformer 입력에 맞게 전처리
    vr = VideoReader(str(path))
    L = len(vr)
    seg_edges = np.linspace(0, L, CLIPS_PER_VID + 1, dtype=int)
    clips = []
    for s0, s1 in zip(seg_edges[:-1], seg_edges[1:]):
        idx = uniform_sample(s1 - s0, NUM_FRAMES) + s0
        arr = vr.get_batch(idx).asnumpy().astype(np.uint8)              # (T, H, W, 3)
        clip = torch.from_numpy(arr).permute(0, 3, 1, 2).float() / 255.0  # (T, 3, H, W)
        clip = torch.stack([preprocess_tensor(f) for f in clip])         # (T, 3, H, W)
        clips.append(clip.permute(1, 0, 2, 3))                            # (3, T, H, W)
    return clips  # list of (3, T, H, W)


# ----------------- 데이터셋 ----------------------------
class SwingDataset(Dataset):
    def __init__(self, root: Path):
        self.samples = []
        for lbl, sub in enumerate(("balanced_true", "false")):
            for p in (root/sub/"crop_video").glob("*.mp4"):
                self.samples.append((p, lbl))
        print(f"\u2705 {len(self.samples)} samples found in {root}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        path, lbl = self.samples[i]
        clips = load_clip(path)
        return torch.stack(clips), torch.tensor(lbl)

# ----------------- 데이터 로더 ----------------------------
ds_full = SwingDataset(ROOT)
n_test = int(len(ds_full)*TEST_RATIO)
n_train = len(ds_full) - n_test
train_ds, test_ds = random_split(ds_full, [n_train, n_test])
train_ld = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
test_ld  = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
# num_workers는 CPU 코어수, 4개로 하니 오류생겨서 일단 0개로 설정

# ----------------- Checkpoint 설정 ----------------------------
checkpoint_path = Path("checkpoint.pth")  # 🔧 체크포인트 저장 경로

start_epoch = 0  # 🔧 기본값

# ----------------- 모델 ----------------------------
model = TimeSformer(img_size=IMG_SIZE, num_frames=NUM_FRAMES,
                    num_classes=2, attention_type='divided_space_time',
                    pretrained_model=str(PRETRAIN_PYTH)).to(device)

# 분류 레이어만 학습하도록 설정
def get_trainable_params(model):
    trainable = []
    for name, param in model.named_parameters():
        if any(x in name for x in ('head', 'cls_head')):
            param.requires_grad = True
            trainable.append(param)
        else:
            param.requires_grad = False
    return trainable

opt = optim.AdamW(get_trainable_params(model), lr=LR, weight_decay=0.02)
crit = nn.CrossEntropyLoss()

# 🔧 Checkpoint가 존재하면 불러오기
if checkpoint_path.exists():
    print(f"🔁 Loading checkpoint from {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    opt.load_state_dict(ckpt["opt"])
    start_epoch = ckpt["epoch"] + 1
    print(f"✅ Resuming from epoch {start_epoch}")

# ----------------- 학습 ----------------------------
for ep in range(start_epoch, EPOCHS):
    model.train()
    tot = correct = 0
    for clips, lab in tqdm(train_ld, desc=f"Epoch {ep}", ncols=70):
        vids = clips.squeeze(0).to(device)  # (5,3,T,H,W)
        labs = lab.repeat(CLIPS_PER_VID).to(device)
        outs = model(vids)  # (5,2)
        loss = crit(outs, labs)
        opt.zero_grad(); loss.backward(); opt.step()
        tot += labs.size(0); correct += (outs.argmax(1) == labs).sum().item()
    print(f"  train acc: {correct/tot:.3%}")

    # 🔧 체크포인트 저장
    torch.save({
        "epoch": ep,
        "model": model.state_dict(),
        "opt": opt.state_dict(),
    }, checkpoint_path)
    print(f"💾 Checkpoint saved to {checkpoint_path}")

# ----------------- 평가 ----------------------------
model.eval(); tot = correct = 0
with torch.no_grad():
    for clips, lab in tqdm(test_ld, desc="[Test]", ncols=70):
        vids = clips.squeeze(0).to(device)
        probs = model(vids).softmax(1).mean(0, keepdim=True)  # 평균 ensemble
        pred = probs.argmax(1)
        tot += 1; correct += (pred.cpu() == lab).item()
print(f"\n✅ Test Video Accuracy : {correct/tot:.3%}")

# 🔧 최종 학습된 모델 저장
final_model_path = Path("timesformer_finetuned.pth")
torch.save({
    "epoch": EPOCHS - 1,
    "model": model.state_dict(),
    "opt": opt.state_dict(),
}, final_model_path)
print(f"\n✅ Final model saved to {final_model_path}")




✅ 436 samples found in D:\golfDataset\스포츠 사람 동작 영상(골프)\Training\Public\male\train
🔁 Loading checkpoint from checkpoint.pth
✅ Resuming from epoch 5


Epoch 5: 100%|██████████████████████| 393/393 [10:46<00:00,  1.64s/it]


  train acc: 90.941%
💾 Checkpoint saved to checkpoint.pth


Epoch 6: 100%|██████████████████████| 393/393 [10:45<00:00,  1.64s/it]


  train acc: 89.567%
💾 Checkpoint saved to checkpoint.pth


Epoch 7: 100%|██████████████████████| 393/393 [10:45<00:00,  1.64s/it]


  train acc: 92.265%
💾 Checkpoint saved to checkpoint.pth


Epoch 8: 100%|██████████████████████| 393/393 [10:46<00:00,  1.64s/it]


  train acc: 92.163%
💾 Checkpoint saved to checkpoint.pth


Epoch 9: 100%|██████████████████████| 393/393 [10:48<00:00,  1.65s/it]


  train acc: 92.774%
💾 Checkpoint saved to checkpoint.pth


[Test]: 100%|█████████████████████████| 43/43 [01:09<00:00,  1.62s/it]


✅ Test Video Accuracy : 74.419%



