# Imports

In [None]:
import xml.etree.ElementTree as ET
import xmltodict
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from typing import Dict

# Check xml files

In [None]:
datafolder = './data'
xmlfile = os.listdir('./data/anno')[0]

In [None]:
tree = ET.parse(os.path.join(datafolder, 'anno', xmlfile))
root = tree.getroot()

In [None]:
for child in root:
    print(child.tag, child.attrib)

In [None]:
root[5][4][0].text

# xml -> dict

In [None]:
with open(os.path.join(datafolder, 'anno', xmlfile), 'r') as fp:

    xmlcontent = fp.read()
    d = xmltodict.parse(xmlcontent)
d['annotation']['object'][1]

# Functions

In [None]:
def set_bndbox(mask, bndbox):

    xmin, xmax = int(bndbox['xmin']), int(bndbox['xmax'])
    ymin, ymax = int(bndbox['ymin']), int(bndbox['ymax'])

    mask[xmin:xmax, ymin:ymax] = 1

def set_allbndbox(mask, objects):

    for obj in objects:
        set_bndbox(mask, obj['bndbox'])

def sets_describe(set_1, set_2):

    print(f"set_1 len: {len(set_1)}")
    print(f"set_2 len: {len(set_2)}")
    print(f"set_1 - set_2 len: {len(set_1 - set_2)}")
    print(f"set_2 - set_1 len: {len(set_2 - set_1)}")
    print(f"symmetric diff len: {len(set_2.symmetric_difference(set_1))}")
    print(f"intersection len: {len(set_1.intersection(set_2))}")
    print(f"union len: {len(set_1.union(set_2))}")

def get_image(imgfile):

    return mpimg.imread(imgfile)

def get_xmldict(xmlfile):

    with open(xmlfile, 'r') as fp:

        xmlcontent = fp.read()
        xmldict = xmltodict.parse(xmlcontent)

    return xmldict

def get_mask(xmlfile):

    xmldict = get_xmldict(xmlfile)

    mask = np.zeros((200, 200))
    if isinstance(xmldict['annotation']['object'], Dict):
        set_bndbox(mask, xmldict['annotation']['object']['bndbox'])
    else:
        set_allbndbox(mask, xmldict['annotation']['object'])

    return mask

def show_image(imgfile):

    img = get_image(imgfile)
    print(img.shape)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 14))
    ax.imshow(img)
    ax.axis('off')
    return fig, ax

def show_mask(xmlfile, fig, ax):

    mask = get_mask(xmlfile)

    ax.imshow(mask.T, alpha=0.3)

def show_obj(object_val, datafolder='./'):

    imgfile = os.path.join(datafolder, 'images', object_val + '.jpg')
    xmlfile = os.path.join(datafolder, 'anno', object_val + '.xml')

    fig, ax = show_image(imgfile)
    show_mask(xmlfile, fig, ax)

def get_stats(object_val, datafolder='./'):

    imgfile = os.path.join(datafolder, 'images', object_val + '.jpg')
    xmlfile = os.path.join(datafolder, 'anno', object_val + '.xml')

    img = get_image(imgfile)
    mask = get_mask(xmlfile)

    return img.shape, mask.sum()

# Some tests

In [None]:
mask = np.zeros((200,200))
bndbox = d['annotation']['object'][0]['bndbox']
print(mask.sum())
set_bndbox(mask, bndbox)
print(mask.sum())

In [None]:
mask = np.zeros((200,200))
objects = d['annotation']['object']
print(mask.sum())
set_allbndbox(mask, objects)
print(mask.sum())

In [None]:
jpgfiles = sorted(os.listdir('./data/images'))
xmlfiles = sorted(os.listdir('./data/anno'))

set_1 = set([f.split('.')[0] for f in jpgfiles])
set_2 = set([f.split('.')[0] for f in xmlfiles])
common_objects = set_1.intersection(set_2)
sets_describe(set_1, set_2)

# Get stats for images

In [None]:
obj_dict = {obj: obj.split('_')[0] for obj in common_objects}
df_obj = pd.DataFrame([obj_dict.keys(), obj_dict.values()], index=['obj', 'type']).T
df_obj['id'] = df_obj['obj'].str.split('_').str[1].astype(int)
df_obj = df_obj.sort_values(['type', 'id']).reset_index(drop=True)

In [None]:
df_obj['type'].unique()

In [None]:
df_obj['id'].nunique()

In [None]:
i = 7
object_val = df_obj.loc[i, 'obj']
show_obj(object_val, datafolder)

In [None]:
# for idx, row in tqdm(df_obj.iterrows()):
df_merge = {
    'obj': [],
    'w': [],
    'h': [],
    'd': [],
    'm_s': [],
}
for idx, row in df_obj.iterrows():

    img_shape, mask_size = get_stats(row['obj'], datafolder)
    df_merge['obj'].append(row['obj'])
    df_merge['w'].append(img_shape[0])
    df_merge['h'].append(img_shape[1])
    df_merge['d'].append(img_shape[2])
    df_merge['m_s'].append(mask_size)

df_merge = pd.DataFrame(df_merge)

df_stats = df_obj.merge(df_merge, on='obj', how='left')

df_stats['m_s_p'] = df_stats['m_s'] / (df_stats['w'] * df_stats['h']) 
df_stats.head()

In [None]:
df_stats['type'].unique()

In [None]:
mask = df_stats['type'] == 'inclusion'
df_stats[mask].describe().drop(['count'], axis=0).drop(['id'], axis=1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 12))
df_stats[mask]['m_s_p'].hist(ax=ax, alpha=0.3, label='inclusion')
df_stats[~mask]['m_s_p'].hist(ax=ax, alpha=0.3, label='scratches')
plt.legend()