In [17]:
import os
import glob
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm

## Binary mask

In [10]:
data_dir = "/mnt/localssd/VideoMatte240K"

In [15]:
split = "test"
threshold = 0.5
alpha_dir = os.path.join(data_dir, split, "pha")
out_dir = os.path.join(data_dir, split, "bin")
for alpha_path in tqdm(glob.glob(alpha_dir + "/*/*")):
    alpha = np.array(Image.open(alpha_path).convert("L"))
    binmask = (alpha > threshold * 255) * 255
    out_path = alpha_path.replace(alpha_dir, out_dir)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, binmask)

100%|██████████| 2720/2720 [03:19<00:00, 13.62it/s]


In [16]:
split = "valid"
threshold = 0.5
alpha_dir = os.path.join(data_dir, split, "pha")
out_dir = os.path.join(data_dir, split, "bin")
for alpha_path in tqdm(glob.glob(alpha_dir + "/*/*")):
    alpha = np.array(Image.open(alpha_path).convert("L"))
    binmask = (alpha > threshold * 255) * 255
    out_path = alpha_path.replace(alpha_dir, out_dir)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, binmask)

100%|██████████| 2773/2773 [02:01<00:00, 22.87it/s]


## Propagate the binary mask

## Tri-map

In [34]:
def gen_trimap(alpha, eval_kernel=25):
    k_size = eval_kernel
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                       (k_size, k_size))
    dilated = cv2.dilate(alpha, kernel)
    eroded = cv2.erode(alpha, kernel)
    trimap = np.zeros(alpha.shape)
    trimap.fill(128)
    # cv2.imwrite("dilated.png", dilated)
    # cv2.imwrite("eroded.png", eroded)
    # print((eroded > 254.5).sum())
    # print((dilated < 10).sum())
    trimap[eroded > 254.5] = 255
    trimap[dilated < 10] = 0
    return trimap, eroded, dilated

In [35]:
split = "test"
threshold = 0.5
alpha_dir = os.path.join(data_dir, split, "pha")
out_dir = os.path.join(data_dir, split, "tri")
for alpha_path in tqdm(glob.glob(alpha_dir + "/*/*")):
    alpha = np.array(Image.open(alpha_path).convert("L"))
    # binmask = (alpha > threshold * 255) * 255
    trimap, _, _ = gen_trimap(alpha)
    out_path = alpha_path.replace(alpha_dir, out_dir)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, trimap)

100%|██████████| 2720/2720 [09:49<00:00,  4.61it/s]


In [36]:
split = "valid"
threshold = 0.5
alpha_dir = os.path.join(data_dir, split, "pha")
out_dir = os.path.join(data_dir, split, "tri")
for alpha_path in tqdm(glob.glob(alpha_dir + "/*/*")):
    alpha = np.array(Image.open(alpha_path).convert("L"))
    # binmask = (alpha > threshold * 255) * 255
    trimap, _, _ = gen_trimap(alpha)
    out_path = alpha_path.replace(alpha_dir, out_dir)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, trimap)

100%|██████████| 2773/2773 [05:58<00:00,  7.73it/s]


## Propagate the tri-map

## Compute flow

In [3]:
import sys
sys.path.append('RAFT/core')

In [5]:
from raft import RAFT
from argparse import Namespace

In [32]:
small=False
mixed_precision=False
alternate_corr=False
args = Namespace(small=small, mixed_precision=mixed_precision, alternate_corr=alternate_corr)
raft_model = RAFT(args).cuda()
state_dict = torch.load('RAFT/models/raft-sintel.pth')
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
raft_model.load_state_dict(state_dict)

<All keys matched successfully>

In [33]:
_ = raft_model.eval()

In [14]:
split = "test"
img_dir = os.path.join(data_dir, split, "fgr")
out_dir = os.path.join(data_dir, split, "flow")

image_paths = sorted(glob.glob(img_dir + "/*/*"))
image_pairs = [(image_paths[i], image_paths[i + 1]) for i in range(len(image_paths) - 1)]


In [24]:
for l_img_path, r_img_path in image_pairs:
    left = np.array(cv2.imread(l_img_path)[:, :, ::-1])
    right = np.array(cv2.imread(r_img_path)[:, :, ::-1])
    alpha_left = cv2.imread(l_img_path].replace("fgr", "pha"))
    if len(alpha_left.shape) > 2:
        alpha_left = alpha_left[:, :, 0]
    alpha_right = cv2.imread(r_img_path.replace("fgr", "pha"))
    if len(alpha_right.shape) > 2:
        alpha_right = alpha_right[:, :, 0]
    
    # Check ROI
    x1, y1, x2, y2 = 0, 0, left.shape
    ys, xs = np.where(alpha_left > 10)
    ys.min(), ys.max(), xs.min(), xs.max()

In [27]:
left = torch.from_numpy(left).permute(2, 0, 1)[None]

In [29]:
right = torch.from_numpy(right).permute(2, 0, 1)[None]

In [34]:
with torch.no_grad():
    _, flow1 = raft_model(left.cuda(), right.cuda())

RuntimeError: CUDA out of memory. Tried to allocate 62.57 GiB (GPU 0; 22.02 GiB total capacity; 808.62 MiB already allocated; 18.43 GiB free; 1.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [37]:
right.shape

torch.Size([1, 3, 2160, 3840])

In [41]:
alpha_left = cv2.imread(image_pairs[0][0].replace("fgr", "pha"))[:, :, 0]
alpha_right = cv2.imread(image_pairs[0][1].replace("fgr", "pha"))[:, :, 0]

In [50]:
ys, xs = np.where(alpha_left > 10)
ys.min(), ys.max(), xs.min(), xs.max()

(88, 2159, 1985, 3663)

In [51]:
ys, xs = np.where(alpha_right > 10)
ys.min(), ys.max(), xs.min(), xs.max()

(88, 2159, 1985, 3660)