# Datasets
**Description:** This file contains code for defining custom datasets for FathomNet\
**Author:** Garðar Ingvarsson\
**Email:** gi241@cam.ac.uk
---

In [1]:
import torch
import os
import numpy as np
import json
from PIL import Image
import xml.etree.ElementTree as ET
import sys
from pathlib import Path
import matplotlib.pyplot as plt

sys.path.append(os.path.join(Path(os.path.abspath('')).parent, 'utils'))

from fathomnethelper import Taxonomicon

In [39]:
class FathomNetDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, 'images'))))
        self.anns = list(sorted(os.listdir(os.path.join(root, 'annotations'))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'images', self.imgs[idx])
        ann_path = os.path.join(self.root, 'annotations', self.anns[idx])
        img = Image.open(img_path).convert('RGB')
        ann_tree = ET.parse(ann_path)
        root = ann_tree.getroot()
        names = []
        boxes = []
        for box in root.iter('object'):
            name = box.find('name').text
            xmin = int(box.find('bndbox/xmin').text)
            ymin = int(box.find('bndbox/ymin').text)
            xmax = int(box.find('bndbox/xmax').text)
            ymax = int(box.find('bndbox/ymax').text)
            names.append(name)
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)
        image_id = torch.tensor([idx])

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = image_id

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)
    
    
class FathomNetCroppedDataset(torch.utils.data.Dataset):
    def __init__(self, root, rank=None, top_n=None, classes=None, transforms=None):
        self.root = root
        self.transforms = transforms
        imgs = list(sorted(os.listdir(os.path.join(root, 'images'))))
        tax = Taxonomicon()
        
        if not classes:
            classes = tax.get_concepts_at_rank(rank)
            
        subtrees_of_each_class = {concept: set(tax.get_subtree_nodes(concept)) for concept in classes}
        class_mapping = {}
        for clss in classes:
            for val in subtrees_of_each_class[clss]:
                class_mapping[val] = clss

        imgs_and_boxes = {concept: [] for concept in classes}
        
        for img in imgs:
            ann_tree = ET.parse(os.path.join(root, 'annotations', os.path.splitext(img)[0] + '.xml'))
            tree_root = ann_tree.getroot()
            for box in tree_root.iter('object'):
                name = box.find('name').text
                if name in class_mapping:
                    clss = class_mapping[name]
                    xmin = int(box.find('bndbox/xmin').text)
                    ymin = int(box.find('bndbox/ymin').text)
                    xmax = int(box.find('bndbox/xmax').text)
                    ymax = int(box.find('bndbox/ymax').text)
                    imgs_and_boxes[clss].append((img, (xmin, ymin, xmax, ymax)))
        
        filtered = sorted(imgs_and_boxes.items(), key=lambda item: (len(item[1]), item[0]), reverse=True)[:top_n]
        self.boxes = []
        self.labels = []
        self.classes = []
        for name, boxes in filtered:
            self.classes.append(name)
            for box in boxes:
                self.boxes.append(box)
                self.labels.append(len(self.classes)-1)  # Map class names to integers
                
                
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'images', self.boxes[idx][0])
        box = self.boxes[idx][1]
        label = self.labels[idx]
        img = Image.open(img_path).convert('RGB').crop(box)

        if self.transforms is not None:
            img = self.transforms(img)

        return img, label

