# Part-swap demo for paper "Motion Supervised co-part Segmentation"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/DeepFake

/content/drive/MyDrive/DeepFake


**Clone repository**

In [None]:
!git clone https://github.com/AliaksandrSiarohin/motion-cosegmentation motion-co-seg

fatal: destination path 'motion-co-seg' already exists and is not an empty directory.


In [None]:
cd motion-co-seg/

/content/drive/MyDrive/DeepFake/motion-co-seg


**Mount your Google drive folder on Colab**

**Add shortcut of https://drive.google.com/open?id=1SsBifjoM_qO0iFzb8wLlsz_4qW2j8dZe to your google drive.**


**Load target video and source image**

In [None]:
!pip install pyyaml==5.4.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyyaml==5.4.1
  Downloading PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m662.4/662.4 KB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyyaml
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 6.0
    Uninstalling PyYAML-6.0:
      Successfully uninstalled PyYAML-6.0
Successfully installed pyyaml-5.4.1


In [None]:
!pip install imageio-ffmpeg
!pip install ffmpeg

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting imageio-ffmpeg
  Downloading imageio_ffmpeg-0.4.8-py3-none-manylinux2010_x86_64.whl (26.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.9/26.9 MB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imageio-ffmpeg
Successfully installed imageio-ffmpeg-0.4.8
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ffmpeg
  Downloading ffmpeg-1.4.tar.gz (5.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ffmpeg
  Building wheel for ffmpeg (setup.py) ... [?25l[?25hdone
  Created wheel for ffmpeg: filename=ffmpeg-1.4-py3-none-any.whl size=6084 sha256=39160e76ffeb833414699de0adf81f57c655c8143c8fae2e5fa8b9ec7b018714
  Stored in directory: /root/.cache/pip/wheels/30/33/46/5ab7eca55b9490dddbf3441c68a29535996270ef1ce8b9b6d7
Success

In [None]:
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
warnings.filterwarnings("ignore")

source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/26.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/04.mp4')

#Resize image and video to 256x256

source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

def display(source, target, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(target)):
        cols = [source]
        cols.append(target[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani
    

HTML(display(source_image, target_video).to_html5_video())

**Loading checkpoints with 10 parts**

In [None]:
%cd /content/drive/MyDrive/DeepFake/motion-co-seg

/content/drive/MyDrive/DeepFake/motion-co-seg


In [None]:
%pwd

'/content/drive/MyDrive/DeepFake/motion-co-seg'

In [None]:
from part_swap import load_checkpoints

reconstruction_module, segmentation_module = load_checkpoints(config='config/vox-256-sem-10segments.yaml', 
                                               checkpoint='/content/drive/MyDrive/DeepFake/samples/ vox-10segments.pth.tar',
                                               blend_scale=1)

**Visualizing the segmentation**

In [None]:
import torch
import torch.nn.functional as F

import matplotlib.patches as mpatches

def visualize_segmentation(image, network, supervised=False, hard=True, colormap='gist_rainbow'):
    with torch.no_grad():
        inp = torch.tensor(image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
        if supervised:
            inp = F.interpolate(inp, size=(512, 512))
            inp = (inp - network.mean) / network.std
            mask = torch.softmax(network(inp)[0], dim=1)
            mask = F.interpolate(mask, size=image.shape[:2])
        else:
            mask = network(inp)['segmentation']
            mask = F.interpolate(mask, size=image.shape[:2], mode='bilinear')
    
    if hard:
        mask = (torch.max(mask, dim=1, keepdim=True)[0] == mask).float()
    
    colormap = plt.get_cmap(colormap)
    num_segments = mask.shape[1]
    mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy()
    color_mask = 0
    patches = []
    for i in range(num_segments):
        if i != 0:
            color = np.array(colormap((i - 1) / (num_segments - 1)))[:3]
        else:
            color = np.array((0, 0, 0))
        patches.append(mpatches.Patch(color=color, label=str(i)))
        color_mask += mask[..., i:(i+1)] * color.reshape(1, 1, 3)
    
    fig, ax = plt.subplots(1, 2, figsize=(12,6))

    ax[0].imshow(color_mask)
    ax[1].imshow(0.3 * image + 0.7 * color_mask)
    ax[1].legend(handles=patches)
    ax[0].axis('off')
    ax[1].axis('off')

visualize_segmentation(source_image, segmentation_module, hard=True)
plt.show()

**Identify index of the part that you want to swap. For example to make trump with red lips part 2 should be used**

In [None]:
from part_swap import make_video

predictions = make_video(swap_index=[2], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|██████████| 211/211 [00:09<00:00, 21.88it/s]


In [None]:
# Saving result video
from skimage import img_as_ubyte
imageio.mimsave('../result.mp4', [img_as_ubyte(frame) for frame in predictions], fps=30)

**Changing eye color**

In [None]:
source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/26.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/04.mp4')
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[7,9], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|██████████| 211/211 [00:09<00:00, 22.14it/s]


In [None]:
del predictions
del source_image
del target_video
del reconstruction_module
del segmentation_module

**Examples with 5-segments model**

In [None]:
%cd /content/drive/MyDrive/DeepFake/motion-co-seg

/content/drive/MyDrive/DeepFake/motion-co-seg


In [None]:
!pip install imageio-ffmpeg
!pip install ffmpeg

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from part_swap import load_checkpoints

reconstruction_module, segmentation_module = load_checkpoints(config='config/vox-256-sem-5segments.yaml', 
                                               checkpoint='/content/drive/MyDrive/DeepFake/samples/vox-5segments.pth.tar',
                                               blend_scale=1)

In [None]:
source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/K.png')
source_image = resize(source_image, (256, 256))[..., :3]
visualize_segmentation(source_image, segmentation_module, hard=True)
plt.show()

**Changing hair**

In [None]:
from part_swap import make_video

source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/26.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/04.mp4')
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[3, 4, 5], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

NameError: ignored

**Source segmentation can be used if warped source region will be outside target mask**

In [None]:
from part_swap import make_video

source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/K.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/J1.mp4')
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[3, 4,5], source_image = source_image, target_video = target_video, use_source_segmentation=True,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|██████████| 168/168 [00:07<00:00, 21.14it/s]


**Adding Beard**

In [None]:
source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/K.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/J1.mp4')
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[1], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|██████████| 168/168 [00:07<00:00, 21.16it/s]


In [None]:
del predictions
del source_image
del target_video
del reconstruction_module
del segmentation_module

**For the reference we provide a method for supervised part-swaps**

**Download model of @zllrunning for face parsing**

In [None]:
%cd /content/drive/MyDrive/DeepFake

/content/drive/MyDrive/DeepFake


In [None]:
!git clone https://github.com/AliaksandrSiarohin/face-makeup.PyTorch face_parsing

fatal: destination path 'face_parsing' already exists and is not an empty directory.


In [None]:
from part_swap import load_face_parser
face_parser = load_face_parser(cpu=False)

source_image = imageio.imread('/content/drive/MyDrive/DeepFake/samples/K.png')
target_video = imageio.mimread('/content/drive/MyDrive/DeepFake/samples/J1.mp4')

#Resize image and video to 256x256

source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

visualize_segmentation(source_image, face_parser, supervised=True, hard=True, colormap='tab20')
plt.show()

In [None]:
from part_swap import load_checkpoints

reconstruction_module, segmentation_module = load_checkpoints(config='/content/drive/MyDrive/DeepFake/motion-co-seg/config/vox-256-sem-10segments.yaml', 
                                               checkpoint='/content/drive/MyDrive/DeepFake/samples/vox-first-order.pth.tar',
                                               blend_scale=0.125, first_order_motion_model=True)


Segmentation part initialized at random.


In [None]:
predictions = make_video(swap_index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], source_image = source_image,
                         target_video = target_video, use_source_segmentation=True, segmentation_module=segmentation_module,
                         reconstruction_module=reconstruction_module, face_parser=face_parser)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|██████████| 168/168 [00:12<00:00, 13.82it/s]
