# V-AURA Demo

Install the V-AURA environment (see ./README.md) and run the notebook below to generate samples using the V-AURA model.

### Imports

In [None]:
from pathlib import Path
import warnings
from math import ceil
import time

import torch
from tqdm import tqdm
from omegaconf import OmegaConf

from models.vaura_model import VAURAModel
from utils.train_utils import get_datamodule_from_type, get_curr_time_w_random_shift
from utils.demo_utils import resolve_ckpt_demo, resolve_hparams_demo
from models.data.vggsound_dataset import EPS as EPS_VGGSOUND
from scripts.generate import save_results, get_original_data

### Parameters

```EXPERIMENT_DIR```: Path to experiment dir. If not provided/not existing default will be downloaded.

```AVCLIP_CKPT```: If you have downloaded the Segment AVCLIP checkpoint for VGGSound, provide the path to it. If you have not downloaded the checkpoint, the code will download it for you.

```DURATION```: Duration of the audio to be generated. Must not exceed the duration of conditioned video and must be a multiple of 0.64.

```STRIDE```: When the context lenght of the model is exceeded (```DURATION``` > 2.56 seconds), stride defains how much of old audio is preserved in the prompt (```DURATION``` - ```STRIDE``` = amount of preserved audio), when generating audio past the 2.56 second mark.

```DEVICE```: Cuda is recommended.

```TEMP```: Temperature to be used during sampling. 0.95 - 1.0 is recommended.

```TOP_K```: Top_k amount for top_k sampling. 128 is recommended.

```CFG_SCALE```: Classifier-Free Guidance scale. 6.0 is recommended.

In [None]:
EXPERIMENT_DIR = "./logs/24-08-01T08-34-26"
AVCLIP_CKPT = "./segment_avclip/vggsound/best.pt"
OUTPUT_DIR = "./demo_output"

DURATION = 2.56  # n * 0.64s
STRIDE = 1.28   # n * 0.64s && < 2.56
assert DURATION % 0.64 == 0
assert STRIDE % 0.64 == 0

DEVICE = "cuda"
TEMP = 0.95
TOP_K = 128
CFG_SCALE = 6.0


In [None]:
# Resolve paths and download checkpoints if needed
output_path = Path(OUTPUT_DIR) / f"generated_samples_{get_curr_time_w_random_shift()}"
output_path.mkdir(exist_ok=True, parents=True)
checkpoint_path = resolve_ckpt_demo(EXPERIMENT_DIR)
hparams_path = resolve_hparams_demo(checkpoint_path, AVCLIP_CKPT)

print(f"Using checkpoint: {checkpoint_path}")
print(f"Using hparams: {hparams_path}")
print(f"Using output path: {output_path}")

In [None]:
# Load the model
with warnings.catch_warnings():  # :)
    warnings.simplefilter("ignore")
    model = VAURAModel.load_from_checkpoint(
        checkpoint_path, hparams_file=hparams_path, map_location=DEVICE
    )
model.eval()

In [None]:
# Resolve dataloader
dl_cfg = OmegaConf.load("./data/demo/dataloader_config.yaml")
dl_cfg["sample_duration"] = DURATION
OmegaConf.resolve(dl_cfg)  # resolve durations

datamodule = get_datamodule_from_type("motionformer_gen", dl_cfg)
datamodule.setup("test")
dataloader = datamodule.test_dataloader()

In [None]:
# Resolve generation parameters
MODEL_MAX_DURATION = 2.56  # do not modify
COMPRESSION_MODEL_FRAME_RATE = 86  # do not modify
if DURATION > MODEL_MAX_DURATION:
    assert STRIDE < MODEL_MAX_DURATION

total_gen_len = int(DURATION * COMPRESSION_MODEL_FRAME_RATE)
stride_tokens = int(STRIDE * COMPRESSION_MODEL_FRAME_RATE)
model.sampler.audio_tokens_per_video_frame = 7

In [None]:
# Generate
for sample in tqdm(dataloader):
    assert sample["meta"]["duration"] >= DURATION, "Sample duration can not exceed conditional video duration"
    frames = sample["frames"].to(DEVICE)
    current_gen_offset: int = 0
    prompt_length: int = 0
    all_tokens = []
    prompt_tokens = None

    # get original data without transformations
    original_frames, _ = get_original_data(
        sample["meta"],
        0.0,
        EPS_VGGSOUND,
        DURATION,
        0,
    )

    start_time = time.time()
    if DURATION <= MODEL_MAX_DURATION:  # single chunk generation
        item = model.generate(
            frames=frames,
            audio=prompt_tokens,
            max_new_tokens=total_gen_len,
            return_sampled_indices=True,
            use_sampling=True,
            temp=TEMP,
            top_k=TOP_K,
            cfg_scale=CFG_SCALE,
            remove_prompts=False,
            prompt_is_encoded=True,
        )
        generated_audios = item["generated_audio"]

    else:  # chunked generation
        while current_gen_offset + prompt_length < total_gen_len:
            time_offset = current_gen_offset / COMPRESSION_MODEL_FRAME_RATE
            chunk_duration = min(DURATION - time_offset, MODEL_MAX_DURATION)
            max_gen_len = ceil(chunk_duration * COMPRESSION_MODEL_FRAME_RATE)

            # Figure out the frames to use
            initial_position = ceil(time_offset * sample["meta"]["video_fps"])
            video_target_length = ceil(chunk_duration * sample["meta"]["video_fps"])
            positions = torch.arange(
                initial_position // 16,
                (initial_position + video_target_length) // 16,
                device=DEVICE,
            )
            selected_frames = frames[:, positions % frames.shape[1], ...]

            item = model.generate(
                frames=selected_frames,
                audio=prompt_tokens,
                max_new_tokens=max_gen_len,
                return_sampled_indices=True,
                use_sampling=True,
                temp=TEMP,
                top_k=TOP_K,
                cfg_scale=CFG_SCALE,
                remove_prompts=False,
                prompt_is_encoded=True,
            )
            gen_tokens = item["sampled_indices"]
            if prompt_tokens is None:
                all_tokens.append(gen_tokens)
            else:
                all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1] :])
            prompt_tokens = gen_tokens[:, :, stride_tokens:]
            prompt_length = prompt_tokens.shape[-1]
            current_gen_offset += stride_tokens

        # Gather outputs
        gen_tokens = torch.cat(all_tokens, dim=-1)
        sampled_frames = [(gen_tokens[..., : model.num_codebooks, :], None)]
        generated_audios = model.audio_encoder.decode(sampled_frames)

    end_time = time.time()
    print(f"Generation took {end_time - start_time:.2f}s")

    # Save results
    for i, generated_audios in enumerate(generated_audios):
        save_results(
            generated_audios[i],
            original_frames[i],
            output_path,
            Path(sample["meta"]["filepath"][i]).name,
            sample["meta"]["video_fps"][i].item(),
            44100,
            sample["meta"]["audio_fps"][i].item(),
        )