In [87]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from PIL import Image
import os
import pandas as pd
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
from tqdm import tqdm
from torchinfo import summary as tsummary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


# Prep DATASET

In [88]:
class SeqDataset(Dataset):
    def __init__(self, csv_path, seq_root, seq_len=30, transform=None):

        self.df = pd.read_csv(csv_path)
        self.seq_root = seq_root
        self.seq_len = seq_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        seq_name = row["sequence"] 
        label = row["label"]

        seq_folder = os.path.join(self.seq_root, seq_name)

        frames = []
        for i in range(1, self.seq_len + 1):
            img_path = os.path.join(seq_folder, f"{i:05d}.jpg")

            img = Image.open(img_path).convert("RGB")

            if self.transform:
                img = self.transform(img)

            frames.append(img)

        frames = torch.stack(frames)

        return frames, torch.tensor(label, dtype=torch.float32)

In [89]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ## imgnet stats
])


train_dataset = SeqDataset(
    csv_path = "Data/train/train.csv",
    seq_root="Data/train/sequences",
    seq_len=30,
    transform=transform
)
val_dataset = SeqDataset(
    csv_path = "Data/validation/validation.csv",
    seq_root="Data/validation/sequences",
    seq_len=30,
    transform=transform
)

test_dataset = SeqDataset(
    csv_path = "Data/test/test.csv",
    seq_root="Data/test/sequences",
    seq_len=30,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False) 
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Number of training samples: 618
Number of validation samples: 256
Number of test samples: 116


# DEFINE MODEL

### ResNet + LSTM + FC

In [90]:
class RES_LSTM(nn.Module):
    def __init__(self, hidden_size=128, num_layers=2):  # <-- num_layers 증가
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        RES = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        for p in RES.parameters():
            p.requires_grad = False

        modules = list(RES.children())[:-1]
        self.cnn = nn.Sequential(*modules)
        
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.3  # <-- Dropout 추가
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        B, S, C, H, W = x.shape
        features = []

        for frame_idx in range(S):
            feat = self.cnn(x[:, frame_idx])
            feat = feat.view(B, -1)
            features.append(feat)
            
        features = torch.stack(features, dim=1)  # (B, S, 512)
        lstm_out, (h_n, c_n) = self.lstm(features)  # (B, S, hidden_size)
        
        # 방법 1: 모든 timestep의 평균 사용 (더 좋음)
        lstm_out = lstm_out.mean(dim=1)  # (B, hidden_size)
        
        # 방법 2 (선택): 마지막만 사용하려면
        # lstm_out = lstm_out[:, -1, :]
        
        out = self.fc(lstm_out)
        
        return out

    

# Define func Train / Evaludate

In [91]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for (x, y) in tqdm(loader, desc="Training", leave=True):
        x = x.to(device)
        y = y.to(device).float() # label is int so convert to float

        optimizer.zero_grad()
        logits = model(x).squeeze(1)

        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        batch_size = y.size(0)
        total_loss += loss.item() * batch_size

        preds = (torch.sigmoid(logits) >= 0.5).long()
        correct += (preds == y.long()).sum().item()
        total += batch_size

    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    for (x, y) in tqdm(loader, desc="Evaluating", leave=True):
        x = x.to(device)
        y = y.to(device).float()

        logits = model(x).squeeze(1)

        loss = criterion(logits, y)

        batch_size = y.size(0)
        total_loss += loss.item() * batch_size

        preds = (torch.sigmoid(logits) >= 0.5).long()
        correct += (preds == y.long()).sum().item()
        total += batch_size

    return total_loss / total, correct / total

In [95]:
model = RES_LSTM(hidden_size=256, num_layers=2)
model = model.to(device)

tsummary(model, input_size=(1, 30, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
RES_LSTM                                      [1, 1]                    --
├─Sequential: 1-1                             [1, 512, 1, 1]            --
│    └─Conv2d: 2-1                            [1, 64, 112, 112]         (9,408)
│    └─BatchNorm2d: 2-2                       [1, 64, 112, 112]         (128)
│    └─ReLU: 2-3                              [1, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                         [1, 64, 56, 56]           --
│    └─Sequential: 2-5                        [1, 64, 56, 56]           --
│    │    └─BasicBlock: 3-1                   [1, 64, 56, 56]           (73,984)
│    │    └─BasicBlock: 3-2                   [1, 64, 56, 56]           (73,984)
│    └─Sequential: 2-6                        [1, 128, 28, 28]          --
│    │    └─BasicBlock: 3-3                   [1, 128, 28, 28]          (230,144)
│    │    └─BasicBlock: 3-4                   [1, 128, 28, 28]      

In [None]:
# 학습 셀 수정
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)  # <-- learning rate 낮춤
num_epochs = 20  # <-- epoch 증가

best_val_loss = float('inf')
patience = 5
patience_counter = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pth")  # 최고 성능 모델 저장
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

# 테스트
model.load_state_dict(torch.load("best_model.pth"))
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"\nTest Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

Training: 100%|██████████| 78/78 [07:19<00:00,  5.64s/it]
Evaluating: 100%|██████████| 32/32 [02:43<00:00,  5.12s/it]


Epoch 1/20 | Train Loss: 0.5149, Acc: 0.7524 | Val Loss: 0.6374, Acc: 0.6211


Training: 100%|██████████| 78/78 [07:15<00:00,  5.58s/it]
Evaluating: 100%|██████████| 32/32 [02:43<00:00,  5.12s/it]


Epoch 2/20 | Train Loss: 0.5078, Acc: 0.7443 | Val Loss: 0.4949, Acc: 0.7344


Training: 100%|██████████| 78/78 [07:17<00:00,  5.61s/it]
Evaluating: 100%|██████████| 32/32 [02:41<00:00,  5.05s/it]


Epoch 3/20 | Train Loss: 0.4675, Acc: 0.7783 | Val Loss: 0.4803, Acc: 0.7422


Training:   0%|          | 0/78 [00:00<?, ?it/s]