> In this project I named binary label images to mask and soft label(0 ~ 1) to alpha

In [1]:
import numpy as np
import sklearn.neighbors
import scipy.sparse
import warnings

import matplotlib.pyplot as plt
import scipy.misc

import cv2
import os

In [2]:
import cv2 as cv
from time import time
from PIL import Image

import torch
import torch.nn as nn


In [3]:
import cv2 as cv
from time import time
from PIL import Image

import torch
import torch.nn as nn
from indexnet.hlmobilenetv2 import hlmobilenetv2

# ignore warnings
import warnings
warnings.filterwarnings("ignore")

IMG_SCALE = 1./255
IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4))
IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4))

STRIDE = 32
RESTORE_FROM = './pretrained/indexnet_matting.pth.tar'
RESULT_DIR = './examples/mattes'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists(RESULT_DIR):
    os.makedirs(RESULT_DIR)

# load pretrained model
net = hlmobilenetv2(
        pretrained=False,
        freeze_bn=True, 
        output_stride=STRIDE,
        apply_aspp=True,
        conv_operator='std_conv',
        decoder='indexnet',
        decoder_kernel_size=5,
        indexnet='depthwise',
        index_mode='m2o',
        use_nonlinear=True,
        use_context=True
    )
net = nn.DataParallel(net)
try:
    checkpoint = torch.load(RESTORE_FROM, map_location=torch.device('cpu'))
    pretrained_dict = checkpoint['state_dict']
except:
    raise Exception('Please download the pretrained model!')
net.load_state_dict(pretrained_dict)
net.to(device)

# switch to eval mode
net.eval()

DataParallel(
  (module): hlMobileNetV2UNetDecoderIndexLearning(
    (layer0): Sequential(
      (0): Conv2d(4, 32, kernel_size=(3, 3), stride=1, padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (layer1): Sequential(
      (0): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
          (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layer2): Sequential(
      (0): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05,

In [18]:
def load_img_mask_pair(img_path):
    
    mask_path = img_path.split(".p")[0] + "_matte.png"
    
    if 'Supervisely' in mask_path:
        
        mask_path = mask_path.replace("/img/", "/masks_machine/")
        mask_path = mask_path.replace(".jpeg", "")
        
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    mask *= 255
    
    img = cv2.resize(img, (1024, 1024))
    mask = cv2.resize(mask, (1024, 1024))
    
    return img, mask

In [5]:
def make_trimap(mask, size=(10, 10)):

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size)
    mask = mask / 255.

    dilated = cv2.dilate(mask, kernel, iterations=1) * 255
    eroded = cv2.erode(mask, kernel, iterations=1) * 255

    cnt1 = len(np.where(mask >= 0)[0])
    cnt2 = len(np.where(mask == 0)[0])
    cnt3 = len(np.where(mask == 1)[0])
    
    #print("all:{} bg:{} fg:{}".format(cnt1, cnt2, cnt3))
    
    assert(cnt1 == cnt2 + cnt3)
    
    cnt1 = len(np.where(dilated >= 0)[0])
    cnt2 = len(np.where(dilated == 0)[0])
    cnt3 = len(np.where(dilated == 255)[0])
    
    #print("all:{} bg:{} fg:{}".format(cnt1, cnt2, cnt3))
    assert(cnt1 == cnt2 + cnt3)

    cnt1 = len(np.where(eroded >= 0)[0])
    cnt2 = len(np.where(eroded == 0)[0])
    cnt3 = len(np.where(eroded == 255)[0])
    #print("all:{} bg:{} fg:{}".format(cnt1, cnt2, cnt3))
    assert(cnt1 == cnt2 + cnt3)

    trimap = dilated.copy()
    
    trimap[((dilated == 255) & (eroded == 0))] = 128

    return trimap

In [6]:
# from INdexNet
def read_image(x):
    img_arr = np.array(Image.open(x))
    return img_arr

def save_alpha(target_name, alpha):

    cv2.imwrite(target_name, alpha*255)
    
    return 1 

In [7]:
def image_alignment(x, output_stride, odd=False):
    imsize = np.asarray(x.shape[:2], dtype=np.float)
    if odd:
        new_imsize = np.ceil(imsize / output_stride) * output_stride + 1
    else:
        new_imsize = np.ceil(imsize / output_stride) * output_stride
    h, w = int(new_imsize[0]), int(new_imsize[1])

    x1 = x[:, :, 0:3]
    x2 = x[:, :, 3]
    new_x1 = cv.resize(x1, dsize=(w,h), interpolation=cv.INTER_CUBIC)
    new_x2 = cv.resize(x2, dsize=(w,h), interpolation=cv.INTER_NEAREST)

    new_x2 = np.expand_dims(new_x2, axis=2)
    new_x = np.concatenate((new_x1, new_x2), axis=2)

    return new_x

# Actual Logic

In [8]:
def inference(image_path):

    alpha_path = image_path.replace("/img/", "/alpha/")
    image, mask = load_img_mask_pair(img_path)
    trimap = make_trimap(mask, size=(20, 20))
    
    with torch.no_grad():
#         image, trimap = read_image(image_path), read_image(trimap_path)
        trimap = np.expand_dims(trimap, axis=2)
        image = np.concatenate((image, trimap), axis=2)
        
        h, w = image.shape[:2]

        image = image.astype('float32')
        image = (IMG_SCALE * image - IMG_MEAN) / IMG_STD
        image = image.astype('float32')

        image = image_alignment(image, STRIDE)
        inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0))
        #inputs = inputs.cuda()
        
        # inference
        start = time()
        outputs = net(inputs)
        end = time()

        outputs = outputs.squeeze().cpu().numpy()
        alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC)
        alpha = np.clip(alpha, 0, 1) * 255.
        trimap = trimap.squeeze()
        mask = np.equal(trimap, 128).astype(np.float32)
        alpha = (1 - mask) * trimap + mask * alpha

        _, image_name = os.path.split(image_path)
        Image.fromarray(alpha.astype(np.uint8)).save(os.path.join(alpha_path))
        # Image.fromarray(alpha.astype(np.uint8)).show()

        running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1
        print('framerate: {0:.2f}Hz'.format(running_frame_rate))
        return alpha.astype(np.uint8)

In [9]:
DATASET_BASE = "./dataset/Supervisely_person_dataset"

In [10]:
dslist = [ds for ds in os.listdir(DATASET_BASE) if "." not in ds]

In [11]:
dslist[5:]

['ds9', 'ds7', 'ds2', 'ds5', 'ds13', 'ds12', 'ds4', 'ds3']

In [16]:
# test
for ds in dslist[5:6]:
    
    img_dir = os.path.join(DATASET_BASE, ds, "img")
    alpha_dir = os.path.join(DATASET_BASE, ds, "alpha")

    if not os.path.exists(alpha_dir):
        os.mkdir(alpha_dir)
        
    img_list = [os.path.join(DATASET_BASE, ds, "img", i) for i in os.listdir(img_dir)]
    
#     for img_path in img_list:
        
        
#         alpha = inference(img_path)
#         print(img_path, "is done")
        
# #         plt.imshow(alpha, cmap='gray')
# #         plt.show()
# #         break
# #     break


In [17]:
img_list

['./dataset/Supervisely_person_dataset/ds9/img/pexels-photo-878783.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-873417.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-269920.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-878782.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-867852.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-867846.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-877699.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-881669.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-866019.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-838574.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-878346.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-114794.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexels-photo-865713.png',
 './dataset/Supervisely_person_dataset/ds9/img/pexe

## legacy

In [None]:
# it is too slow
# borrow from 'https://github.com/MarcoForte/knn-matting'
# nn = 10
def knn_matte(img, trimap, mylambda=100):
    
    [m, n, c] = img.shape
    
    img, trimap = img / 255.0, trimap / 255.0
    
    foreground = (trimap > 0.99).astype(int)
    background = (trimap < 0.01).astype(int)
    print(foreground.shape)
    print(background.shape)
    
    all_constraints = foreground + background
    print(all_constraints.shape)

    print('Finding nearest neighbors')
    
    a, b = np.unravel_index(np.arange(m*n), (m, n))
    print("a", a)
    print("b", b)
    
    feature_vec = np.append(np.transpose(img.reshape(m*n,c)), [a, b]/np.sqrt(m*m + n*n), axis=0).T
    
    nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=10).fit(feature_vec)
    knns = nbrs.kneighbors(feature_vec)[1]
    print(knns.shape)

    # Compute Sparse A
    print('Computing sparse A')
    row_inds = np.repeat(np.arange(m*n), 10)
    col_inds = knns.reshape(m*n*10)
    
    vals = 1 - np.linalg.norm(feature_vec[row_inds] - feature_vec[col_inds], axis=1)/(c+2)
    A = scipy.sparse.coo_matrix((vals, (row_inds, col_inds)),shape=(m*n, m*n))

    D_script = scipy.sparse.diags(np.ravel(A.sum(axis=1)))
    
    L = D_script-A
#     D = scipy.sparse.diags(np.ravel(all_constraints[:,:, 0]))
    D = scipy.sparse.diags(np.ravel(all_constraints[:,:]))
#     v = np.ravel(foreground[:,:,0])
    v = np.ravel(foreground[:,:])
    c = 2*mylambda*np.transpose(v)
    H = 2*(L + mylambda*D)

    print('Solving linear system for alpha')
    warnings.filterwarnings('error')
    alpha = []
    try:
        alpha = np.minimum(np.maximum(scipy.sparse.linalg.spsolve(H, c), 0), 1).reshape(m, n)
    except Warning:
        x = scipy.sparse.linalg.lsqr(H, c)
        alpha = np.minimum(np.maximum(x[0], 0), 1).reshape(m, n)
    return alpha