In [1]:
import numpy as np
import cv2
from collections import namedtuple
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import igraph as ig
import logging
import sys
from tqdm.auto import tqdm as tq
from PIL import Image as im
import os

In [2]:
logging.basicConfig(stream=sys.stdout, level=logging.WARN)

In [3]:
con = namedtuple('_', ('FIX', 'UNK', 'FG', 'BG'))(1, 0, 1, 0)
NUM_GMM_COMP = 5
GAMMA = 50
LAMDA = 9 * GAMMA
NUM_ITERS = 3
TOL = 1e-3

In [4]:
class EventHandler:
    """
    Class for handling user input during segmentation iterations 
    """
    
    def __init__(self, flags, img, _types, _alphas, colors):
        
        self.FLAGS = flags
        self.ix = -1
        self.iy = -1
        self.img = img
        self.img2 = self.img.copy()
        self._types = _types
        self._alphas = _alphas
        self.COLORS = colors

    @property
    def image(self):
        return self.img
    
    @image.setter
    def image(self, img):
        self.img = img
        
    @property
    def types(self):
        return self._types

    @types.setter
    def types(self, _types):
        self._types = _types
    
    @property
    def alphas(self):
        return self._alphas

    @alphas.setter
    def alphas(self, _alphas):
        self._alphas = _alphas
    
    @property
    def flags(self):
        return self.FLAGS 
    
    @flags.setter
    def flags(self, flags):
        self.FLAGS = flags
    
    def handler(self, event, x, y, flags, param):


        # Draw strokes for refinement 

        if event == cv2.EVENT_LBUTTONDOWN:
            self.FLAGS['DRAW_STROKE'] = True
            cv2.circle(self.img, (x,y), 3, self.FLAGS['value']['color'], -1)
            cv2.circle(self._alphas, (x,y), 3, self.FLAGS['value']['val'], -1)
            cv2.circle(self._types, (x,y), 3, con.FIX, -1)

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_STROKE'] == True:
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._alphas, (x,y), 3, self.FLAGS['value']['val'], -1)
                cv2.circle(self._types, (x,y), 3, con.FIX, -1)

        elif event == cv2.EVENT_LBUTTONUP:
            if self.FLAGS['DRAW_STROKE'] == True:
                self.FLAGS['DRAW_STROKE'] = False
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._alphas, (x,y), 3, self.FLAGS['value']['val'], -1)
                cv2.circle(self._types, (x,y), 3, con.FIX, -1)

In [5]:
from numpy import savetxt
def run(filepath, filename, n_components=NUM_GMM_COMP, gamma=GAMMA, lamda=LAMDA,
        num_iters=NUM_ITERS, tol=TOL, connect_diag=True):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : con.BG}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : con.FG}

    FLAGS = {
        'DRAW_STROKE': False,         # flag for drawing strokes
        'value' : DRAW_FG,            # drawing strokes initialized to mark foreground
    }

    img = cv2.imread(filepath)
    img2 = img.copy()
    types = np.zeros(img.shape[:2], dtype = np.uint8)  # whether a pixel is known or unknown
    alphas = np.zeros(img.shape[:2], dtype = np.uint8) # mask is a binary array with : 0 - background pixels
                                                       #                               1 - foreground pixels 
    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    
    EventObj = EventHandler(FLAGS, img, types, alphas, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)

    while(1):
        
        img = EventObj.image
        types = EventObj.types
        alphas = EventObj.alphas
        FLAGS = EventObj.flags
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        # key bindings
        if k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            FLAGS['DRAW_STROKE'] = False
            FLAGS['value'] = DRAW_FG
            img = img2.copy()
            types = np.zeros(img.shape[:2], dtype = np.uint8) 
            alphas = np.zeros(img.shape[:2], dtype = np.uint8)
            EventObj.image = img
            EventObj.types = types
            EventObj.alphas = alphas
        
        elif k == 13: 
            # Press carriage return to initiate segmentation
            
            #-------------------------------------------------#
            # Implement GrabCut here.                         #  
            # Function should return a mask which can be used #
            # to segment the original image as shown on L90   # 
            #-------------------------------------------------#
            EventObj.alphas = alphas
            alphas=alphas*255
            types=types*255
            data1=im.fromarray(alphas)
            data2=im.fromarray(types)
            curdir=os.getcwd()
            name1=curdir + '\\FG\\'
            name1=name1 + filename
            name2=curdir + '\\FGBG\\'
            name2=name2 + filename
            data1.save(name1)
            data2.save(name2)
            break

        
        EventObj.flags = FLAGS

In [6]:
if __name__ == '__main__':
    path=os.getcwd()  
    path=path + '\\test\\'
    print(path)
    for filename in os.listdir(path):
        # Path to image file
        print(filename)
        filepath=path+filename
        try:
            run(filepath,filename)
        finally:
            cv2.destroyAllWindows()

C:\Users\JBSCHOLOR-2020-18\Desktop\My files\ComputerVisionProject\src\test\
banana1-resize.jpg
banana1.jpg
banana2.jpg
banana3.jpg
book.jpg
bool.jpg
bush.jpg
ceramic.jpg
cross.jpg
doll.jpg
elefant.jpg
flower.jpg
fullmoon.jpg
grave.jpg
stone1.jpg
stone2.jpg
teddy.jpg
tennis.jpg
