In [1]:
from deepstab.flow import estimate_flow, warp_tensor
from deepstab.utils import cutout_mask
from deepstab.visualize import animate_sequence
%matplotlib inline

In [2]:
from matplotlib import  rc
rc('animation', html='html5')
import flowiz as fz
from torchvision import transforms
from deepstab.load import ObjectMaskVideoDataset, SquareMaskVideoDataset
from torchvision.transforms.functional import to_pil_image
from deepstab.pwcnet import PWCNet
from PIL.ImageOps import invert

In [3]:
frame_dir = '../data/raw/video/DAVIS/JPEGImages/480p/breakdance'
mask_dir = '../data/raw/video/DAVIS/Annotations_unsupervised/480p/breakdance'
model_path = '../models/pwcnet/network-default.pytorch'

In [4]:
plain_transforms = transforms.Compose([
    transforms.CenterCrop((512, 512))
])
further_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.flip(2))
])

model = PWCNet(model_path).cuda().eval()

def animate_flow_sequence(sequence):
    result_masked_frames = []
    result_up_flow_frames = []
    result_kp_flow_frames = []
    for i in range(len(sequence[2]) - 1):
        flow = estimate_flow(model, sequence[2][i], sequence[2][i + 1]).detach()
        mask = to_pil_image(sequence[1][i].flip(2))
        masked_frame = to_pil_image(sequence[2][i].flip(2))
        flow_frame = to_pil_image(fz.convert_from_flow(flow.flip(2).numpy().transpose(1, 2, 0)))
        result_masked_frames.append(masked_frame)
        result_up_flow_frames.append(cutout_mask(flow_frame, invert(mask), dilate_mask=False))
        result_kp_flow_frames.append(cutout_mask(flow_frame, mask, dilate_mask=True))
    return animate_sequence(result_masked_frames, result_up_flow_frames, result_kp_flow_frames)

In [5]:
object_mask_ds = ObjectMaskVideoDataset([frame_dir], [mask_dir], 
                                       frame_transform=plain_transforms, mask_transform=plain_transforms, 
                                       transform=further_transforms)
sequence = object_mask_ds[0]
animate_flow_sequence(sequence)

In [6]:
square_mask_ds = SquareMaskVideoDataset([frame_dir], 
                                       frame_transform=plain_transforms, mask_transform=plain_transforms, 
                                       transform=further_transforms)
sequence = square_mask_ds[0]
animate_flow_sequence(sequence)