[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/SyncTalk-jupyter/blob/main/SyncTalk_jupyter_test.ipynb)

In [None]:
%cd /content
!git clone -b dev https://github.com/camenduru/SyncTalk
%cd /content/SyncTalk

!apt -y install -qq libasound2-dev portaudio19-dev libportaudio2 libportaudiocpp0 ffmpeg aria2

!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/SyncTalk/resolve/main/May.zip -d /content/SyncTalk/data -o May.zip
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/SyncTalk/resolve/main/trial_may.zip -d /content/SyncTalk/model -o trial_may.zip
%cd /content/SyncTalk/data
!unzip May.zip
%cd /content/SyncTalk/model
!unzip trial_may.zip

!pip install -q https://github.com/camenduru/wheels/releases/download/colab2/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl
!pip install -q torch-ema ninja trimesh tensorboardX PyMCubes dearpygui scikit-learn face_alignment python_speech_features
!pip install -q resampy pyaudio einops configargparse lpips onnxruntime-gpu facenet_pytorch fvcore iopath ffmpeg-python
!pip install -q https://github.com/camenduru/wheels/releases/download/colab2/freqencoder-0.0.0-cp310-cp310-linux_x86_64.whl
!pip install -q https://github.com/camenduru/wheels/releases/download/colab2/shencoder-0.0.0-cp310-cp310-linux_x86_64.whl
!pip install -q https://github.com/camenduru/wheels/releases/download/colab2/gridencoder-0.0.0-cp310-cp310-linux_x86_64.whl
!pip install -q https://github.com/camenduru/wheels/releases/download/colab2/raymarching_face-0.0.0-cp310-cp310-linux_x86_64.whl

In [None]:
%cd /content/SyncTalk

from nerf_triplane.provider import NeRFDataset
from nerf_triplane.utils import *
from nerf_triplane.network import NeRFNetwork

try:
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
except AttributeError as e:
    print('Info. This pytorch version is not support with tf32.')

class Options:
    def __init__(self):
        self.path = None
        self.O = False
        self.test = False
        self.test_train = False
        self.data_range = [0, -1]
        self.workspace = 'workspace'
        self.seed = 0

        ### training options
        self.iters = 200000
        self.lr = 1e-2
        self.lr_net = 1e-3
        self.ckpt = 'latest'
        self.num_rays = 4096 * 16
        self.cuda_ray = False
        self.max_steps = 16
        self.num_steps = 16
        self.upsample_steps = 0
        self.update_extra_interval = 16
        self.max_ray_batch = 4096

        ### loss set
        self.warmup_step = 10000
        self.amb_aud_loss = 1
        self.amb_eye_loss = 1
        self.unc_loss = 1
        self.lambda_amb = 1e-1
        self.pyramid_loss = 0

        ### network backbone options
        self.fp16 = False
        self.bg_img = ''
        self.fbg = False
        self.exp_eye = False
        self.fix_eye = -1
        self.smooth_eye = False
        self.bs_area = "upper"
        self.torso_shrink = 0.8

        ### dataset options
        self.color_space = 'srgb'
        self.preload = 0
        self.bound = 1
        self.scale = 4
        self.offset = [0, 0, 0]
        self.dt_gamma = 1/256
        self.min_near = 0.05
        self.density_thresh = 10
        self.density_thresh_torso = 0.01
        self.patch_size = 1

        self.init_lips = False
        self.finetune_lips = False
        self.smooth_lips = False

        self.torso = False
        self.head_ckpt = ''

        ### GUI options
        self.gui = False
        self.W = 450
        self.H = 450
        self.radius = 3.35
        self.fovy = 21.24
        self.max_spp = 1

        ### else
        self.att = 2
        self.aud = ''
        self.emb = False
        self.portrait = False
        self.ind_dim = 4
        self.ind_num = 20000
        self.ind_dim_torso = 8
        self.amb_dim = 2
        self.part = False
        self.part2 = False
        self.train_camera = False
        self.smooth_path = False
        self.smooth_path_window = 7

        self.asr = False
        self.asr_wav = ''
        self.asr_play = False
        self.asr_model = 'deepspeech'
        self.asr_save_feats = False
        self.fps = 50
        self.l = 10
        self.m = 50
        self.r = 10

opt = Options()

In [None]:
opt.test = True
opt.test_train = True
opt.portrait = True
opt.aud = '/content/SyncTalk/demo/test.wav'

if opt.test and False:
    opt.smooth_path = True
    opt.smooth_eye = True
    opt.smooth_lips = True

metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_set = NeRFDataset(opt, device=device, type='train')
test_set.training = False 
test_set.num_rays = -1
test_loader = test_set.dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
trainer.test(test_loader)

In [None]:
opt.iters = 100000
opt.finetune_lips = True
opt.patch_size = 64

optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr_net), betas=(0, 0.99), eps=1e-8)
train_loader = NeRFDataset(opt, device=device, type='train').dataloader()
assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!"

model.aud_features = train_loader._data.auds
model.eye_area = train_loader._data.eye_area
model.poses = train_loader._data.poses

if opt.finetune_lips:
    scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters))
else:
    scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.5 ** (iter / opt.iters))

metrics = [PSNRMeter(), LPIPSMeter(device=device),LMDMeter(backend='fan')]

eval_interval = max(1, int(5000 / len(train_loader)))
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval)
with open(os.path.join(opt.workspace, 'opt.txt'), 'a') as f:
    f.write(str(opt))

valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader()
max_epochs = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
print(f'[INFO] max_epoch = {max_epochs}')
trainer.train(train_loader, valid_loader, max_epochs)

del train_loader, valid_loader
torch.cuda.empty_cache()

test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
trainer.test(test_loader)