# MakeItTalk Quick Demo (natural human face animation)

- included project setup + pretrained model download
- provides step-by-step details
- todo: tdlr version

## Preparations
- Check GPU

In [None]:
!pip uninstall librosa
!pip install librosa==0.8.1

import librosa

print(librosa.__version__)

In [None]:
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
import subprocess
print(subprocess.getoutput('nvidia-smi'))

- Check ffmpeg

In [None]:
print(subprocess.getoutput('ffmpeg'))

- Install Github https://github.com/yzhou359/MakeItTalk

In [None]:
!git clone https://github.com/yzhou359/MakeItTalk

- Install requirements

In [None]:
%cd MakeItTalk/
!export PYTHONPATH=/content/MakeItTalk:$PYTHONPATH
!pip install -r requirements.txt
!pip install tensorboardX

In [None]:
!sed -i 's/mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T/mel(sr=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80).T/' /content/MakeItTalk/src/autovc/retrain_version/vocoder_spec/extract_f0_func.py

In [None]:
from google.colab import drive
drive.mount('/content/drive')

- Download pretrained models

In [None]:
# !gdown -O examples/ckpt/ckpt_autovc.pth https://drive.google.com/uc?id=1ZiwPp_h62LtjU0DwpelLUoodKPR85K7x 
!cp /content/drive/MyDrive/AI/ckpt_autovc.pth examples/ckpt/ckpt_autovc.pth
!cp /content/drive/MyDrive/AI/ckpt_116_i2i_comb.pth examples/ckpt/ckpt_116_i2i_comb.pth

In [None]:
!mkdir examples/dump
!mkdir examples/ckpt
!pip install gdown
!gdown -O examples/ckpt/ckpt_autovc.pth https://drive.google.com/uc?id=1ZiwPp_h62LtjU0DwpelLUoodKPR85K7x
!gdown -O examples/ckpt/ckpt_content_branch.pth https://drive.google.com/uc?id=1r3bfEvTVl6pCNw5xwUhEglwDHjWtAqQp
!gdown -O examples/ckpt/ckpt_speaker_branch.pth https://drive.google.com/uc?id=1rV0jkyDqPW-aDJcj7xSO6Zt1zSXqn1mu
!gdown -O examples/ckpt/ckpt_116_i2i_comb.pth https://drive.google.com/uc?id=1i2LJXKp-yWKIEEgJ7C6cE3_2NirfY_0a
!gdown -O examples/dump/emb.pickle https://drive.google.com/uc?id=18-0CYl5E6ungS3H4rRSHjfYvvm-WwjTI

- prepare your images/audios (or you can use the existing ones)
  - An image to animate: upload to `MakeItTalk/examples` folder, image size should be 256x256
  - An audio (hopefully no noise) to talk: upload to `MakeItTalk/examples` folder as well

## Step 0: import necessary packages

In [None]:
import sys
sys.path.append("thirdparty/AdaptiveWingLoss")
import os, glob
import numpy as np
import cv2
import argparse
from src.approaches.train_image_translation import Image_translation_block
import torch
import pickle
import face_alignment
from src.autovc.AutoVC_mel_Convertor_retrain_version import AutoVC_mel_Convertor
import shutil
import time
import util.utils as util
from scipy.signal import savgol_filter
from src.approaches.train_audio2landmark import Audio2landmark_model

## Step 1: Basic setup for the animation

In [None]:
default_head_name = 'girl_detail'           # the image name (with no .jpg) to animate
ADD_NAIVE_EYE = True                 # whether add naive eye blink
CLOSE_INPUT_FACE_MOUTH = False       # if your image has an opened mouth, put this as True, else False
AMP_LIP_SHAPE_X = 2.                 # amplify the lip motion in horizontal direction
AMP_LIP_SHAPE_Y = 2.                 # amplify the lip motion in vertical direction
AMP_HEAD_POSE_MOTION = 0.7           # amplify the head pose motion (usually smaller than 1.0, put it to 0. for a static head pose)

In [None]:
!pip install face_recognition


In [None]:
from PIL import Image

# 打开原始图片
image = Image.open("/content/MakeItTalk/examples/girl.jpg")

# 设置缩放比例
scale = 0.7  # 50%缩放

# 计算新的大小
new_width = int(image.width * scale)
new_height = int(image.height * scale)

# 缩放图片
image.thumbnail((new_width, new_height))

# 保存缩放后的图片
image.save("/content/MakeItTalk/examples/girl_0.jpg")


In [None]:
!pip install face_recognition

In [None]:
import face_recognition
from PIL import Image

# 加载图片
image = face_recognition.load_image_file("/content/MakeItTalk/examples/girl_0.jpg")

# 定位脸部位置
face_locations = face_recognition.face_locations(image)

left=face_locations[0][3]-100
top=face_locations[0][0]-50
right=face_locations[0][3]+156
bottom=face_locations[0][0]+206

# 裁剪出脸部图片
face_image = Image.fromarray(image).crop((left, top, right, bottom))
print(face_locations[0][3], face_locations[0][0], face_locations[0][1], face_locations[0][2])
print(right, top, left, bottom)
print(right-left, bottom-top)
# !identify /content/MakeItTalk/examples/girl_0.jpg
# 保存脸部图像
face_image.save("/content/MakeItTalk/examples/girl_detail.jpg")

# 将脸部图片粘贴到原始图片上
image_pil = Image.fromarray(image)
image_pil.paste(face_image, (face_locations[0][3], face_locations[0][0]))

# 保存新的图片
image_pil.save("/content/MakeItTalk/examples/girl_1.jpg")


In [None]:
!git reset --hard

In [None]:
!identify /content/MakeItTalk/examples/girl_0.jpg

In [None]:
!apt-get install imagemagick
!identify /content/MakeItTalk/examples/girl.jpg

Default hyper-parameters for the model.

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--jpg', type=str, default='{}.jpg'.format(default_head_name))
parser.add_argument('--close_input_face_mouth', default=CLOSE_INPUT_FACE_MOUTH, action='store_true')

parser.add_argument('--load_AUTOVC_name', type=str, default='examples/ckpt/ckpt_autovc.pth')
parser.add_argument('--load_a2l_G_name', type=str, default='examples/ckpt/ckpt_speaker_branch.pth')
parser.add_argument('--load_a2l_C_name', type=str, default='examples/ckpt/ckpt_content_branch.pth') #ckpt_audio2landmark_c.pth')
parser.add_argument('--load_G_name', type=str, default='examples/ckpt/ckpt_116_i2i_comb.pth') #ckpt_image2image.pth') #ckpt_i2i_finetune_150.pth') #c

parser.add_argument('--amp_lip_x', type=float, default=AMP_LIP_SHAPE_X)
parser.add_argument('--amp_lip_y', type=float, default=AMP_LIP_SHAPE_Y)
parser.add_argument('--amp_pos', type=float, default=AMP_HEAD_POSE_MOTION)
parser.add_argument('--reuse_train_emb_list', type=str, nargs='+', default=[]) #  ['iWeklsXc0H8']) #['45hn7-LXDX8']) #['E_kmpT-EfOg']) #'iWeklsXc0H8', '29k8RtSUjE0', '45hn7-LXDX8',
parser.add_argument('--add_audio_in', default=False, action='store_true')
parser.add_argument('--comb_fan_awing', default=False, action='store_true')
parser.add_argument('--output_folder', type=str, default='examples')

parser.add_argument('--test_end2end', default=True, action='store_true')
parser.add_argument('--dump_dir', type=str, default='', help='')
parser.add_argument('--pos_dim', default=7, type=int)
parser.add_argument('--use_prior_net', default=True, action='store_true')
parser.add_argument('--transformer_d_model', default=32, type=int)
parser.add_argument('--transformer_N', default=2, type=int)
parser.add_argument('--transformer_heads', default=2, type=int)
parser.add_argument('--spk_emb_enc_size', default=16, type=int)
parser.add_argument('--init_content_encoder', type=str, default='')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--reg_lr', type=float, default=1e-6, help='weight decay')
parser.add_argument('--write', default=False, action='store_true')
parser.add_argument('--segment_batch_size', type=int, default=1, help='batch size')
parser.add_argument('--emb_coef', default=3.0, type=float)
parser.add_argument('--lambda_laplacian_smooth_loss', default=1.0, type=float)
parser.add_argument('--use_11spk_only', default=False, action='store_true')
parser.add_argument('-f')

opt_parser = parser.parse_args()

In [None]:
!ffmpeg -i /content/MakeItTalk/examples/girl.png /content/MakeItTalk/examples/girl.jpg

In [None]:
!identify /content/MakeItTalk/examples/hermione.jpg

## Step 2: load the image and detect its landmark

In [None]:
# 获取 LandmarksType 枚举的所有成员
landmarks_type_enum = face_alignment.LandmarksType
members = [member.name for member in landmarks_type_enum]

# 打印所有成员
print(members)

In [None]:
img =cv2.imread('examples/' + opt_parser.jpg)
predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType.THREE_D, device='cpu', flip_input=True)
shapes = predictor.get_landmarks(img)
if (not shapes or len(shapes) != 1):
    print('Cannot detect face landmarks. Exit.')
    exit(-1)
shape_3d = shapes[0]

if(opt_parser.close_input_face_mouth):
    util.close_input_face_mouth(shape_3d)

## (Optional) Simple manual adjustment to landmarks in case FAN is not accurate, e.g.
- slimmer lips
- wider eyes
- wider mouth

In [72]:
shape_3d[48:, 0] = (shape_3d[48:, 0] - np.mean(shape_3d[48:, 0])) * 1.05 + np.mean(shape_3d[48:, 0]) # wider lips
shape_3d[49:54, 1] += 0.           # thinner upper lip
shape_3d[55:60, 1] -= 1.           # thinner lower lip
shape_3d[[37,38,43,44], 1] -=2.    # larger eyes
shape_3d[[40,41,46,47], 1] +=2.    # larger eyes

Normalize face as input to audio branch

In [73]:
shape_3d, scale, shift = util.norm_input_face(shape_3d)

## Step 3: Generate input data for inference based on uploaded audio `MakeItTalk/examples/*.wav`

In [74]:
au_data = []
au_emb = []
ains = glob.glob1('examples', '*.wav')
ains = [item for item in ains if item is not 'tmp.wav']
ains.sort()
for ain in ains:
    os.system('ffmpeg -y -loglevel error -i examples/{} -ar 16000 examples/tmp.wav'.format(ain))
    shutil.copyfile('examples/tmp.wav', 'examples/{}'.format(ain))

    # au embedding
    from thirdparty.resemblyer_util.speaker_emb import get_spk_emb
    me, ae = get_spk_emb('examples/{}'.format(ain))
    au_emb.append(me.reshape(-1))

    print('Processing audio file', ain)
    c = AutoVC_mel_Convertor('examples')

    au_data_i = c.convert_single_wav_to_autovc_input(audio_filename=os.path.join('examples', ain),
           autovc_model_path=opt_parser.load_AUTOVC_name)
    au_data += au_data_i
if(os.path.isfile('examples/tmp.wav')):
    os.remove('examples/tmp.wav')

# landmark fake placeholder
fl_data = []
rot_tran, rot_quat, anchor_t_shape = [], [], []
for au, info in au_data:
    au_length = au.shape[0]
    fl = np.zeros(shape=(au_length, 68 * 3))
    fl_data.append((fl, info))
    rot_tran.append(np.zeros(shape=(au_length, 3, 4)))
    rot_quat.append(np.zeros(shape=(au_length, 4)))
    anchor_t_shape.append(np.zeros(shape=(au_length, 68 * 3)))

if(os.path.exists(os.path.join('examples', 'dump', 'random_val_fl.pickle'))):
    os.remove(os.path.join('examples', 'dump', 'random_val_fl.pickle'))
if(os.path.exists(os.path.join('examples', 'dump', 'random_val_fl_interp.pickle'))):
    os.remove(os.path.join('examples', 'dump', 'random_val_fl_interp.pickle'))
if(os.path.exists(os.path.join('examples', 'dump', 'random_val_au.pickle'))):
    os.remove(os.path.join('examples', 'dump', 'random_val_au.pickle'))
if (os.path.exists(os.path.join('examples', 'dump', 'random_val_gaze.pickle'))):
    os.remove(os.path.join('examples', 'dump', 'random_val_gaze.pickle'))

with open(os.path.join('examples', 'dump', 'random_val_fl.pickle'), 'wb') as fp:
    pickle.dump(fl_data, fp)
with open(os.path.join('examples', 'dump', 'random_val_au.pickle'), 'wb') as fp:
    pickle.dump(au_data, fp)
with open(os.path.join('examples', 'dump', 'random_val_gaze.pickle'), 'wb') as fp:
    gaze = {'rot_trans':rot_tran, 'rot_quat':rot_quat, 'anchor_t_shape':anchor_t_shape}
    pickle.dump(gaze, fp)

In [94]:
"""
 # Copyright 2020 Adobe
 # All Rights Reserved.
 
 # NOTICE: Adobe permits you to use, modify, and distribute this file in
 # accordance with the terms of the Adobe license agreement accompanying
 # it.
 
"""

from src.models.model_image_translation import ResUnetGenerator, VGGLoss
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import time
import numpy as np
import cv2
import os, glob
from src.dataset.image_translation.image_translation_dataset import vis_landmark_on_img, vis_landmark_on_img98, vis_landmark_on_img74


from thirdparty.AdaptiveWingLoss.core import models
from thirdparty.AdaptiveWingLoss.utils.utils import get_preds_fromhm

import face_alignment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Image_translation_block1():

    def __init__(self, opt_parser, single_test=False):
        print('Run on device {}'.format(device))

        # for key in vars(opt_parser).keys():
        #     print(key, ':', vars(opt_parser)[key])
        self.opt_parser = opt_parser

        # model
        if(opt_parser.add_audio_in):
            self.G = ResUnetGenerator(input_nc=7, output_nc=3, num_downs=6, use_dropout=False)
        else:
            self.G = ResUnetGenerator(input_nc=6, output_nc=3, num_downs=6, use_dropout=False)

        if (opt_parser.load_G_name != ''):
            ckpt = torch.load(opt_parser.load_G_name)
            try:
                self.G.load_state_dict(ckpt['G'])
            except:
                tmp = nn.DataParallel(self.G)
                tmp.load_state_dict(ckpt['G'])
                self.G.load_state_dict(tmp.module.state_dict())
                del tmp

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs in G mode!")
            self.G = nn.DataParallel(self.G)

        self.G.to(device)

        if(not single_test):
            # dataset
            if(opt_parser.use_vox_dataset == 'raw'):
                if(opt_parser.comb_fan_awing):
                    from src.dataset.image_translation.image_translation_dataset import \
                        image_translation_raw74_dataset as image_translation_dataset
                elif(opt_parser.add_audio_in):
                    from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_with_audio_dataset as \
                        image_translation_dataset
                else:
                    from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_dataset as \
                    image_translation_dataset
            else:
                from src.dataset.image_translation.image_translation_dataset import image_translation_preprocessed98_dataset as \
                    image_translation_dataset

            self.dataset = image_translation_dataset(num_frames=opt_parser.num_frames)
            self.dataloader = torch.utils.data.DataLoader(self.dataset,
                                                          batch_size=opt_parser.batch_size,
                                                          shuffle=True,
                                                          num_workers=opt_parser.num_workers)

            # criterion
            self.criterionL1 = nn.L1Loss()
            self.criterionVGG = VGGLoss()
            if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs in VGG model!")
                self.criterionVGG = nn.DataParallel(self.criterionVGG)
            self.criterionVGG.to(device)

            # optimizer
            self.optimizer = torch.optim.Adam(self.G.parameters(), lr=opt_parser.lr, betas=(0.5, 0.999))

            # writer
            if(opt_parser.write):
                self.writer = SummaryWriter(log_dir=os.path.join(opt_parser.log_dir, opt_parser.name))
                self.count = 0

            # ===========================================================
            #       online landmark alignment : Awing
            # ===========================================================
            PRETRAINED_WEIGHTS = 'thirdparty/AdaptiveWingLoss/ckpt/WFLW_4HG.pth'
            GRAY_SCALE = False
            HG_BLOCKS = 4
            END_RELU = False
            NUM_LANDMARKS = 98

            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS)

            checkpoint = torch.load(PRETRAINED_WEIGHTS)
            if 'state_dict' not in checkpoint:
                model_ft.load_state_dict(checkpoint)
            else:
                pretrained_weights = checkpoint['state_dict']
                model_weights = model_ft.state_dict()
                pretrained_weights = {k: v for k, v in pretrained_weights.items() \
                                      if k in model_weights}
                model_weights.update(pretrained_weights)
                model_ft.load_state_dict(model_weights)
            print('Load AWing model sucessfully')
            if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs for AWing!")
                self.fa_model = nn.DataParallel(model_ft).to(self.device).eval()
            else:
                self.fa_model = model_ft.to(self.device).eval()

            # ===========================================================
            #       online landmark alignment : FAN
            # ===========================================================
            if(opt_parser.comb_fan_awing):
                if(opt_parser.fan_2or3D == '2D'):
                    self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
                                                                  device='cuda' if torch.cuda.is_available() else "cpu",
                                                                  flip_input=True)
                else:
                    self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,
                                                                  device='cuda' if torch.cuda.is_available() else "cpu",
                                                                  flip_input=True)

    def __train_pass__(self, epoch, is_training=True):
        st_epoch = time.time()
        if(is_training):
            self.G.train()
            status = 'TRAIN'
        else:
            self.G.eval()
            status = 'EVAL'

        g_time = 0.0
        for i, batch in enumerate(self.dataloader):
            if(i >= len(self.dataloader)-2):
                break
            st_batch = time.time()

            if(self.opt_parser.comb_fan_awing):
                image_in, image_out, fan_pred_landmarks = batch
                fan_pred_landmarks = fan_pred_landmarks.reshape(-1, 68, 3).detach().cpu().numpy()
            elif(self.opt_parser.add_audio_in):
                image_in, image_out, audio_in = batch
                audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
            else:
                image_in, image_out = batch

            with torch.no_grad():
                # # online landmark (AwingNet)
                image_in, image_out = \
                    image_in.reshape(-1, 3, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)
                inputs = image_out
                outputs, boundary_channels = self.fa_model(inputs)
                pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
                pred_landmarks, _ = get_preds_fromhm(pred_heatmap)
                pred_landmarks = pred_landmarks.numpy() * 4

                # online landmark (FAN) -> replace jaw + eye brow in AwingNet
                if(self.opt_parser.comb_fan_awing):
                    fl_jaw_eyebrow = fan_pred_landmarks[:, 0:27, 0:2]
                    fl_rest = pred_landmarks[:, 51:, :]
                    pred_landmarks = np.concatenate([fl_jaw_eyebrow, fl_rest], axis=1).astype(np.int)

            # draw landmark on while bg
            img_fls = []
            for pred_fl in pred_landmarks:
                img_fl = np.ones(shape=(256, 256, 3)) * 255.0
                if(self.opt_parser.comb_fan_awing):
                    img_fl = vis_landmark_on_img74(img_fl, pred_fl)  # 74x2
                else:
                    img_fl = vis_landmark_on_img98(img_fl, pred_fl)  # 98x2
                img_fls.append(img_fl.transpose((2, 0, 1)))
            img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
            image_fls_in = torch.tensor(img_fls, requires_grad=False).to(device)

            if(self.opt_parser.add_audio_in):
                # print(image_fls_in.shape, image_in.shape, audio_in.shape)
                image_in = torch.cat([image_fls_in, image_in, audio_in], dim=1)
            else:
                image_in = torch.cat([image_fls_in, image_in], dim=1)

            # image_in, image_out = \
            #     image_in.reshape(-1, 6, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)

            # image2image net fp
            g_out = self.G(image_in)
            g_out = torch.tanh(g_out)

            loss_l1 = self.criterionL1(g_out, image_out)
            loss_vgg, loss_style = self.criterionVGG(g_out, image_out, style=True)

            loss_vgg, loss_style = torch.mean(loss_vgg), torch.mean(loss_style)

            loss = loss_l1  + loss_vgg + loss_style
            if(is_training):
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # log
            if(self.opt_parser.write):
                self.writer.add_scalar('loss', loss.cpu().detach().numpy(), self.count)
                self.writer.add_scalar('loss_l1', loss_l1.cpu().detach().numpy(), self.count)
                self.writer.add_scalar('loss_vgg', loss_vgg.cpu().detach().numpy(), self.count)
                self.count += 1

            # save image to track training process
            if (i % self.opt_parser.jpg_freq == 0):
                vis_in = np.concatenate([image_in[0, 3:6].cpu().detach().numpy().transpose((1, 2, 0)),
                                         image_in[0, 0:3].cpu().detach().numpy().transpose((1, 2, 0))], axis=1)
                vis_out = np.concatenate([image_out[0].cpu().detach().numpy().transpose((1, 2, 0)),
                                          g_out[0].cpu().detach().numpy().transpose((1, 2, 0))], axis=1)
                vis = np.concatenate([vis_in, vis_out], axis=0)
                try:
                    os.makedirs(os.path.join(self.opt_parser.jpg_dir, self.opt_parser.name))
                except:
                    pass
                cv2.imwrite(os.path.join(self.opt_parser.jpg_dir, self.opt_parser.name, 'e{:03d}_b{:04d}.jpg'.format(epoch, i)), vis * 255.0)
            # save ckpt
            if (i % self.opt_parser.ckpt_last_freq == 0):
                self.__save_model__('last', epoch)

            print("Epoch {}, Batch {}/{}, loss {:.4f}, l1 {:.4f}, vggloss {:.4f}, styleloss {:.4f} time {:.4f}".format(
                epoch, i, len(self.dataset) // self.opt_parser.batch_size,
                loss.cpu().detach().numpy(),
                loss_l1.cpu().detach().numpy(),
                loss_vgg.cpu().detach().numpy(),
                loss_style.cpu().detach().numpy(),
                          time.time() - st_batch))

            g_time += time.time() - st_batch


            if(self.opt_parser.test_speed):
                if(i >= 100):
                    break

        print('Epoch time usage:', time.time() - st_epoch, 'I/O time usage:', time.time() - st_epoch - g_time, '\n=========================')
        if(self.opt_parser.test_speed):
            exit(0)
        if(epoch % self.opt_parser.ckpt_epoch_freq == 0):
            self.__save_model__('{:02d}'.format(epoch), epoch)


    def __save_model__(self, save_type, epoch):
        try:
            os.makedirs(os.path.join(self.opt_parser.ckpt_dir, self.opt_parser.name))
        except:
            pass
        if (self.opt_parser.write):
            torch.save({
            'G': self.G.state_dict(),
            'opt': self.optimizer,
            'epoch': epoch
        }, os.path.join(self.opt_parser.ckpt_dir, self.opt_parser.name, 'ckpt_{}.pth'.format(save_type)))

    def train(self):
        for epoch in range(self.opt_parser.nepoch):
            self.__train_pass__(epoch, is_training=True)

    def test(self):
        if (self.opt_parser.use_vox_dataset == 'raw'):
            if(self.opt_parser.add_audio_in):
                from src.dataset.image_translation.image_translation_dataset import \
                    image_translation_raw98_with_audio_test_dataset as image_translation_test_dataset
            else:
                from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_test_dataset as image_translation_test_dataset
        else:
            from src.dataset.image_translation.image_translation_dataset import image_translation_preprocessed98_test_dataset as image_translation_test_dataset
        self.dataset = image_translation_test_dataset(num_frames=self.opt_parser.num_frames)
        self.dataloader = torch.utils.data.DataLoader(self.dataset,
                                                      batch_size=1,
                                                      shuffle=True,
                                                      num_workers=self.opt_parser.num_workers)

        self.G.eval()
        for i, batch in enumerate(self.dataloader):
            print(i, 50)
            if (i > 50):
                break

            if (self.opt_parser.add_audio_in):
                image_in, image_out, audio_in = batch
                audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
            else:
                image_in, image_out = batch

            # # online landmark (AwingNet)
            with torch.no_grad():
                image_in, image_out = \
                    image_in.reshape(-1, 3, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)

                pred_landmarks = []
                for j in range(image_in.shape[0] // 16):
                    inputs = image_out[j*16:j*16+16]
                    outputs, boundary_channels = self.fa_model(inputs)
                    pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
                    pred_landmark, _ = get_preds_fromhm(pred_heatmap)
                    pred_landmarks.append(pred_landmark.numpy() * 4)
                pred_landmarks = np.concatenate(pred_landmarks, axis=0)

            # draw landmark on while bg
            img_fls = []
            for pred_fl in pred_landmarks:
                img_fl = np.ones(shape=(256, 256, 3)) * 255.0
                img_fl = vis_landmark_on_img98(img_fl, pred_fl)  # 98x2
                img_fls.append(img_fl.transpose((2, 0, 1)))
            img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
            image_fls_in = torch.tensor(img_fls, requires_grad=False).to(device)

            if (self.opt_parser.add_audio_in):
                # print(image_fls_in.shape, image_in.shape, audio_in.shape)
                image_in = torch.cat([image_fls_in,
                                      image_in[0:image_fls_in.shape[0]],
                                      audio_in[0:image_fls_in.shape[0]]], dim=1)
            else:
                image_in = torch.cat([image_fls_in, image_in[0:image_fls_in.shape[0]]], dim=1)

            # normal 68 test dataset
            # image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)

            # random single frame
            # cv2.imwrite('random_img_{}.jpg'.format(i), np.swapaxes(image_out[5].numpy(),0, 2)*255.0)

            image_in, image_out = image_in.to(device), image_out.to(device)

            writer = cv2.VideoWriter('tmp_{:04d}.mp4'.format(i), cv2.VideoWriter_fourcc(*'mjpg'), 25, (256*4, 256))

            for j in range(image_in.shape[0] // 16):
                g_out = self.G(image_in[j*16:j*16+16])
                g_out = torch.tanh(g_out)

                # norm 68 pts
                # g_out = np.swapaxes(g_out.cpu().detach().numpy(), 1, 3)
                # ref_out = np.swapaxes(image_out[j*16:j*16+16].cpu().detach().numpy(), 1, 3)
                # ref_in = np.swapaxes(image_in[j*16:j*16+16, 3:6, :, :].cpu().detach().numpy(), 1, 3)
                # fls_in = np.swapaxes(image_in[j * 16:j * 16 + 16, 0:3, :, :].cpu().detach().numpy(), 1, 3)
                g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
                g_out[g_out < 0] = 0
                ref_out = image_out[j * 16:j * 16 + 16].cpu().detach().numpy().transpose((0, 2, 3, 1))
                ref_in = image_in[j * 16:j * 16 + 16, 3:6, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
                fls_in = image_in[j * 16:j * 16 + 16, 0:3, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))

                for k in range(g_out.shape[0]):
                    frame = np.concatenate((ref_in[k], g_out[k], fls_in[k], ref_out[k]), axis=1) * 255.0
                    writer.write(frame.astype(np.uint8))

            writer.release()

            os.system('ffmpeg -y -i tmp_{:04d}.mp4 -pix_fmt yuv420p random_{:04d}.mp4'.format(i, i))
            os.system('rm tmp_{:04d}.mp4'.format(i))


    def single_test(self, jpg=None, fls=None, filename=None, prefix='', grey_only=False):
        import time
        st = time.time()
        self.G.eval()

        if(jpg is None):
            jpg = glob.glob1(self.opt_parser.single_test, '*.jpg')[0]
            jpg = cv2.imread(os.path.join(self.opt_parser.single_test, jpg))

        if(fls is None):
            fls = glob.glob1(self.opt_parser.single_test, '*.txt')[0]
            fls = np.loadtxt(os.path.join(self.opt_parser.single_test, fls))
            fls = fls * 95
            fls[:, 0::3] += 130
            fls[:, 1::3] += 80

        # writer = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc(*'mjpg'), 62.5, (540 * 1, 624))

        # for i, frame in enumerate(fls):

        #     img_fl = np.ones(shape=(256, 256, 3)) * 255
        #     fl = frame.astype(int)
        #     img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
        #     frame = np.concatenate((img_fl, jpg), axis=2).astype(np.float32)/255.0

        #     image_in, image_out = frame.transpose((2, 0, 1)), np.zeros(shape=(3, 256, 256))
        #     # image_in, image_out = frame.transpose((2, 1, 0)), np.zeros(shape=(3, 256, 256))
        #     image_in, image_out = torch.tensor(image_in, requires_grad=False), \
        #                           torch.tensor(image_out, requires_grad=False)

        #     image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)
        #     image_in, image_out = image_in.to(device), image_out.to(device)

        #     g_out = self.G(image_in)
        #     g_out = torch.tanh(g_out)

        #     g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
        #     g_out[g_out < 0] = 0
        #     ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
        #     fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
        #     # g_out = g_out.cpu().detach().numpy().transpose((0, 3, 2, 1))
        #     # g_out[g_out < 0] = 0
        #     # ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))
        #     # fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))

        #     if(grey_only):
        #         g_out_grey =np.mean(g_out, axis=3, keepdims=True)
        #         g_out[:, :, :, 0:1] = g_out[:, :, :, 1:2] = g_out[:, :, :, 2:3] = g_out_grey


        #     for i in range(g_out.shape[0]):
        #         # frame = np.concatenate((ref_out[i], g_out[i], fls_in[i]), axis=1) * 255.0
        #         frame = g_out[i] * 255.0

        #         # import numpy as np
        #         from PIL import Image

        #         image = cv2.imread("/content/MakeItTalk/examples/girl_0.jpg")
        #         # numpy_array = np.random.rand(256, 256, 3)
        #         image_pil = Image.fromarray(image.astype(np.uint8))
        #         pil_image = Image.fromarray(frame.astype(np.uint8))

        #         # new_image_pil = Image.new('RGB', (540, 624), (255, 255, 255))
        #         # new_image_pil.paste(image_pil, (0, 0))

        #         # image_pil = Image.fromarray(image)
        #         image_pil.paste(pil_image, (face_locations[0][3]-50, face_locations[0][0]-50))

        #         # pil_image = Image.open('example.jpg')
        #         numpy_array = np.array(image_pil)
                
        #         writer.write(frame.astype(np.uint8))

        #         # img = cv2.imread("/content/MakeItTalk/examples/girl_0.jpg")
        #         # writer.write(img)

        # writer.release()
        # print('Time - only video:', time.time() - st)

        # if(filename is None):
        #     filename = 'v'
        # os.system('ffmpeg -loglevel error -y -i out.mp4 -i {} -pix_fmt yuv420p -strict -2 examples/{}_{}.mp4'.format(
        #     'examples/'+filename[9:-16]+'.wav',
        #     prefix, filename[:-4]))
        # # os.system('rm out.mp4')

        # print('Time - ffmpeg add audio:', time.time() - st)






        writer = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc(*'mjpg'), 62.5, (440 * 1, 509))

        for i, frame in enumerate(fls):

            img_fl = np.ones(shape=(256, 256, 3)) * 255
            fl = frame.astype(int)
            img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
            frame = np.concatenate((img_fl, jpg), axis=2).astype(np.float32)/255.0

            image_in, image_out = frame.transpose((2, 0, 1)), np.zeros(shape=(3, 256, 256))
            # image_in, image_out = frame.transpose((2, 1, 0)), np.zeros(shape=(3, 256, 256))
            image_in, image_out = torch.tensor(image_in, requires_grad=False), \
                                  torch.tensor(image_out, requires_grad=False)

            image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)
            image_in, image_out = image_in.to(device), image_out.to(device)

            g_out = self.G(image_in)
            g_out = torch.tanh(g_out)

            g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
            g_out[g_out < 0] = 0
            ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
            fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
            # g_out = g_out.cpu().detach().numpy().transpose((0, 3, 2, 1))
            # g_out[g_out < 0] = 0
            # ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))
            # fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))

            if(grey_only):
                g_out_grey =np.mean(g_out, axis=3, keepdims=True)
                g_out[:, :, :, 0:1] = g_out[:, :, :, 1:2] = g_out[:, :, :, 2:3] = g_out_grey


            for i in range(g_out.shape[0]):
                # # frame = np.concatenate((ref_in[i], g_out[i], fls_in[i]), axis=1) * 255.0
                # frame = g_out[i] * 255.0
                # writer.write(frame.astype(np.uint8))


                # frame = np.concatenate((ref_out[i], g_out[i], fls_in[i]), axis=1) * 255.0
                frame = g_out[i] * 255.0

                # import numpy as np
                from PIL import Image

                image = cv2.imread("/content/MakeItTalk/examples/girl_0.jpg")
                # numpy_array = np.random.rand(256, 256, 3)
                image_pil = Image.fromarray(image.astype(np.uint8))
                pil_image = Image.fromarray(frame.astype(np.uint8))

                # new_image_pil = Image.new('RGB', (540, 624), (255, 255, 255))
                # new_image_pil.paste(image_pil, (0, 0))

                # image_pil = Image.fromarray(image)
                image_pil.paste(pil_image, (face_locations[0][3]-50, face_locations[0][0]-50))

                # pil_image = Image.open('example.jpg')
                numpy_array = np.array(image_pil)
                
                writer.write(numpy_array.astype(np.uint8))







        writer.release()
        print('Time - only video:', time.time() - st)

        if(filename is None):
            filename = 'v'
        os.system('ffmpeg -loglevel error -y -i out.mp4 -i {} -pix_fmt yuv420p -strict -2 examples/{}_{}.mp4'.format(
            'examples/'+filename[9:-16]+'.wav',
            prefix, filename[:-4]))
        # os.system('rm out.mp4')

        print('Time - ffmpeg add audio:', time.time() - st)







In [93]:
!identify /content/MakeItTalk/examples/girl_0.jpg

## Step 4: Audio-to-Landmarks prediction

In [75]:
import face_recognition
from PIL import Image

# 加载图片
image = face_recognition.load_image_file("/content/MakeItTalk/examples/girl_0.jpg")

# 定位脸部位置
face_locations = face_recognition.face_locations(image)

# 裁剪出脸部图片
face_image = Image.fromarray(image).crop((face_locations[0][3]-50, face_locations[0][0]-50, face_locations[0][3]+206, face_locations[0][0]+206))
print(face_locations[0][3], face_locations[0][0], face_locations[0][1], face_locations[0][2])
# 保存脸部图像
face_image.save("/content/MakeItTalk/examples/girl_detail.jpg")

# 将脸部图片粘贴到原始图片上
image_pil = Image.fromarray(image)
image_pil.paste(face_image, (face_locations[0][3], face_locations[0][0]))

# 保存新的图片
image_pil.save("/content/MakeItTalk/examples/girl_1.jpg")

In [76]:
!pwd
model = Audio2landmark_model(opt_parser, jpg_shape=shape_3d)
if(len(opt_parser.reuse_train_emb_list) == 0):
    model.test(au_emb=au_emb)
else:
    model.test(au_emb=None)

In [27]:
!identify /content/MakeItTalk/examples/girl_0.jpg

In [28]:
import numpy as np
from PIL import Image

numpy_array = np.random.rand(256, 256, 3)
pil_image = Image.fromarray((numpy_array * 255).astype(np.uint8))

## Step 5: Natural face animation via Image-to-image translation 

In [95]:
fls = glob.glob1('examples', 'pred_fls_*.txt')
fls.sort()

for i in range(0,len(fls)):
    fl = np.loadtxt(os.path.join('examples', fls[i])).reshape((-1, 68,3))
    fl[:, :, 0:2] = -fl[:, :, 0:2]
    fl[:, :, 0:2] = fl[:, :, 0:2] / scale - shift

    if (ADD_NAIVE_EYE):
        fl = util.add_naive_eye(fl)

    # additional smooth
    fl = fl.reshape((-1, 204))
    fl[:, :48 * 3] = savgol_filter(fl[:, :48 * 3], 15, 3, axis=0)
    fl[:, 48*3:] = savgol_filter(fl[:, 48*3:], 5, 3, axis=0)
    fl = fl.reshape((-1, 68, 3))

    ''' STEP 6: Imag2image translation '''
    model = Image_translation_block1(opt_parser, single_test=True)
    with torch.no_grad():
        model.single_test(jpg=img, fls=fl, filename=fls[i], prefix=opt_parser.jpg.split('.')[0])
        print('finish image2image gen')
    # os.remove(os.path.join('examples', fls[i]))

## Visualize your animation!

In [96]:
from IPython.display import HTML
from base64 import b64encode

for ain in ains:
  OUTPUT_MP4_NAME = '{}_pred_fls_{}_audio_embed.mp4'.format(
    opt_parser.jpg.split('.')[0],
    ain.split('.')[0]
    )
  mp4 = open('examples/{}'.format(OUTPUT_MP4_NAME),'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

  print('Display animation: examples/{}'.format(OUTPUT_MP4_NAME))
  display(HTML("""
  <video width=300 controls>
        <source src="%s" type="video/mp4">
  </video>
  """ % data_url))

In [None]:
!pwd
model = Audio2landmark_model(opt_parser, jpg_shape=shape_3d)
if(len(opt_parser.reuse_train_emb_list) == 0):
    model.test(au_emb=au_emb)
else:
    model.test(au_emb=None)