# Synchformer: Efficient Synchronization from Sparse Cues

<figure>
  <img src="https://github.com/v-iashin/Synchformer/raw/main/_repo_assets/main.png" width="700" />
</figure>

This notebook demonstrates a minimal working example of audio-visual synchronisation on a sample video with a sparse synchronisation signal.

[Project Page](https://www.robots.ox.ac.uk/~vgg/research/synchformer/) | [Code & Models](https://github.com/v-iashin/Synchformer)

Uncomment the lines in the following cell if you are on Google Colab

In [1]:
# !git clone https://github.com/v-iashin/Synchformer.git
# !pip install omegaconf==2.0.6 av==10.0 einops timm==0.6.12
# %cd Synchformer

In [2]:
import subprocess
from pathlib import Path

import torch
import torchaudio
import torchvision
from omegaconf import OmegaConf

from dataset.dataset_utils import get_video_and_audio
from dataset.transforms import make_class_grid, quantize_offset
from utils.utils import check_if_file_exists_else_download, which_ffmpeg
from scripts.train_utils import get_model, get_transforms, prepare_inputs


def reencode_video(path, vfps=25, afps=16000, in_size=256):
    assert which_ffmpeg() != '', 'Is ffmpeg installed? Check if the conda environment is activated.'
    new_path = Path.cwd() / 'vis' / f'{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4'
    new_path.parent.mkdir(exist_ok=True)
    new_path = str(new_path)
    cmd = f'{which_ffmpeg()}'
    # no info/error printing
    cmd += ' -hide_banner -loglevel panic'
    cmd += f' -y -i {path}'
    # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
    cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
    cmd += f" -ar {afps}"
    cmd += f' {new_path}'
    subprocess.call(cmd.split())
    cmd = f'{which_ffmpeg()}'
    cmd += ' -hide_banner -loglevel panic'
    cmd += f' -y -i {new_path}'
    cmd += f' -acodec pcm_s16le -ac 1'
    cmd += f' {new_path.replace(".mp4", ".wav")}'
    subprocess.call(cmd.split())
    return new_path


def decode_single_video_prediction(off_logits, grid, item):
    label = item['targets']['offset_label'].item()
    print('Ground Truth offset (sec):', f'{label:.2f} ({quantize_offset(grid, label)[-1].item()})')
    print('Prediction Results:')
    off_probs = torch.softmax(off_logits, dim=-1)

    filename = "./result.csv"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    offsets = torch.arange(-0.2,0.21,0.02).to(device).half()   
    offset_prediction = torch.dot(off_probs.squeeze(), offsets)
    with open(filename, 'a') as f:
        f.write(f"{offset_prediction:.4f},")
        
    k = min(off_probs.shape[-1], 5)
    topk_logits, topk_preds = torch.topk(off_logits, k)
    # remove batch dimension
    assert len(topk_logits) == 1, 'batch is larger than 1'
    topk_logits = topk_logits[0]
    topk_preds = topk_preds[0]
    off_logits = off_logits[0]
    off_probs = off_probs[0]
    temp = 0
    for target_hat in topk_preds:
        print(
            f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
        if temp == 0:
            temp = 1
            with open(filename, 'a') as f:
                #f.write(f"{off_probs[target_hat]:.4f},({off_logits[target_hat]:.4f}),{grid[target_hat]:.2f},({target_hat})\n")     
                f.write(f"{off_probs[target_hat]:.4f},({off_logits[target_hat]:.4f}),{grid[target_hat]:.2f},({target_hat})\n")  
    return off_probs


def patch_config(cfg):
    # the FE ckpts are already in the model ckpt
    cfg.model.params.afeat_extractor.params.ckpt_path = None
    cfg.model.params.vfeat_extractor.params.ckpt_path = None
    # old checkpoints have different names
    cfg.model.params.transformer.target = cfg.model.params.transformer.target\
                                             .replace('.modules.feature_selector.', '.sync_model.')
    return cfg


In [3]:
vfps = 25
afps = 16000
in_size = 256
exp_name = '24-12-25T22-20-32'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load the model
cfg_path = f'./logs/sync_models/{exp_name}/cfg-{exp_name}.yaml'
ckpt_path = f'./logs/sync_models/{exp_name}/{exp_name}.pt'

# if the model does not exist try to download it from the server 24-11-30T20-48-52​
check_if_file_exists_else_download(cfg_path)
check_if_file_exists_else_download(ckpt_path)

# load config
cfg = OmegaConf.load(cfg_path)

# patch config
cfg = patch_config(cfg)

_, model = get_model(cfg, device)
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model'])
model.eval()
print('Model loaded.')

  from .autonotebook import tqdm as notebook_tqdm


Model loaded.


In [5]:
'''import os
import random
import pandas as pd
to_process = []
#x = '/home/gnivedita/Synchformer/data/vggsound/h264_video_25fps_256side_16000hz_aac'
x = '/home/gnivedita/Wav2Lip/videos/LRS2_shifted_LGset2offset'
for file in os.listdir(x):
    if file.endswith(".mp4"):
        #v_start_sec = round(random.uniform(2, 3), 2)
        #print(float(file[:-4].split('_')[-1]))      
        #audioset
        audioset_df = pd.read_csv("v_start_sec.csv")
        if file[:-4] in audioset_df['vidname'].values:
            pass
            # Get the corresponding v_start_sec
            # v_start_sec = audioset_df.loc[audioset_df['vidname'] == file[:-4], 'v_start_sec'].values[0]
            # print(file, v_start_sec)
        else:
            #lrs2
            if float(file[:-4].split('_')[-1]) == 2:
                continue
            elif float(file[:-4].split('_')[-1]) > 0:
                v_start_sec = 0
            else:
                v_start_sec = 1
        # offset_in_sync = random.choice([-0.2, -0.1, 0, 0.1, 0.2])
        # offset_out_of_sync = random.choice([-1, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
        # if offset_in_sync > 0:
        #     v_start_sec = 0
        # else: 
        #     v_start_sec = 1
        # to_process.append((x+'/'+file, offset_in_sync, v_start_sec))      
        # if offset_out_of_sync > 0:
        #     v_start_sec = 0
        # else: 
        #     v_start_sec = 1
        # to_process.append((x+'/'+file, offset_out_of_sync, v_start_sec))
        to_process.append((x+'/'+file, 0, v_start_sec))'''

'import os\nimport random\nimport pandas as pd\nto_process = []\n#x = \'/home/gnivedita/Synchformer/data/vggsound/h264_video_25fps_256side_16000hz_aac\'\nx = \'/home/gnivedita/Wav2Lip/videos/LRS2_shifted_LGset2offset\'\nfor file in os.listdir(x):\n    if file.endswith(".mp4"):\n        #v_start_sec = round(random.uniform(2, 3), 2)\n        #print(float(file[:-4].split(\'_\')[-1]))      \n        #audioset\n        audioset_df = pd.read_csv("v_start_sec.csv")\n        if file[:-4] in audioset_df[\'vidname\'].values:\n            pass\n            # Get the corresponding v_start_sec\n            # v_start_sec = audioset_df.loc[audioset_df[\'vidname\'] == file[:-4], \'v_start_sec\'].values[0]\n            # print(file, v_start_sec)\n        else:\n            #lrs2\n            if float(file[:-4].split(\'_\')[-1]) == 2:\n                continue\n            elif float(file[:-4].split(\'_\')[-1]) > 0:\n                v_start_sec = 0\n            else:\n                v_start_sec = 1\n  

In [6]:
import os
import random
import pandas as pd
import numpy as np

x = '/home/gnivedita/Synchformer/data/vggsound/h264_video_25fps_256side_16000hz_aac' 
# Load the CSV file
csv_path = '/home/gnivedita/Synchformer/data/vggsound.csv'  # Replace with the actual path to vggsound.csv
vggsound_data = pd.read_csv(csv_path)

# Extract column 1 values as a set for efficient lookup
file_names_in_csv = set(vggsound_data.iloc[:, 0])  # Assuming column 1 is the first column (0-based index)

# Initialize a list to process
to_process = []
#x = '/home/gnivedita/Synchformer/data_-0.2to0.2/vggsound/h264_video_25fps_256side_16000hz_aac'
# for training videos present in vggsound:
'''for file in os.listdir(x):
    if file.endswith((".mp4",".mov")):
        file_name_no_ext = os.path.splitext(file)[0]
        if file_name_no_ext in file_names_in_csv:
            v_start_sec = 0
            to_process.append((x+'/'+file, 0, v_start_sec))'''

for file in os.listdir(x):
    if file.endswith((".mp4",".mov")):
        v_start_sec = 0.5
        offset_sec = np.round(float(file.split('_')[-2]), 2)
        #offset_sec = 0
        to_process.append((x+'/'+file, offset_sec, v_start_sec)) 

In [7]:
print(len(to_process))

2583


In [8]:
'''# list of items to process. Mind the order: (video_path, offset_sec, v_start_i_sec)
to_process = [
    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/3qesirWAGt4_20000_30000.mp4', 1.6, 0.0),
    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/ZYc410CE4Rg_0_10000.mp4', -2.0, 4.0),
]'''

"# list of items to process. Mind the order: (video_path, offset_sec, v_start_i_sec)\nto_process = [\n    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/3qesirWAGt4_20000_30000.mp4', 1.6, 0.0),\n    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/ZYc410CE4Rg_0_10000.mp4', -2.0, 4.0),\n]"

In [9]:
for vid_path, offset_sec, v_start_i_sec in to_process:
    # (optional) checking if the provided video has the correct frame rates
    print(f'Using video: {vid_path}')
    with open("./result.csv", 'a') as f:
        f.write(f"{vid_path.split('/')[-1][:-4]},{v_start_i_sec},")

    v, _, info = torchvision.io.read_video(vid_path, pts_unit='sec')
    _, H, W, _ = v.shape
    if info['video_fps'] != vfps or info['audio_fps'] != afps or min(H, W) != in_size:
        print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=' ')
        print(f'afps: {info["audio_fps"]} -> {afps};', end=' ')
        print(f'{(H, W)} -> min(H, W)={in_size}')
        vid_path = reencode_video(vid_path, vfps, afps, in_size)
    else:
        print(
            f'No need to reencode: vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')

    # load visual and audio streams
    # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
    try:
        rgb, audio, meta = get_video_and_audio(vid_path, get_meta=True)

    except Exception as e:
        # Print an error message and continue
        print(f"Skipping video: {vid_path} due to error: {e}")
        continue

    # making an item (dict) to apply transformations
    # NOTE: here is how it works:
    # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
    # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
    # track by `offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
    # start by `offset_sec` earlier than the rgb track.
    # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
    item = dict(
        video=rgb, audio=audio, meta=meta, path=vid_path, split='test',
        targets={'v_start_i_sec': v_start_i_sec, 'offset_sec': offset_sec, },
    )
    print(f"offset_sec: {offset_sec}")
    # making the offset class grid similar to the one used in transforms
    max_off_sec = cfg.data.max_off_sec
    num_cls = cfg.model.params.transformer.params.off_head_cfg.params.out_features
    grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
    if not (min(grid) <= item['targets']['offset_sec'] <= max(grid)):
        print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')

    # applying the test-time transform
    item = get_transforms(cfg, ['test'])['test'](item)

    # prepare inputs for inference
    batch = torch.utils.data.default_collate([item])
    aud, vid, targets = prepare_inputs(batch, device)

    # TODO:
    # sanity check: we will take the input to the `model` and recontruct make a video from it.
    # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
    # reconstruct_video_from_input(aud, vid, batch['meta'], vid_path, v_start_i_sec, offset_sec,
    #                              vfps, afps)

    # forward pass
    with torch.set_grad_enabled(False):
        with torch.autocast('cuda', enabled=cfg.training.use_half_precision):
            _, logits = model(vid, aud)

    # simply prints the results of the prediction
    decode_single_video_prediction(logits, grid, item)
    print()

Using video: /home/gnivedita/Synchformer/data/vggsound/h264_video_25fps_256side_16000hz_aac/0447Rl_Gh3g_40.0_50.0_-0.14_4s.mp4
No need to reencode: vfps: 25.0; afps: 16000; min(H, W)=256
offset_sec: -0.14
Current Working Directory (full path): /home/gnivedita/Synchformer
here1: True True
here2
here4:-0.14
nivi: 0.26
nivi: 0.4
Ground Truth offset (sec): -0.14 (3)
Prediction Results:
p=0.0737 (5.1523), "-0.06" (7)
p=0.0682 (5.0742), "-0.04" (8)
p=0.0656 (5.0352), "0.10" (15)
p=0.0640 (5.0117), "-0.08" (6)
p=0.0630 (4.9961), "0.06" (13)

Using video: /home/gnivedita/Synchformer/data/vggsound/h264_video_25fps_256side_16000hz_aac/0J7NMewwdKU_30.0_40.0_-0.14_4s.mp4
No need to reencode: vfps: 25.0; afps: 16000; min(H, W)=256
offset_sec: -0.14
Current Working Directory (full path): /home/gnivedita/Synchformer
here1: True True
here2
here4:-0.14
nivi: 0.26
nivi: 0.4
Ground Truth offset (sec): -0.14 (3)
Prediction Results:
p=0.1259 (6.1367), "-0.02" (9)
p=0.1224 (6.1094), "-0.04" (8)
p=0.0999 (5.

KeyboardInterrupt: 

In [10]:
# !pip freeze