In [None]:
# %cd /content/drive/MyDrive/colab/gen

## Note:
The following steps include:
- Setup Environment  
    - ***If there is a numpy 2.x dependency issue, you can ignore it.***
- Restart Colab Runtime  
    - ***Important!***
- Prepare Code and Models  
- Inference and Display  

## Setup Environment

### show gpu info

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader
!pwd
!ls

### check torch

In [None]:
import torch
torch.__version__

### install packages

In [None]:
# about 0.5~1min
!pip install tensorrt==8.6.1 librosa tqdm filetype imageio opencv_python_headless scikit-image cython cuda-python imageio-ffmpeg colored polygraphy numpy==2.0.1

### restart runtime

In [None]:
# !!!
# You need to restart the runtime to ensure that the newly installed environment takes effect
# !!!
import os
os.kill(os.getpid(), 9)

In [None]:
# %cd /content/drive/MyDrive/colab/gen
!ls

### check environment

In [None]:
import numpy as np
import torch
import tensorrt as trt
print(np.__version__)
print(torch.__version__)
print(trt.__version__)

## Prepare Code and Models

### prepare code

In [None]:
# about 2s
import os
if not os.path.isdir("ditto-talkinghead"):
    !git clone https://github.com/antgroup/ditto-talkinghead.git

%cd ditto-talkinghead
!git pull
!ls

### prepare model

In [None]:
# about 1~2min
!git lfs install
if not os.path.isdir("checkpoints"):
    !git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints

%cd checkpoints
!git pull
!ls

%cd ..
!ls

### check GPU architecture

In [None]:
# about 1~2min
import os
import torch

def cvt_custom_trt():
    from scripts.cvt_onnx_to_trt import main as cvt_trt
    onnx_dir = "./checkpoints/ditto_onnx"
    trt_dir = "./checkpoints/ditto_trt_custom"
    assert os.path.isdir(onnx_dir)
    os.makedirs(trt_dir, exist_ok=True)
    grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so")
    cvt_trt(onnx_dir, trt_dir, grid_sample_plugin_file)
    return trt_dir


def download_Non_Ampere_trt():
    !pip install --upgrade --no-cache-dir gdown
    !gdown https://drive.google.com/drive/folders/1-1qnqy0D9ICgRh8iNY_22j9ieNRC0-zf?usp=sharing -O ./checkpoints/ditto_trt --folder
    trt_dir = "./checkpoints/ditto_trt"
    return trt_dir


if torch.cuda.get_device_capability()[0] < 8:
    # data_root = cvt_custom_trt()    # cvt
    # The conversion is slow, so you can download pre-converted files.
    data_root = download_Non_Ampere_trt()
else:
    data_root = "./checkpoints/ditto_trt_Ampere_Plus"

## Inference

### run inference

In [None]:
# init, about 10s
from inference import StreamSDK, run
# data_root = "./checkpoints/ditto_trt_custom"   # model dir
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"     # cfg pkl
print(data_root)
print(cfg_pkl)
SDK = StreamSDK(cfg_pkl, data_root)

In [None]:
# run inference, about 1~2min
audio_path = "./example/audio.wav"    # .wav
source_path = "./example/image.png"   # video|image
output_path = "./tmp/result.mp4"    # .mp4

run(SDK, audio_path, source_path, output_path)

### display result

In [None]:
# display, about 5~10s
from IPython.display import HTML
from base64 import b64encode
import os, sys
import glob

mp4_name = output_path

mp4 = open('{}'.format(mp4_name),'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

print('Display animation: {}'.format(mp4_name), file=sys.stderr)
display(HTML("""
  <video width=256 controls>
        <source src="%s" type="video/mp4">
  </video>
  """ % data_url))