In [2]:
from PIL import Image
import numpy as np
from collections import defaultdict
import json
import numpy as np
import cv2
import os
import csv
import matplotlib.pyplot as plt
import uuid


In [3]:
class Mapping:
    def __init__(self, annotation_map, annotation_name):
        self.annotation_map = annotation_map
        self.annotation_name = annotation_name
    def __str__(self):
        return f"Mapping(annotation_map={self.annotation_map}, annotation_name={self.annotation_name})"

class COCOParser:
    def __init__(self, anns_file, imgs_dir, path_root):
        with open(anns_file, 'r') as f:
            coco = json.load(f)

        self.mapping = self.get_mapping(path_root + "/mapping.csv")
        self.imgs_dir = imgs_dir
        self.annIm_dict = defaultdict(list)        
        self.cat_dict = {} 
        self.annId_dict = {}
        self.im_dict = {}
        self.licenses_dict = {}
        self.cat_dict = {} 
        for cat in coco['categories']:
            self.cat_dict[cat['id']] = cat
        for ann in coco['annotations']:           
            self.annIm_dict[ann['image_id']].append(ann) 
            self.annId_dict[ann['id']]=ann
        for img in coco['images']:
            self.im_dict[img['id']] = img
        for cat in coco['categories']:
            self.cat_dict[cat['id']] = cat
        for license in coco['licenses']:
            self.licenses_dict[license['id']] = license

    def get_mapping(self, csv_path):
        annotation_dict = {}
        try:
            with open(csv_path, 'r') as f:
                reader = csv.DictReader(f)
                
                for row in reader:
                    annotation = row['annotation']
                    annotation_map = row['annotation_map']
                    annotation_name = row['annotation_name']
                    if annotation in annotation_dict:
                        continue
                    annotation_dict[annotation] = Mapping(annotation_map, annotation_name) 
        except: 
            pass
        return annotation_dict
    def get_imgIds(self):
        return list(self.im_dict.keys())
    def get_annIds(self, im_ids):
        im_ids=im_ids if isinstance(im_ids, list) else [im_ids]
        return [ann['id'] for im_id in im_ids for ann in self.annIm_dict[im_id]]
    def load_anns(self, ann_ids):
        im_ids=ann_ids if isinstance(ann_ids, list) else [ann_ids]
        return [self.annId_dict[ann_id] for ann_id in ann_ids]        
    def load_cats(self, class_ids):
        class_ids=class_ids if isinstance(class_ids, list) else [class_ids]
        return [self.cat_dict[class_id] for class_id in class_ids]
    def get_imgLicenses(self,im_ids):
        im_ids=im_ids if isinstance(im_ids, list) else [im_ids]
        lic_ids = [self.im_dict[im_id]["license"] for im_id in im_ids]
        return [self.licenses_dict[lic_id] for lic_id in lic_ids]
    def get_path_images(self,im_ids):
        im_ids=im_ids if isinstance(im_ids, list) else [im_ids]
        lic_ids = [self.im_dict[im_id] for im_id in im_ids]
        return lic_ids
    def load_cats(self):
        print("Danh sách các loại nhãn:")
        for cat_id, category in self.cat_dict.items():
            print(f"ID: {cat_id}, Name: {category['name']}")
            
    def process(self, size, save = False, root = "data-process", max = 1500):
        x_data = []
        y_data = []
        label = []
        
        sel_im_idxs = self.get_imgIds()
        img_ids = self.get_imgIds()
        selected_img_ids = [img_ids[i] for i in sel_im_idxs]
        ann_ids = self.get_annIds(selected_img_ids)
        im_licenses = self.get_imgLicenses(selected_img_ids)
        im_path = self.get_path_images(selected_img_ids)
        count_label = {}
        for i, im in enumerate(im_path):
            try:
                if(len(x_data) > 10):
                    break
                image =  Image.open(f"{self.imgs_dir}/{im['file_name']}")
                ann_ids = self.get_annIds(im['id'])
                if image.format != 'JPEG':
                    continue
                annotations = self.load_anns(ann_ids)
                for ann in annotations:
                    try:
                        class_id = ann["category_id"]
                        class_id = self.mapping.get(f"{class_id}", None)
                        if(class_id is None):
                            continue
                        bbox = ann['bbox']
                        x, y, w, h = [int(b) for b in bbox]
                        cropped_image = image.crop((x, y, x + w, y + h))
                        cropped_image = cropped_image.resize(size)
                        if(save == True):
                            count = count_label.get(f"{class_id.annotation_name}", 0)
                            if(count >= max):
                                continue
                            save_path =f'{root}/{class_id.annotation_name}/{uuid.uuid1()}.jpg'
                            os.makedirs(os.path.dirname(save_path), exist_ok=True)
                            cropped_image.save(save_path)
                            count_label[f"{class_id.annotation_name}"] = count + 1
                            continue
                        image = np.array(cropped_image)
                        x_data.append(image)
                        y_data.append(int(class_id.annotation_map))
                        label.append(class_id.annotation_name)
                    except Exception as e:
                        pass
            except Exception as e:
                pass

        return x_data, y_data, label

In [12]:
class JoinDataset:
    def __init__(self, type = 'train', root ='dataset_root'):
        self.datasets = self.get_all_dataset(root, type)
        self.type = type
    def get_all_dataset(self, root, type):
        parsers = []
        for subdir in os.listdir(root):
            if os.path.isdir(os.path.join(root, subdir)):
                coco_annotations_file = f"{root}/{subdir}/{type}/_annotations.coco.json" 
                coco_images_dir = f"{root}/{subdir}/{type}" 
                coco= COCOParser(coco_annotations_file, coco_images_dir, f"{root}/{subdir}" )
                parsers.append(coco)
                print("Dataset: " +subdir)
                coco.load_cats()
                for key, mapping in coco.mapping.items():
                    print(f"{key}: {mapping}")
        return parsers
    
    def mapping_all_dataset(self, size, save = True):
        x_data_all = []
        y_data_all = []
        labels_all = []

        for coco in self.datasets:
            x_data, y_data, label = coco.process(size, save)
            x_data_all.extend(x_data)
            y_data_all.extend(y_data)
            labels_all.extend(label)

        x_data_all = np.array(x_data_all)
        y_data_all = np.array(y_data_all)
        labels_all = np.array(labels_all)
        return x_data_all, y_data_all, labels_all


In [13]:
join_dataset = JoinDataset()
join_dataset.mapping_all_dataset((224,224))


(array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))