In [35]:
import torch
import torchvision
from torchvision import models

import xml.etree.ElementTree as ET
from PIL import Image
import os

In [25]:
AlexNet = models.alexnet(weights="IMAGENET1K_V1", progress=True)
AlexNet

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

# Custom Dataset

In [63]:
VOC_CLASSES = [
    'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 
    'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

In [64]:
class CustomDataset():
    def __init__(self, root, transform, img_set='trainval', num_classes=20):
        self.root = root
        self.transform = transform
        self.img_set = img_set
        self.num_classes = num_classes
        
        self.annotation_path = os.path.join(self.root, f'PASCAL_VOC_{self.img_set}', 'VOCdevkit', 'VOC2007', 'Annotations')
        self.img_path = os.path.join(self.root, f'PASCAL_VOC_{self.img_set}', 'VOCdevkit', 'VOC2007', 'JPEGimages')

        self.annotations = [os.path.join(self.annotation_path, xml) for xml in os.listdir(self.annotation_path)]
        self.images = [os.path.join(self.img_path, xml) for xml in os.listdir(self.img_path)]


    def __len__(self):
        return len(self.images)

    
    def parse_xml_boxes(self, root):
        bboxes = []
        for obj in root.findall('object'):
            cls_name = obj.find('name').text
            if cls_name in VOC_CLASSES:
                cls_idx = VOC_CLASSES.index(cls_name)
                xmin = int(obj.find('bndbox').find('xmin').text)
                xmax = int(obj.find('bndbox').find('xmax').text)
                ymin = int(obj.find('bndbox').find('ymin').text)
                ymax = int(obj.find('bndbox').find('ymax').text)
            bboxes.append([xmin, xmax, ymin, ymax, cls_idx])

        return bboxes

    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        annotation_path = self.annotations[idx]

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        tree = ET.parse(annotation_path)
        root = tree.getroot()
        bboxes = parse_xml_boxes(root)

        image_width = int(root.find('size/width').text)
        image_height = int(root.find('size/height').text)
        
            

In [65]:
from torchvision.transforms import v2

VOC_VAL = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]

img_transform = v2.Compose([
    v2.Resize((448, 448)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=VOC_VAL[0], std=VOC_VAL[1])
])

In [66]:
root = '/Users/h383kim/pytorch/data'
train_dataset = CustomDataset(root, img_transform, 'trainval', 20)
test_dataset = CustomDataset(root, img_transform, 'test', 20)
print(len(train_dataset))
print(len(test_dataset))

5011
4952


In [67]:
print(train_dataset[2])

None


In [36]:


train_dataset.transform = transforamtion
test_dataset.transform = transforamtion

In [37]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_loader = DataLoader(train_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True)
test_loader = DataLoader(test_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)