In [None]:
import cv2 
import numpy as np 
import torch
import torchvision.models as models
from torchvision import transforms as T

import os
from PIL import Image, ImageFilter 
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import scipy
import scipy.ndimage
from scipy import ndimage
from shutil import copyfile

In [None]:
def feature_mask(mask, i):
    """
    Because the edges of a segmentaation mask may be inacurate, we dilate it for 
    the subsequent feature matching. 
    The feature matcher will only look for correspondence points outside the mask.
    """
    h, w = mask.shape
    obj = np.nonzero(mask)
    if len(obj[0]) == 0:
        # occlusion
        print('occlusion!')
        return 255 - mask
    mask_homo = np.ones_like(mask)*255
    up, down = obj[0].min(), obj[0].max() 
    left, right = obj[1].min(), obj[1].max() 
    kernel = np.ones((50, 50), np.uint8)
    mask = cv2.dilate(mask, kernel, iterations=1)
    mask_homo -= mask
    # You can adjust the edge erosion here to inlude more regions or less
#     mask_homo[max(0, up-50): min(h, down+60), max(0, left-20): min(w, right+20)]=0
    return mask_homo

In [None]:
# the dataset folder
dataset = "composited/cloth/cloth_grail_5152"
img_f = sorted(os.listdir(os.path.join("../datasets/", dataset, "rgb")))[0]
print(img_f)
img1 = cv2.imread(os.path.join("../datasets/", dataset, "rgb", img_f))
old_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)

h, w, _ = img1.shape
img1_acc_mask = 0
for mask_ind in os.listdir(os.path.join("../datasets/", dataset, "mask")):
    print(mask_ind)
    mask_f = sorted(os.listdir(os.path.join("../datasets/", dataset, "mask/", mask_ind)))[0]
    print(mask_f)
    mask_i = cv2.imread(os.path.join("../datasets/", dataset, "mask/", mask_ind, mask_f))
    img1_acc_mask += mask_i
img1_mask = feature_mask(cv2.cvtColor(img1_acc_mask, cv2.COLOR_BGR2GRAY),0)
imshow(img1_mask)
plt.show()
sift = cv2.SIFT_create()
# FLANN parameters
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
search_params = dict(checks=50)   # or pass empty dictionary
flann = cv2.FlannBasedMatcher(index_params,search_params)

start_matrix = np.identity(3)
with open(os.path.join("../datasets/", dataset, 'homographies_raw.txt'), 'w') as f:
    for i in range(len(os.listdir(os.path.join("../datasets/", dataset, "rgb")))):
        img_f = sorted(os.listdir(os.path.join("../datasets/", dataset, "rgb")))[i]
        frame = cv2.imread(os.path.join("../datasets/", dataset, "rgb", img_f))
        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)


        frame_acc_mask = 0
        # if there are multiple objects, collate their masks together, so that 
        # feature masking will consider the points within none of them.
        # subfolders inside "mask" should be names as "01", "02", ...
        for mask_ind in os.listdir(os.path.join("../datasets/", dataset, "mask")):
            mask_f = sorted(os.listdir(os.path.join("../datasets/", dataset, "mask/", mask_ind)))[i]
            mask_i = cv2.imread(os.path.join("../datasets/", dataset, "mask/", mask_ind, mask_f))
            frame_acc_mask += mask_i
        frame_mask = feature_mask(cv2.cvtColor(frame_acc_mask, cv2.COLOR_BGR2GRAY), i)
        imshow(frame_mask) 
        plt.show()
        # find correspondence points between 2 frames by SIFT features.
        kp1, des1 = sift.detectAndCompute(old_gray, img1_mask)
        kp2, des2 = sift.detectAndCompute(frame_gray, frame_mask)
        matches = flann.knnMatch(des1,des2,k=2)
        tmp1 = cv2.drawKeypoints(old_gray, kp1, old_gray)
        tmp2 = cv2.drawKeypoints(frame_gray, kp2, frame_gray)
        plt.imshow(tmp1)
        plt.show()
        plt.imshow(tmp2)
        plt.show()
        good_points=[] 
        for m, n in matches: 
            good_points.append((m, m.distance/n.distance)) 
        # sort the correspondence points by confidence, by default we only use the best 50.
        good_points.sort(key=lambda y: y[1])
        query_pts = np.float32([kp1[m.queryIdx] 
                        .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) 

        train_pts = np.float32([kp2[m.trainIdx] 
                        .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) 
        print('len(query_pts)',len(query_pts))
        # compute homography by the correspondence pairs
        matrix, matrix_mask = cv2.findHomography(query_pts, train_pts, cv2.RANSAC, 5.0) 
        inliers = matrix_mask.sum()
        print(i, inliers, matrix)
        start_matrix = matrix @ start_matrix
        f.write(' '.join([str(i) for i in start_matrix.flatten()])+'\n')
        imshow(frame_mask) 
        plt.show()
        dst = cv2.warpPerspective(img1, start_matrix, (w, h), flags=cv2.INTER_LINEAR)
        imshow(dst) 
        plt.show()
        dst = cv2.warpPerspective(old_gray, matrix, (w, h), flags=cv2.INTER_LINEAR)
        imshow(dst) 
        plt.show()
        old_gray = frame_gray.copy()
        img1_mask = frame_mask.copy()
        imshow(frame_gray) 
        plt.show()