# tf비율 맞춰서 학습

In [1]:
import os
import json
from PIL import Image
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from collections import defaultdict

In [2]:
class GolfSwingSequenceDataset(Dataset):
    def __init__(self, root_dir, transform=None, sequence_length=5):
        self.transform = transform
        self.sequence_length = sequence_length
        self.samples = []

        true_json_dir = os.path.join(root_dir, "tf", "true", "json")
        false_json_dir = os.path.join(root_dir, "tf", "false", "json")
        true_jpg_dir = os.path.join(root_dir, "tf", "true", "jpg")
        false_jpg_dir = os.path.join(root_dir, "tf", "false", "jpg")

        # ✅ 스윙 ID 단위로 그룹화
        def collect_sequences(json_dir):
            swings = defaultdict(list)
            for fname in os.listdir(json_dir):
                if not fname.endswith(".json"):
                    continue
                swing_id = fname.split("_")[0]  # 예: 001_0001 -> 001
                swings[swing_id].append(fname)
            # 정렬
            for key in swings:
                swings[key] = sorted(swings[key])
            return swings

        true_swings = collect_sequences(true_json_dir)
        false_swings = collect_sequences(false_json_dir)

        # ✅ 최대 샘플 개수 조절
        max_samples = min(len(true_swings), len(false_swings), 20000)

        true_keys = list(true_swings.keys())[:max_samples]
        false_keys = list(false_swings.keys())[:max_samples]

        # ✅ true 데이터 수집
        for key in true_keys:
            file_names = true_swings[key]
            for i in range(0, len(file_names) - sequence_length + 1):
                sequence = file_names[i:i+sequence_length]
                self.samples.append({
                    "label": 1,
                    "files": [(
                        os.path.join(true_json_dir, f),
                        os.path.join(true_jpg_dir, f.replace(".json", ".jpg"))
                    ) for f in sequence]
                })

        # ✅ false 데이터 수집
        for key in false_keys:
            file_names = false_swings[key]
            for i in range(0, len(file_names) - sequence_length + 1):
                sequence = file_names[i:i+sequence_length]
                self.samples.append({
                    "label": 0,
                    "files": [(
                        os.path.join(false_json_dir, f),
                        os.path.join(false_jpg_dir, f.replace(".json", ".jpg"))
                    ) for f in sequence]
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        images = []
        for json_path, img_path in sample["files"]:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            images.append(image)
        sequence = torch.stack(images)
        label = torch.tensor(sample["label"], dtype=torch.float32)
        return sequence, label
    
    # ✅ 이미지 전처리 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# ✅ root 디렉토리는 tf보다 상위 디렉토리까지 포함
root_path = "D:/golfDataset/스포츠 사람 동작 영상(골프)/Training/Public/male"

# ✅ 수정된 파라미터명: sequence_length
dataset = GolfSwingSequenceDataset(root_dir=root_path, transform=transform, sequence_length=5)

# ✅ 데이터셋 분할 (train 80%, val 20%)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# ✅ DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

import torch.nn as nn

class CNN_GRU_Classifier(nn.Module):
    def __init__(self, hidden_size=128, num_layers=1):
        super().__init__()
        # CNN encoder
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        self.flattened_size = 32 * 56 * 56  # assuming input is 224x224
        self.gru = nn.GRU(input_size=self.flattened_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):  # x shape: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        x = self.cnn(x)
        x = x.view(B, T, -1)  # reshape for GRU
        out, _ = self.gru(x)
        out = out[:, -1, :]  # 마지막 타임스텝 출력
        out = self.fc(out)
        return out


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_GRU_Classifier().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

for epoch in range(5):
    print(f"🌀 Epoch {epoch+1}/5")
    model.train()
    train_loss = 0.0

    for x_batch, y_batch in tqdm(train_loader, desc="Training"):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device).unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_loss = train_loss / len(train_loader)
    print(f"🟢 Train Loss: {avg_loss:.4f}")


🌀 Epoch 1/5


Training: 100%|██████████| 30675/30675 [5:29:58<00:00,  1.55it/s]  


🟢 Train Loss: 0.0291
🌀 Epoch 2/5


Training: 100%|██████████| 30675/30675 [5:31:12<00:00,  1.54it/s]  


🟢 Train Loss: 0.0005
🌀 Epoch 3/5


Training: 100%|██████████| 30675/30675 [5:30:01<00:00,  1.55it/s]  


🟢 Train Loss: 0.0009
🌀 Epoch 4/5


Training: 100%|██████████| 30675/30675 [5:32:28<00:00,  1.54it/s]  


🟢 Train Loss: 0.0006
🌀 Epoch 5/5


Training: 100%|██████████| 30675/30675 [5:05:13<00:00,  1.67it/s]  

🟢 Train Loss: 0.0007





In [8]:
# ✅ 저장 경로 지정
save_path = r"D:\golfDataset\CNN\cnn_gru_model.pth"

# ✅ 모델 저장
torch.save(model.state_dict(), save_path)
print(f"✅ 모델이 저장되었습니다: {save_path}")

✅ 모델이 저장되었습니다: D:\golfDataset\CNN\cnn_gru_model.pth


# 일부 디렉토리 학습