In [38]:
import os
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
import xml.etree.ElementTree as ET
import torch
from torchvision.transforms import ToPILImage

In [4]:
# constants
image_folder = "./data/images"
annotation_folder = "./data/annotations"
image_type = ".jpg"
annotation_type = ".xml"

In [62]:
# read images and annotations 
# logic: different folders/same file name/different file type
def get_files(image_folder, annotation_folder, image_type=".jpg", annotation_type=".xml"):
    files = [
    (os.path.join(image_folder, image_file), os.path.join(annotation_folder, image_file.replace(image_type, annotation_type))) \
        for image_file in os.listdir(image_folder)
    ]
    return files

# parse xml file
def get_annotation(file_path):
    tree = ET.parse(file_path) 
    root = tree.getroot()
    bbox_coordinates = []
    class_name = []
    for member in root.findall('object'):
        class_name.append(member[0].text) # class name
            
        # bbox coordinates
        xmin = int(member[4][0].text)
        ymin = int(member[4][1].text)
        xmax = int(member[4][2].text)
        ymax = int(member[4][3].text)
        # store data in list
        bbox_coordinates.append([xmin, ymin, xmax, ymax])

    return {"class_name": class_name,
    "bbox": torch.Tensor(bbox_coordinates)}

# visualize annotation
def get_visualization(image_tensor, annotations, out_path = None):
    transformer = ToPILImage()
    annotated_tensor = draw_bounding_boxes(image_tensor, annotations)
    annotated_image = transformer(annotated_tensor)
    if out_path is None:
        annotated_image.show()
    else:
        annotated_image.save(out_path)
    

In [84]:
# define data module
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms


class ObjectDetectionDataSet(Dataset):
    def __init__(self, data_tuples, tranformations = None) -> None:
        super().__init__()
        self.data_tuples = data_tuples
        self.transformations = tranformations
    
    def __len__(self):
        return len(self.data_tuples)

    def __getitem__(self, index):
        img = read_image(self.data_tuples[index][0])
        annotations = get_annotation(self.data_tuples[index][1])
        if self.transformations:
            img = self.transformations(img)
        return img, annotations

class ObjectDetectionDataModule(LightningDataModule):
    def __init__(self, image_folder, annotation_folder, transformations=None, test_suffix=None):
        super().__init__()
        self.image_folder = image_folder
        self.annotation_folder = annotation_folder
        self.transform = transformations
        self.train_val_split = 0.8
        self.test_suffix = test_suffix
        if test_suffix:
            self.image_folder_test = image_folder + test_suffix
            self.annotation_folder_test = annotation_folder + test_suffix

    
    def prepare_data(self) -> None:
        self.data_tuples = get_files(self.image_folder, self.annotation_folder)
        train_len = int(len(self.data_tuples)*self.train_val_split)
        val_len = len(self.data_tuples) - train_len
        self.seq_lengths = [train_len, val_len]
        if self.test_suffix:
            self.data_tuples_test = get_files(self.image_folder_test, self.annotation_folder_test)
        else:
            self.data_tuples_test = []

    def setup(self, stage: str = "fit"):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            dataset = ObjectDetectionDataSet(self.data_tuples, self.transform)
            
            self.train_set, self.val_set = random_split(dataset, self.seq_lengths, torch.Generator().manual_seed(42))

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.test_set = ObjectDetectionDataSet(self.data_tuples_test)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=32)

In [85]:
object_det = ObjectDetectionDataModule(
    image_folder,
    annotation_folder
)

In [None]:
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler import SimpleProfiler
from .config import config_file

ObjectDetectionLogger = TensorBoardLogger())
ObjectDetectionCheckoints = []
ObjectDetectionProfiler = SimpleProfiler()