In [None]:
!pip install kagglehub

In [None]:
import kagglehub

data_dir = kagglehub.dataset_download("andrewmvd/dog-cat-detection")
print("Path to dataset file: ", data_dir)

In [None]:
import os
import torch
import random
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET 

from PIL import Image
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torchvision.models.resnet import ResNet18_Weights

In [None]:
class MyDataset(Dataset):
    def __init__(self, annotations_dir, image_dir, transform=None):
        self.annotations_dir = annotations_dir
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = self.filter_images_with_multiple_objects()
        
    def filter_images_with_multiple_objects(self):
        valid_image_files = []
        for f in os.listdir(self.image_dir):
            if os.path.isfile(os.path.join(self.image_dir, f)):
                img_name = f
                annotation_name = os.path.splitext(img_name)[0] +".xml"
                annotation_path = os.path.join(self.annotations_dir, annotation_name)
                
                # Keep iomages that have single object
                if self.count_objects_in_annotation(annotation_path) == 1:
                    valid_image_files.append(img_name)
        return valid_image_files
    
    def count_objects_in_annotation(self, annotation_path):
        try:
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            count = 0
            for obj in root.findall("object"):
                count += 1
        except FileNotFoundError:
            return 0
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img1_name = self.image_files[idx]
        img1_path = os.path.join(self.image_dir, img1_name)
        
        annotation_name = os.path.splitext(img1_name)
        img1_annotations = self.parse_annotation(
            os.path.join(self.annotations_dir, annotation_name)
        )
        idx2 = random.randint(0, len(self.image_files) - 1)
        img2_file = self.image_files[idx2]
        img2_path = os.path.join(self.image_dir, img2_file)
        
        annotation_name = os.path.splitext(img2_file)[0] + ".xml"
        img2_annotations = self.parse_annotation(
            os.path.join(self.annotations_dir, annotation_name)
        )
        
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img1_path).convert("RGB")

        merged_image = Image.new("RGB", (img1.width + img2.width, max(img1.height, img2.height)))
        
        merged_image.paste(img1, (0, 0))
        merged_image.paste(img2, (img1.width, 0))
        merged_w = img1.width + img2.width
        merged_h = max(img1.height, img2.height)
        
        merged_annotations = []
        
        merged_annotations.append(
            {
                "bbox": img1_annotations[1].tolist(),
                "label": img1_annotations[0]
            }
        )
        
        new_bbox = [
            ()
        ]
    
    def parse_annotation(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        
        image_width = int(root.find('size/width').text)
        image_height = int(root.find('size/height').text)
        
        label = None
        bbox = None
        for obj in root.findall("object"):
            name = obj.find("name").text
            
            if label is None:
                label = name
                xmin = int(obj.find('bndbox/xmin').text)
                ymin = int(obj.find('bndbox/ymin').text)
                xmax = int(obj.find('bndbox/xmax').text)
                ymax = int(obj.find('bndbox/ymax').text)
                
                bbox = [
                    xmin / image_width,
                    ymin / image_height,
                    xmax / image_width,
                    ymax / image_height,
                ]
        
        label_num = 0 if label == 'cat' else 1 if label =='dog' else -1
        
        return label_num, torch.tensor(bbox, dtype=torch.float32)