In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
import cv2

# 이미지 전처리를 위한 transform 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
# CustomDataset 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, img_dir, annotations_dir, transform=None):
        self.img_dir = img_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.img_names = [x for x in os.listdir(img_dir) if x.endswith('.jpg')]
        self.class_to_idx = {'apple': 0, 'banana': 1, 'orange': 2} # 예시, 실제 클래스에 맞게 수정 필요
        self.classes = list(self.class_to_idx.keys())


    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        annotation_path = os.path.join(self.annotations_dir, self.img_names[idx].replace('.jpg', '.xml'))
        image = Image.open(img_path).convert('RGB')
        label = self.extract_label(annotation_path)
        if self.transform:
            image = self.transform(image)
        return image, label

    def extract_label(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        label = root.find('object').find('name').text
        return self.class_to_idx[label]

# 데이터셋 및 데이터 로더 인스턴스화
train_dataset = CustomDataset('../../dataset/train_zip/train/', '../../dataset/train_zip/train/', transform=transform)
test_dataset = CustomDataset('../../dataset/test_zip/test/', '../../dataset/test_zip/test/', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

In [10]:
cv2.setUseOptimized(True) # cv2.resize() 최적화 사용
selective_search = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation() # Selective Search 객체 생성

image_path = '../../dataset/train_zip/train/apple_1.jpg'
image = cv2.imread(image_path) # 이미지 읽기

if image is None:
    print(f"이미지를 읽어오지 못했습니다. 경로를 확인하세요: {image_path}")
else:
    selective_search.setBaseImage(image) # 이미지 설정
    selective_search.switchToSelectiveSearchFast() # Selective Search 알고리즘 적용
    rects = selective_search.process() # 검출된 bounding box 정보
    out = image.copy()
    for i, rect in enumerate(rects):
        x, y, w, h = rect
        cv2.rectangle(out, (x, y), (x+w, y+h), (0, 255, 0), 1)

    plt.imshow(cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
    plt.show()


이미지를 읽어오지 못했습니다. 경로를 확인하세요: ../../데이터셋/train_zip/train/apple_1.jpg
