In [1]:
from mmseg.registry import DATASETS, VISUALIZERS


  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from mmseg.datasets.basesegdataset import BaseSegDataset


In [2]:
class XRayDataset(BaseSegDataset):
    METAINFO = dict(
        classes=('finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
                'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
                'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
                'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
                'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
                'Triquetrum', 'Pisiform', 'Radius', 'Ulna',),
        palette=[
            (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
            (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
            (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
            (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
            (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
            (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),])
    def __init__(self,
                 img_suffix='.png',
                 seg_map_suffix='.json',
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)


In [5]:
data_root = '/data/ephemeral/home/data'
data_prefix=dict(img_path='fold_0/images', seg_map_path='fold_0/annos')


In [6]:
dataset = XRayDataset(data_root=data_root, data_prefix=data_prefix, test_mode=False)


In [7]:
dataset

<__main__.XRayDataset at 0x7f6dda880ca0>

In [8]:
print(len(dataset))


160


In [9]:
print(dataset.get_data_info(0))


{'img_path': '/data/ephemeral/home/data/fold_0/images/ID001_image1661130828152_R.png', 'seg_map_path': '/data/ephemeral/home/data/fold_0/annos/ID001_image1661130828152_R.json', 'label_map': None, 'reduce_zero_label': False, 'seg_fields': [], 'sample_idx': 0}


In [10]:
print(dataset.metainfo)


{'classes': ('finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5', 'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10', 'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15', 'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium', 'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate', 'Triquetrum', 'Pisiform', 'Radius', 'Ulna'), 'palette': [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176)], 'label_map': None, 'reduce_zero_label': False}


In [12]:
dataset[0]['seg_map_path']

'/data/ephemeral/home/data/fold_0/annos/ID001_image1661130828152_R.json'

In [13]:
import json

In [19]:
from mmcv.transforms import LoadImageFromFile
tran = LoadImageFromFile()


In [20]:
dataset[0]

{'img_path': '/data/ephemeral/home/data/fold_0/images/ID001_image1661130828152_R.png',
 'seg_map_path': '/data/ephemeral/home/data/fold_0/annos/ID001_image1661130828152_R.json',
 'label_map': None,
 'reduce_zero_label': False,
 'seg_fields': [],
 'sample_idx': 0}

In [23]:
res =  tran.transform(dataset[0])

In [25]:
res['img_shape']

(2048, 2048)

In [28]:
import numpy as np
import cv2


In [27]:
CLASSES = 29

In [29]:
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]
CLASS2IND = {v: i for i, v in enumerate(CLASSES)}

In [None]:
with open(dataset[0]['seg_map_path'], "r") as f:
    annotations = json.load(f)
annotations = annotations["annotations"]

label_shape = res['img_shape'] + (len(CLASSES), )
label = np.zeros(label_shape, dtype=np.uint8)

for ann in annotations:
    c = ann["label"]
    class_ind = CLASS2IND[c]
    points = np.array(ann["points"])
    
    # polygon 포맷을 dense한 mask 포맷으로 바꿉니다.
    class_label = np.zeros(res['img_shape'], dtype=np.uint8)
    cv2.fillPoly(class_label, [points], 1)
    label[..., class_ind] = class_label
label = label.transpose(2, 0, 1)


In [None]:
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations

In [48]:
# @TRANSFORMS.register_module()
class LoadXRayAnnotations(MMCV_LoadAnnotations):

    def __init__(
        self,
        # reduce_zero_label=None,
        backend_args=None,
        # imdecode_backend='pillow',
    ) -> None:
        super().__init__(
            with_bbox=False,
            with_label=False,
            with_seg=True,
            with_keypoints=False,
            # imdecode_backend=imdecode_backend,
            backend_args=backend_args)
        # self.reduce_zero_label = reduce_zero_label
        # if self.reduce_zero_label is not None:
        #     warnings.warn('`reduce_zero_label` will be deprecated, '
        #                   'if you would like to ignore the zero label, please '
        #                   'set `reduce_zero_label=True` when dataset '
        #                   'initialized')
        # self.imdecode_backend = imdecode_backend

        self.CLASSES = [
            'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
            'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
            'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
            'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
            'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
            'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
        ]
        self.CLASS2IND = {v: i for i, v in enumerate(self.CLASSES)}

    def _load_seg_map(self, results: dict) -> None:

        with open(results['seg_map_path'], "r") as f:
            annotations = json.load(f)
        annotations = annotations["annotations"]

        label_shape = res['img_shape'] + (len(self.CLASSES), )
        label = np.zeros(label_shape, dtype=np.uint8)

        for ann in annotations:
            c = ann["label"]
            class_ind = self.CLASS2IND[c]
            points = np.array(ann["points"])
            
            # polygon 포맷을 dense한 mask 포맷으로 바꿉니다.
            class_label = np.zeros(res['img_shape'], dtype=np.uint8)
            cv2.fillPoly(class_label, [points], 1)
            label[..., class_ind] = class_label
        
        label = label.transpose(2, 0, 1)

        results['gt_seg_map'] = label
        results['seg_fields'].append('gt_seg_map')

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'backend_args={self.backend_args})'
        return repr_str


In [49]:
tran2 = LoadXRayAnnotations()

In [52]:
tran2.transform(res)['gt_seg_map'].shape

(29, 2048, 2048)

In [None]:
res

In [43]:
label.shape

(2048, 2048, 29)

In [42]:
label.sum()

458681

In [None]:
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    # dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75),
    # dict(type='RandomFlip', prob=0.5),
    dict(type='PackSegInputs')
]


In [None]:
dataset = XRayDataset(data_root=data_root, data_prefix=data_prefix, test_mode=False, pipeline=train_pipeline)
