In [None]:
from os.path import basename, isfile, join, splitext

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from insightface_func.face_detect_crop_single import Face_detect_crop
from models.models import create_model
from options.test_options import TestOptions

import os
import shutil
from os.path import basename, exists, isfile, join, splitext

import cv2
import numpy as np
import torch
from tqdm import tqdm
import onnxruntime
from util.videoswap import lower_video_resolution, extract_audio, get_frames_n, _totensor

import warnings
onnxruntime.set_default_logger_severity(3)
torch.nn.Module.dump_patches = True
warnings.filterwarnings('ignore')

from face_parsing.bisenet import BiSeNet


seg_model = BiSeNet(n_classes=19)
seg_model.cuda()
save_pth = os.path.join('weights', '79999_iter.pth')
seg_model.load_state_dict(torch.load(save_pth))
seg_model.eval()


model, app = None, None
transformer_Arcface = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

opt = TestOptions()
opt.initialize()
opt.parser.add_argument('-f')  # dummy arg to avoid bug
opt = opt.parse()
opt.Arc_path = './weights/arcface_checkpoint.tar'
opt.isTrain = False
torch.nn.Module.dump_patches = True
global model
model = create_model(opt)
model.eval()
global app
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
app.prepare(ctx_id=0, det_thresh=0.6, det_size=(256, 256))

# source = '../reference_videos/gen_0.jpg'
# target = '../reference_videos/stocks/man_2.mp4'
source = '../reference_videos/gen_1.jpg'
target = 'IMG_1269.MOV'
result_dir='./output'
crop_size=224
from skimage.exposure import rescale_intensity


assert isfile(source), f'Can\'t find source at {source}'
assert isfile(target), f'Can\'t find target at {target}'
output_filename = f'infer-{splitext(basename(source))[0]}-{splitext(basename(target))[0]}.mp4'
output_path = join(result_dir, output_filename)

assert model is not None
assert app is not None

img_a_whole = cv2.imread(source)
img_a_align_crop, _ = app.get(img_a_whole, crop_size)
img_a_align_crop_pil = Image.fromarray(
    cv2.cvtColor(img_a_align_crop[0], cv2.COLOR_BGR2RGB))
img_a = transformer_Arcface(img_a_align_crop_pil)
img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
img_id = img_id.cuda()

img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
latend_id = model.netArc(img_id_downsample)
latend_id = latend_id.detach().to('cpu')
latend_id = latend_id / np.linalg.norm(latend_id, axis=1, keepdims=True)
latend_id = latend_id.to('cuda')

import torchvision.transforms as transforms
from fsr.models.SRGAN_model import SRGANModel
import easydict

esrgan_fsr_transform = transforms.Compose([transforms.Resize((128, 128)),
                                 transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                      std=[0.5, 0.5, 0.5])])

args = easydict.EasyDict({
    'gpu_ids': None,
    'batch_size': 32,
    'lr_G': 1e-4,
    'weight_decay_G': 0,
    'beta1_G': 0.9,
    'beta2_G': 0.99,
    'lr_D': 1e-4,
    'weight_decay_D': 0,
    'beta1_D': 0.9,
    'beta2_D': 0.99,
    'lr_scheme': 'MultiStepLR',
    'niter': 100000,
    'warmup_iter': -1,
    'lr_steps': [50000],
    'lr_gamma': 0.5,
    'pixel_criterion': 'l1',
    'pixel_weight': 1e-2,
    'feature_criterion': 'l1',
    'feature_weight': 1,
    'gan_type': 'ragan',
    'gan_weight': 5e-3,
    'D_update_ratio': 1,
    'D_init_iters': 0,

    'print_freq': 100,
    'val_freq': 1000,
    'save_freq': 10000,
    'crop_size': 0.85,
    'lr_size': 128,
    'hr_size': 512,

    # network G
    'which_model_G': 'RRDBNet',
    'G_in_nc': 3,
    'out_nc': 3,
    'G_nf': 64,
    'nb': 16,

    # network D
    'which_model_D': 'discriminator_vgg_128',
    'D_in_nc': 3,
    'D_nf': 64,

    # data dir
    'pretrain_model_G': 'weights/90000_G.pth',
    'pretrain_model_D': None
})


esrgan_fsr_model = SRGANModel(args, is_train=False)
esrgan_fsr_model.load()
esrgan_fsr_model.netG.to('cuda')
esrgan_fsr_model.netG.eval();

In [None]:
import matplotlib.pyplot as plt

# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
# 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']


def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, save_path=''):
    target_image_list = []
    img_mask_list = []
    for swaped_img, mat in zip(swaped_imgs, mats):
        print('swaped_img:'); plt.imshow(swaped_img.cpu().detach().numpy().transpose((1, 2, 0))); plt.show() ### 

        swaped_img_ready = F.interpolate(swaped_img.unsqueeze(0), size=(512,512))
        seg_mask_logits = seg_model(swaped_img_ready)[0]
        seg_mask_logits = F.interpolate(seg_mask_logits, size=(crop_size, crop_size))
        seg_mask = seg_mask_logits.squeeze().cpu().detach().numpy().argmax(0).astype(np.uint8)
        face_part_ids = [1, 2, 3, 4, 5, 6, 10]
        mouth_ids = [11, 12, 13]
        img_mask = np.zeros_like(seg_mask)
        img_mask[np.isin(seg_mask, face_part_ids)] = 255
        
        print('img_mask:'); plt.imshow(img_mask); plt.show() ###
        print(img_mask.shape, img_mask.min(), img_mask.max())

        img_mouth = np.zeros([seg_mask.shape[0], seg_mask.shape[1]])
        img_mouth[seg_mask == 11] = 255
        img_mouth[seg_mask == 12] = 255
        img_mouth[seg_mask == 13] = 255
        print('img_mouth:'); plt.imshow(img_mouth); plt.show() ###
        print(img_mouth.shape, img_mouth.min(), img_mouth.max())

        # select and fill the biggest contour (in case of face hair)
        contours, _ = cv2.findContours(img_mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
        img_mask_ = np.zeros_like(img_mask)
        cv2.drawContours(img_mask_, [max(contours, key = cv2.contourArea)], 0, 255, -1)
        img_mask = np.array(img_mask_)
        img_mask[np.isin(seg_mask, mouth_ids)] = 0
        print('img_mask:'); plt.imshow(img_mask); plt.show() ###
        print(img_mask.shape, img_mask.min(), img_mask.max())

        # median blur to smooth sharp edges
        img_mask = cv2.medianBlur(img_mask.astype(np.uint8), 15)
        # dilate face region
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
        img_mask = cv2.dilate(img_mask, kernel, iterations=1)
        # make dilated region transparent to the border 
        blur = cv2.GaussianBlur(img_mask, (35, 35), 0, 0)
        img_mask = rescale_intensity(blur, in_range=(127.5,255), out_range=(0,255))

        # SR: ESRGAN (https://github.com/ewrfcas/Face-Super-Resolution)
        swaped_img = esrgan_fsr_transform(torch.clone(swaped_img))
        swaped_img = esrgan_fsr_model.netG(swaped_img.unsqueeze(0))
        swaped_img = swaped_img.squeeze(0).cpu().detach().numpy().transpose((1, 2, 0))
        swaped_img = np.clip(swaped_img / 2.0 + 0.5, 0, 1)

        mat_rev = cv2.invertAffineTransform(mat)
        mat_rev_face = np.array(mat_rev)
        mat_rev_face[:2, :2] = mat_rev_face[:2, :2] / (swaped_img.shape[0] / crop_size)
        orisize = (oriimg.shape[1], oriimg.shape[0])

        target_image = cv2.warpAffine(swaped_img, mat_rev_face, orisize)
        target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
        target_image_list.append(target_image)

        img_mask = cv2.warpAffine(img_mask / 255, mat_rev, orisize)
        img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1])
        img_mask_list.append(img_mask)

    img = np.array(oriimg, dtype=np.float64)
    for img_mask, target_image in zip(img_mask_list, target_image_list):
        img = img_mask * target_image + (1-img_mask) * img

    final_img = img.astype(np.uint8)
    print('final_img-RGB:'); plt.imshow(cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)); plt.show() ###

    cv2.imwrite(save_path, final_img)


video_path = target
temp_results_dir='./temp_results'
swap_model = model
detect_model = app
id_veсtor = latend_id

lower_video_resolution(video_path)
print(f'=> Swapping face in "{video_path}"...')
if exists(temp_results_dir):
    shutil.rmtree(temp_results_dir)
os.makedirs(temp_results_dir)

audio_path = join(temp_results_dir, splitext(basename(video_path))[0] + '.wav')
extract_audio(video_path, audio_path)

frame_count = get_frames_n(video_path)

video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)

for i, frame_index in tqdm(enumerate(range(frame_count))): 
    _, frame = video.read()
    if frame_index != 216: #169 - forehead # 216 - mouth
        continue
    detect_results = detect_model.get(frame, crop_size)     

    if detect_results is not None:
        frame_align_crop_list = detect_results[0]
        frame_mat_list = detect_results[1]
        swap_result_list = []

        for frame_align_crop in frame_align_crop_list:
            frame_align_crop_tensor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()

            swap_result = swap_model(None, frame_align_crop_tensor, id_veсtor, None, True)[0]
            swap_result_list.append(swap_result)
        reverse2wholeimage(swap_result_list, frame_mat_list, crop_size, frame, join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)))
    else:
        frame = frame.astype(np.uint8)
        cv2.imwrite(join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)