In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()
!conda install -c pytorch faiss-gpu

⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:29
🔁 Restarting kernel...


In [None]:
!conda install -c pytorch faiss-gpu

Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - done
Solving environment: | / - \ | / - \ | / - \ | / - \ | / - \ | / done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - faiss-gpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2021.10.8  |       ha878542_0         139 KB  conda-forge
    certifi-2021.10.8          |   py37h89c1867_1         145 KB  conda-forge
    conda-4.10.3               |   py37h89c1867_3         3.1 MB  conda-forge
    cudatoolkit-11.1.1         |       h6406543_8        1.20 GB  conda-forge
    faiss-1.7.1             

In [None]:
!git clone https://github.com/jinczing/info-flow-matting.git

Cloning into 'AlphaMatting-Information-Flow'...
remote: Enumerating objects: 11, done.[K
remote: Total 11 (delta 0), reused 0 (delta 0), pack-reused 11[K
Unpacking objects: 100% (11/11), done.


In [1]:
import numpy as np
import sklearn.neighbors as neighbors
import scipy.sparse as sparse
import warnings
import matplotlib.pyplot as plt
import cv2
import time
import torch
import torchvision
from torchvision import transforms
import faiss

In [2]:
def knn_matting(image, trimap, my_lambda=100, use_faiss=False, use_weak_spatial=False,
                use_vgg_feature=False, use_bicg=False):
    timer = time.time()
    [h, w, c] = image.shape
    image, trimap = image / 255.0, trimap / 255.0
    foreground = (trimap == 1.0).astype(int)
    background = (trimap == 0.0).astype(int)
    all_constraints = (foreground + background)

    ####################################################
    # TODO: find KNN for the given image
    ####################################################
    k = 11
    k2 = 5
    k3 = 5
    lamb = 100
    C = 3+2
    f_lamb = 1
    norm = np.sqrt(h*h+w*w)
    u, v = np.meshgrid(np.arange(0, w)/norm, np.arange(0, h)/norm)
    X = np.concatenate([image[..., :],
               np.repeat(np.expand_dims(u+np.random.rand(h, w)*1e-6, axis=-1), 1, axis=-1), # +np.random.rand(h, w)*1e-6
               np.repeat(np.expand_dims(v+np.random.rand(h, w)*1e-6, axis=-1), 1, axis=-1)], axis=-1) # +np.random.rand(h, w)*1e-6
    if use_vgg_feature:
      preprocess = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ])
      image = preprocess(image).unsqueeze(0).float()
      model = torchvision.models.vgg16(pretrained=True)
      model.eval()
      device = 'cuda'
      model = model.to(device)
      image = image.to(device)
      f = model.features[:3](image)
      f = f.detach().cpu().numpy().squeeze().transpose((1, 2, 0))
      f = f.reshape(-1, 64)

    X = X.reshape(-1, C).astype('float32')
    print(X.shape)
    if use_faiss:
      res = faiss.StandardGpuResources()  # use a single GPU

      # first knn
      timer = time.time()
      index = faiss.IndexFlatL2(C)
      index = faiss.index_cpu_to_gpu(res, 0, index)
      index.add(X)
      _, knns = index.search(X, k)
      print('first knn time', time.time()-timer)

      # second knn (weak spatial constraints)
      if use_weak_spatial:
        timer = time.time()
        index = faiss.IndexFlatL2(C)
        index = faiss.index_cpu_to_gpu(res, 0, index)
        X[:, -2:] = X[:, -2:]/100
        index.add(X)
        _, knns2 = index.search(X, k2)
        print('second knn time', time.time()-timer)

        if use_vgg_feature:
          # third knn (vgg features)
          timer = time.time()
          index = faiss.IndexFlatL2(64)
          index = faiss.index_cpu_to_gpu(res, 0, index)

          index.add(np.ascontiguousarray(f))
          _, knns3 = index.search(np.ascontiguousarray(f), k3)
          knns3 = knns3.reshape((h*w, k3))
          
          print('third knn time', time.time()-timer)

          knns = np.concatenate((knns, knns2, knns3), axis=-1).astype(int)
        else:
          knns = np.concatenate((knns, knns2), axis=-1).astype(int)
          k3 = 0
      else:
        k2 = 0
        k3 = 0
    else:
      # first knn
      timer = time.time()
      knn = neighbors.NearestNeighbors(n_neighbors=k, p=1).fit(X)
      knns = knn.kneighbors(X)[1]
      print('first knn time', time.time()-timer)

      # second knn
      if use_weak_spatial:
        timer = time.time()
        X[:, -2:] = X[:, -2:]/100
        knn = neighbors.NearestNeighbors(n_neighbors=k2, p=1).fit(X)
        knns2 = knn.kneighbors(X)[1]
        print('second knn time', time.time()-timer)

        if use_vgg_feature:
          # third knn
          timer = time.time()
          knn = neighbors.NearestNeighbors(n_neighbors=k3, p=1).fit(f)
          knns3 = knn.kneighbors(f)[1]
          print('third knn time', time.time()-timer)

          knns = np.concatenate((knns, knns2, knns3), axis=-1)
        else:
          knns = np.concatenate((knns, knns2), axis=-1)
          k3 = 0
      else:
        k2 = 0
        k3 = 0

    print(knns.shape)

    ####################################################
    # TODO: compute the affinity matrix A
    #       and all other stuff needed
    ####################################################
    row_inds = np.repeat(np.arange(h*w), k+k2+k3)
    col_inds = knns.reshape(h*w*(k+k2+k3))
    A = 1 - np.linalg.norm(X[row_inds, :C] - X[col_inds, :C], axis=1, ord=1)/C
    A = sparse.coo_matrix((A, (row_inds, col_inds)), shape=(h*w, h*w))

    D_script = sparse.diags(np.ravel(A.sum(axis=1)))
    L = D_script-A
    D = sparse.diags(np.ravel(all_constraints[:,:]))
    v = np.ravel(foreground[:,:])
    c = lamb*np.transpose(v)
    H = (L + lamb*D)
    
    ####################################################
    # TODO: solve for the linear system,
    #       note that you may encounter en error
    #       if no exact solution exists
    ####################################################
    print('start solving')
    timer = time.time()
    if use_bicg:
      alpha, _ = sparse.linalg.bicg(H, c, atol=1e-10, maxiter=6000)
      print('solving time:', time.time()-timer)
      alpha = alpha.reshape(h, w)
    else:
      warnings.filterwarnings('error')
      try:
        alpha = sparse.linalg.spsolve(H, c)
      except Warning:
        print('using bicg to solve instead')
        alpha, _ = sparse.linalg.bicg(H, c, atol=1e-10, maxiter=6000)
      print('solving time:', time.time()-timer)
      alpha = alpha.reshape(h, w)
    print(alpha.min(), alpha.max())

    return alpha

In [9]:
# run knn matting
image = cv2.imread('./image/bear.png')
trimap = cv2.imread('./trimap/bear.png', cv2.IMREAD_GRAYSCALE)
background = cv2.imread('./image/forest.jpg')
background = cv2.resize(background, (image.shape[1], image.shape[0]))

alpha = knn_matting(image, trimap, use_faiss=True, use_weak_spatial=True, use_vgg_feature=False, use_bicg=True)
alpha = alpha[:, :, np.newaxis].clip(0, 1)
cv2.imwrite('./result/bear_alpha_faiss.png', alpha*255)

####################################################
# TODO: pick up your own background image, 
#       and merge it with the foreground
####################################################
result = (1-alpha)*background + alpha*image
cv2.imwrite('./result/bear_faiss.png', result)

(143640, 5)
first knn time 0.6080746650695801
second knn time 0.4530041217803955
(143640, 16)
start solving
solving time: 0.6991984844207764
-3.116766201295849e-06 1.001473862033135


True

In [10]:
# run information flow matting
sys.path.append('./info_flow/')
from info_flow.alpha import info_flow
image = cv2.imread('./image/bear.png')
trimap = cv2.imread('./trimap/bear.png', cv2.IMREAD_GRAYSCALE)
background = cv2.imread('./image/forest.jpg')
background = cv2.resize(background, (image.shape[1], image.shape[0]))

warnings.filterwarnings("ignore", category=DeprecationWarning)
timer = time.time()
alpha = info_flow(image, trimap)
print('info flow time', time.time()-timer)
alpha = alpha[:, :, np.newaxis].clip(0, 1)
cv2.imwrite('./result/bear_alpha_info.png', alpha*255)

result = (1-alpha)*background + alpha*image
cv2.imwrite('./result/bear_info.png', result)

134155
(9485, 14, 1)
(9485, 14, 3)
(9485, 5, 5)
1.3548182723431954 -0.20146884252984487
info flow time 26.96239733695984


True