In [1]:
from deepstab.model_gatingconvolution import GatingConvolutionUNet
%matplotlib inline

In [2]:
import torch

from matplotlib import  rc
rc('animation', html='html5')

from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from deepstab.load import ObjectMaskVideoDataset, SquareMaskVideoDataset
from deepstab.visualize import animate_sequence


In [3]:
frame_dir = '../data/raw/video/DAVIS/JPEGImages/480p/breakdance'
mask_dir = '../data/raw/video/DAVIS/Annotations_unsupervised/480p/breakdance'
checkpoint_path = '../models/model_epoch_400_lr_0.0001.pth'

In [4]:
plain_transforms = transforms.Compose([
    transforms.CenterCrop((512, 512)),
])
further_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

model = GatingConvolutionUNet().eval().cuda()

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['generator'])

def animate_inpainting_sequence(sequence):
    masked_frames = []
    inpainted_frames = []
    frames = []
    for masked_frame, mask, frame in zip(sequence[2], sequence[1], sequence[0]):
        masked_frame = masked_frame.cuda().view(1, 3, 256, 256)
        mask = mask.cuda().view(1, 1, 256, 256)
        inpainted_frame = model(masked_frame, mask)
        masked_frames.append(to_pil_image(masked_frame.detach().cpu().view(3, 256, 256)))
        inpainted_frames.append(to_pil_image(inpainted_frame.detach().cpu().view(3, 256, 256)))
        frames.append(to_pil_image(frame))
    return animate_sequence(masked_frames, inpainted_frames, frames)

In [5]:
object_mask_ds = ObjectMaskVideoDataset([frame_dir], [mask_dir], 
                                       frame_transform=plain_transforms, mask_transform=plain_transforms, 
                                       transform=further_transforms)
sequence = object_mask_ds[0]
animate_inpainting_sequence(sequence)

In [6]:
square_mask_ds = SquareMaskVideoDataset([frame_dir], 
                                       frame_transform=plain_transforms, mask_transform=plain_transforms, 
                                       transform=further_transforms)
sequence = square_mask_ds[0]
animate_inpainting_sequence(sequence)