In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

# DeepLabv3+ 모델 정의 (ResNet-50 백본 사용)
class DeepLabv3Plus(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabv3Plus, self).__init__()
        # ResNet-50 백본 로드
        self.backbone = models.resnet50(pretrained=True)
        # DeepLabv3+ 헤드 부분 추가 (PyTorch 내장 함수 활용)
        self.classifier = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        # ResNet-50을 통해 특징 추출
        features = self.backbone(x)
        # DeepLabv3+ 헤드를 통해 분할 결과 생성
        output = self.classifier(features)
        return output

# 하이퍼파라미터 설정
num_classes = 21  # 클래스 개수
batch_size = 16
learning_rate = 0.001
num_epochs = 10

# 데이터 변환
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 로드
train_dataset = ImageFolder(root='/home/mira/Desktop/KistAIRobot/david/autonomous-driving/project/2DSS/training/', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 모델, 손실 함수, 옵티마이저 생성
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 학습
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# 모델 저장
torch.save(model.state_dict(), "deeplabv3plus.pth")
print("Saved PyTorch Model State to deeplabv3plus.pth")

FileNotFoundError: Found no valid file for the classes labels. Supported extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp