<a href="https://colab.research.google.com/github/ivlucky/freelance_steel_defects/blob/master/2_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Magic

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Installation

In [None]:
!pip install xmltodict==0.12.0 --quiet
!pip install pytorch_lightning==1.4.9 --quiet
!pip install segmentation-models-pytorch==0.2.0 --quiet
!pip install albumentations==1.1.0 --quiet
!pip install tensorboard==2.7.0 # for monitoring training

# Imports

In [3]:
import os
import xmltodict
import numpy as np
from typing import Dict
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import albumentations as albu

from sklearn.model_selection import train_test_split
import itertools

# Mount drive

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Global variables

In [9]:
DATAFOLDER = "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset"
INPUTFOLDER = os.path.join(DATAFOLDER, "input")
ANNOFOLDER = os.path.join(INPUTFOLDER, "anno")
IMAGESFOLDER = os.path.join(INPUTFOLDER, "images")
AUGMENTEDFOLDER = os.path.join(DATAFOLDER, "augmented")
PROCESSEDFOLDER = os.path.join(DATAFOLDER, "processed")
MASKFOLDER = os.path.join(PROCESSEDFOLDER, "masks")
SPLITFOLDER = os.path.join(DATAFOLDER, "split")
OUTPUTFOLDER = os.path.join(DATAFOLDER, "output")
LOGFOLDER = os.path.join(OUTPUTFOLDER, "logs")

RANDOM_STATE = 42
VALSIZE = 0.2
TESTSIZE = 0.1

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
NUMCLASSES = 1
# Note: don't use activation in model level because of process stability
ACTIVATION = None#'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')
DEVICE

device(type='cpu')

# Functions

In [6]:
def get_xmldict(xmlfile):

    with open(xmlfile, 'r') as fp:

        xmlcontent = fp.read()
        xmldict = xmltodict.parse(xmlcontent)

    return xmldict

def set_bndbox(mask, bndbox):

    xmin, xmax = int(bndbox['xmin']), int(bndbox['xmax'])
    ymin, ymax = int(bndbox['ymin']), int(bndbox['ymax'])

    mask[xmin:xmax, ymin:ymax] = 1

def set_allbndbox(mask, objects):

    for obj in objects:
        set_bndbox(mask, obj['bndbox'])

def get_mask(xmlfile):

    xmldict = get_xmldict(xmlfile)

    mask = np.zeros((200, 200))
    if isinstance(xmldict['annotation']['object'], Dict):
        set_bndbox(mask, xmldict['annotation']['object']['bndbox'])
    else:
        set_allbndbox(mask, xmldict['annotation']['object'])

    return mask.T

def show_image(imgfile):

    img = get_image(imgfile)
    print(img.shape)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 14))
    ax.imshow(img)
    ax.axis('off')
    return fig, ax

def show_mask(xmlfile, fig, ax):

    mask = get_mask(xmlfile)
    ax.imshow(mask, alpha=0.3)

def show_obj(object_val, datafolder='./'):

    imgfile = os.path.join(datafolder, 'images', object_val + '.jpg')
    xmlfile = os.path.join(datafolder, 'anno', object_val + '.xml')

    fig, ax = show_image(imgfile)
    show_mask(xmlfile, fig, ax)

def sets_describe(set_1, set_2):

    print(f"set_1 len: {len(set_1)}")
    print(f"set_2 len: {len(set_2)}")
    print(f"set_1 - set_2 len: {len(set_1 - set_2)}")
    print(f"set_2 - set_1 len: {len(set_2 - set_1)}")
    print(f"symmetric diff len: {len(set_2.symmetric_difference(set_1))}")
    print(f"intersection len: {len(set_1.intersection(set_2))}")
    print(f"union len: {len(set_1.union(set_2))}")

def get_image(imgfile):

    return mpimg.imread(imgfile)

def save_mask(maskfile, mask):

  np.save(maskfile, mask)

def load_mask(maskfile):

  return np.load(maskfile)

# Classes

In [14]:
class SegDataset(torch.utils.data.Dataset):
  """Initial Dataset for segmentation images.
  
  Args:
  
      obj_list (list): list of object for this dataset
      images_dir (str): path to images folder
      masks_dir (str): path to segmentation masks folder
  
  """
  
  def __init__(self, obj_list, img_dir, mask_dir):

    self.obj_list = obj_list
    self.img_dir = img_dir
    self.mask_dir = mask_dir
  
  def __getitem__(self, i):
      
    # read data
    if i > len(self.obj_list):
        raise KeyError('i not in self.obj_list')
    if i < 0:
        raise KeyError('i is negative')

    obj = self.obj_list[i]
    imgfile = os.path.join(self.img_dir, '.'.join([obj, 'jpg']))
    maskfile = os.path.join(self.mask_dir, '.'.join([obj, 'npy']))
    
    image = get_image(imgfile)
    mask = load_mask(maskfile)
        
    return image, mask, obj
      
  def __len__(self):
    return len(self.obj_list)

In [15]:
class SegDataModule(pl.LightningDataModule):
  """Initial DataModule for segmentation images.
    
    Args:
    
        anno_dir (str): dir with annotations
        img_dir: (str): dir with images
        mask_dir (str): dir with mask
        val_size (float): part for 
        test_size=0.1,
        batch_size=32,
        random_state=42
    
  
  """
  def __init__(self, 
               anno_dir: str = "./anno",
               img_dir: str = "./images",
               mask_dir: str = "./masks",
               val_size=0.2,
               test_size=0.1,
               batch_size=32,
               random_state=42):

    super().__init__()
    self.anno_dir = anno_dir
    self.img_dir = img_dir
    self.mask_dir = mask_dir
    self.val_size = val_size
    self.test_size = test_size
    self.batch_size = batch_size
    self.random_state = random_state

  def prepare_data(self):

    jpgfiles = sorted(os.listdir(self.anno_dir))
    xmlfiles = sorted(os.listdir(self.img_dir))

    set_1 = set([f.split('.')[0] for f in jpgfiles])
    set_2 = set([f.split('.')[0] for f in xmlfiles])
    common_objects = set_1.intersection(set_2)

    os.makedirs(self.mask_dir, exist_ok=True)

    for obj in tqdm(common_objects):

      xmlfile = os.path.join(self.anno_dir, ".".join([obj, "xml"]))
      mask = get_mask(xmlfile)
      maskfile = os.path.join(self.img_dir, ".".join([obj, "npy"]))
      save_mask(maskfile, mask)

  def setup(self):

    jpgfiles = sorted(os.listdir(self.anno_dir))
    xmlfiles = sorted(os.listdir(self.img_dir))

    set_1 = set([f.split('.')[0] for f in jpgfiles])
    set_2 = set([f.split('.')[0] for f in xmlfiles])
    common_objects = set_1.intersection(set_2)
    defect_types = set([obj.split('_') for obj in common_objects])

    # split data per defects
    def split(defect_type):

      defect_objects = [obj for obj in common_objects if defect_type in obj]

      nontrain_part = 1 - (self.val_size+self.test_size)
      train, val = train_test_split(defect_objects,
                                    test_size=nontrain_part,
                                    random_state=self.random_state)
      
      val, test = train_test_split(val, 
                                   test_size=self.test_size/nontrain_part,
                                   random_state=self.random_state)
      
      return train, val, test

    self.train = []
    self.val = []
    self.test = []

    for defect_type in defect_types:
      train, dev, test = split(defect_type)

      self.train.append(train)
      self.dev.append(dev)
      self.test.append(test)

    self.train = list(itertools.chain.from_iterable(self.train))
    self.traindataset = SegDataset(self.train, self.img_dir, self.mask_dir)

    self.val = list(itertools.chain.from_iterable(self.val))
    self.valdataset = SegDataset(self.val, self.img_dir, self.mask_dir)

    self.test = list(itertools.chain.from_iterable(self.test))
    self.testdataset = SegDataset(self.test, self.img_dir, self.mask_dir)

    self.alldataset = SegDataset(self.train+self.val+self.test,
                                 self.img_dir, self.mask_dir)

  def train_dataloader(self):
    return DataLoader(self.traindataset, batch_size=self.batch_size)

  def val_dataloader(self):
    return DataLoader(self.valdataset, batch_size=self.batch_size)

  def test_dataloader(self):
    return DataLoader(self.testdataset, batch_size=self.batch_size)

  def predict_dataloader(self):
    return DataLoader(self.alldataset)


In [16]:
class PlSegModel(pl.LightningModule):
  """
  Some code is from https://github.com/Shreeyak/pytorch-lightning-segmentation-template
  """

  def __init__(self, model, loss, 
      train_metric, val_metric, test_metric,
      sample=None,
    ):

    super().__init__()
    self.model = model

    self.loss = loss

    self.train_metric = train_metric
    self.val_metric = val_metric
    self.test_metric = test_metric
    self.sample = sample
    self.monitor_imgs = self.sample

  def forward(self, x):
    """
      In lightning, forward defines the prediction/inference actions.
      This method can be called elsewhere in the LightningModule with: `outputs = self(inputs)`.
    """
      
    output_mask = self.model(x)
    return output_mask
  
  def training_step(self, batch, batch_idx):
    """
      Defines the train loop. It is independent of forward().
      Don’t use any cuda or .to(device) calls in the code. 
      PL will move the tensors to the correct device.
    """

    inputs, mask, objects = batch
    outputs = self(inputs)

    pred = outputs
    target = mask
    loss = self.loss(pred, target.type_as(pred))
    self.train_metric(pred, target)
    
    self.log(f'train_metric', self.train_metric, 
             on_step=True, on_epoch=True, 
             prog_bar=False, logger=True, sync_dist=True)
    self.log(f'train_loss', loss, 
             on_step=True, on_epoch=True, 
             prog_bar=False, logger=True, sync_dist=True)

    return loss
  
  def validation_step(self, batch, batch_idx):
      
    inputs, mask, labels, objects = batch
    outputs = self(inputs)
      
    pred = outputs
    target = mask
      
    loss = self.loss(pred, target.type_as(pred))
    self.val_metric(pred, target)

    self.log(f'val_metric', self.val_metric,
              on_step=True, on_epoch=True, 
              prog_bar=False, logger=True, sync_dist=True)
    self.log(f'val_loss', loss,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
  
  def test_step(self, batch, batch_idx):
      
    inputs, mask, labels, objects = batch
    outputs = self(inputs)

    pred = outputs
    target = mask
    loss = self.loss(pred, target.type_as(pred))
    self.test_metric(pred, target)

    self.log(f'test_metric', self.test_metric,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
    self.log(f'test_loss', loss,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
  
  def configure_optimizers(self):

      optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
      lr_scheduler = ReduceLROnPlateau(optimizer, patience=3, verbose=True)
      
      sched = {'scheduler': lr_scheduler,
                'monitor': f'val_metric'}
      
      return [optimizer], [sched]

# Model

In [17]:
# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pre-trained weights for encoder initialization
    classes=NUMCLASSES,   # model output channels (number of classes in your dataset)
    in_channels=1,
    activation=ACTIVATION
)