<a href="https://colab.research.google.com/github/ivlucky/freelance_steel_defects/blob/master/2_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installation

In [1]:
!pip install xmltodict==0.12.0 --quiet
!pip install pytorch_lightning==1.4.9 --quiet
!pip install segmentation-models-pytorch==0.2.0 --quiet
!pip install albumentations==1.1.0 --quiet

# Imports

In [2]:
import os
import xmltodict
import numpy as np
from typing import Dict
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import albumentations as albu

# Mount drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Global variables

In [4]:
DATAFOLDER = "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset"
INPUTFOLDER = os.path.join(DATAFOLDER, "input")
ANNOFOLDER = os.path.join(INPUTFOLDER, "anno")
IMAGESFOLDER = os.path.join(INPUTFOLDER, "images")
AUGMENTEDFOLDER = os.path.join(DATAFOLDER, "augmented")
PROCESSEDFOLDER = os.path.join(DATAFOLDER, "processed")
MASKFOLDER = os.path.join(PROCESSEDFOLDER, "masks")
SPLITFOLDER = os.path.join(DATAFOLDER, "split")
RANDOM_STATE = 42

# Functions

In [5]:
def get_xmldict(xmlfile):

    with open(xmlfile, 'r') as fp:

        xmlcontent = fp.read()
        xmldict = xmltodict.parse(xmlcontent)

    return xmldict

def set_bndbox(mask, bndbox):

    xmin, xmax = int(bndbox['xmin']), int(bndbox['xmax'])
    ymin, ymax = int(bndbox['ymin']), int(bndbox['ymax'])

    mask[xmin:xmax, ymin:ymax] = 1

def set_allbndbox(mask, objects):

    for obj in objects:
        set_bndbox(mask, obj['bndbox'])

def get_mask(xmlfile):

    xmldict = get_xmldict(xmlfile)

    mask = np.zeros((200, 200))
    if isinstance(xmldict['annotation']['object'], Dict):
        set_bndbox(mask, xmldict['annotation']['object']['bndbox'])
    else:
        set_allbndbox(mask, xmldict['annotation']['object'])

    return mask.T

def show_image(imgfile):

    img = get_image(imgfile)
    print(img.shape)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 14))
    ax.imshow(img)
    ax.axis('off')
    return fig, ax

def show_mask(xmlfile, fig, ax):

    mask = get_mask(xmlfile)
    ax.imshow(mask, alpha=0.3)

def show_obj(object_val, datafolder='./'):

    imgfile = os.path.join(datafolder, 'images', object_val + '.jpg')
    xmlfile = os.path.join(datafolder, 'anno', object_val + '.xml')

    fig, ax = show_image(imgfile)
    show_mask(xmlfile, fig, ax)

def sets_describe(set_1, set_2):

    print(f"set_1 len: {len(set_1)}")
    print(f"set_2 len: {len(set_2)}")
    print(f"set_1 - set_2 len: {len(set_1 - set_2)}")
    print(f"set_2 - set_1 len: {len(set_2 - set_1)}")
    print(f"symmetric diff len: {len(set_2.symmetric_difference(set_1))}")
    print(f"intersection len: {len(set_1.intersection(set_2))}")
    print(f"union len: {len(set_1.union(set_2))}")

def get_image(imgfile):

    return mpimg.imread(imgfile)

def save_mask(maskfile, mask):

  np.save(maskfile, mask)

def load_mask(maskfile):

  return np.load(maskfile)

# Classes

In [None]:
class SegDataModule(pl.LightningDataModule):
  def __init__(self, 
               anno_dir: str = "./anno",
               img_dir: str = "./images",
               mask_dir: str = "./masks"):

    super().__init__()
    self.anno_dir = anno_dir
    self.img_dir = img_dir
    self.mask_dir = mask_dir

  def prepare_data(self):

    jpgfiles = sorted(os.listdir(self.anno_dir))
    xmlfiles = sorted(os.listdir(self.img_dir))

    set_1 = set([f.split('.')[0] for f in jpgfiles])
    set_2 = set([f.split('.')[0] for f in xmlfiles])
    common_objects = set_1.intersection(set_2)

    os.makedirs(self.mask_dir, exist_ok=True)

    for obj in tqdm(common_objects):

      xmlfile = os.path.join(self.anno_dir, ".".join([obj, "xml"]))
      mask = get_mask(xmlfile)
      maskfile = os.path.join(self.img_dir, ".".join([obj, "npy"]))
      save_mask(maskfile, mask)

# Test check