In [12]:
import os
import xml.etree.ElementTree as ET

In [13]:
root_path = 'e:\\COMP\\COMP9444\\project\\dumpsite_data\\VOC2012'
train_path = os.path.join(root_path, "train")
test_path = os.path.join(root_path, "test")
file_Annotations = os.path.join(train_path, "Annotations")

In [14]:
object_class = []
for each_xml in os.listdir(file_Annotations):
    pic_xml = os.path.join(file_Annotations, each_xml)
    tree = ET.parse(pic_xml)
    root = tree.getroot()
    for object_elem in root.findall('object'):
        name_elem_value = object_elem.find('name').text
        object_class.append(name_elem_value)
print(set(object_class))


{'domestic garbage', 'mining waste', 'agriculture forestry', 'construction waste', 'industry waste', 'disposed garbage'}


##### Due to the problem of severe sample imbalance in the dumpsite dataset (Fig. 1a), we propose two training strategies, data augmentation (vertical flipping, horizontal
##### flipping, forward 90° rotation and reverse 90° rotation) and category balancing, to ensure the model’sefficiency during the training process

In [15]:
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

In [16]:
class  VOCDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        txt_file = os.path.join(root, "train.txt")
        with open(txt_file, 'r') as f:
            self.image_ids = f.read().strip().split()
        self.image_folder = os.path.join(root, "JPEGImages")
        self.ann_folder = os.path.join(root, "Annotations")

    def __len__(self):
        return len(self.image_folder)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_folder, f"{img_id}.jpg")
        ann_path = os.path.join(self.ann_folder, f"{img_id}.xml")

        img = Image.open(img_path).convert("RGB")
        tree = ET.parse(ann_path)
        root = tree.getroot()
        boxes = []
        labels = []
        for obj in root.findall("object"):
            label = obj.find("name").text
            labels.append(label)
            bbox = obj.find("bndbox")
            box = [
                int(bbox.find("xmin").text),
                int(bbox.find("ymin").text),
                int(bbox.find("xmax").text),
                int(bbox.find("ymax").text),
            ]
            boxes.append(box)
        
        boxes = torch.tensor(boxes, dtype=torch.float32)
        # labels = torch.tensor(labels, dtype=torch.int64)
        target = {"boxes":boxes, "labels":labels}
        if self.transforms:
            img = self.transforms(img)
        return img, target

In [17]:
import random
from torchvision.transforms import functional as F

In [18]:
class DetectionTransforms:
    def __init__(self):
        self.augment = T.Compose([
            # T.RandomHorizontalFlip(),
            # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.RandomRotation(30),
            T.RandomResizedCrop(size=(300, 300), scale=(0.0, 1.0)),
        ])
        self.normalize = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, image):
        image = self.augment(image)
        image = self.normalize(image)
        return image

In [19]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def unnormalize(image, mean, std):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return image * std * mean

In [20]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    images = []
    targets = []
    for obj in batch:
        images.append(obj[0])
        targets.append(obj[1])
    images = torch.stack(images, dim=0)
    return images, targets


transforms = DetectionTransforms()
dataset = VOCDataset(root='./dumpsite_data/VOC2012/train', transforms=transforms)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

for images, targets in dataloader:
    PIL_image = unnormalize(images[0], mean, std)
    PIL_image = F.to_pil_image(PIL_image)
    PIL_image.show()
    print(targets[0]['boxes'])
    print(targets[0]['labels'])

tensor([[124., 667., 217., 783.],
        [256., 613., 347., 666.]])
['construction waste', 'construction waste']
tensor([[182., 614., 268., 728.]])
['domestic garbage']
tensor([[796., 605., 916., 689.]])
['construction waste']
