[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Float-jupyter/blob/main/Float_jupyter.ipynb)

In [None]:
%cd /content
!git clone --depth 1 --branch master https://github.com/comfyanonymous/ComfyUI /content/ComfyUI
!git clone --depth 1 --branch master https://github.com/yuvraj108c/ComfyUI-FLOAT /content/ComfyUI/custom_nodes/ComfyUI-FLOAT
%cd /content/ComfyUI 
!pip install -r requirements.txt
%cd /content/ComfyUI/custom_nodes/ComfyUI-FLOAT
!pip install -r requirements.txt
%cd /content/ComfyUI
!pip install ffmpeg-python

import os, json, requests, random, time, cv2, ffmpeg, shutil, subprocess, re
from urllib.parse import urlsplit

import torch
from PIL import Image
import numpy as np

from nodes import NODE_CLASS_MAPPINGS, load_custom_node

load_custom_node("/content/ComfyUI/custom_nodes/ComfyUI-FLOAT")

LoadFloatModels = NODE_CLASS_MAPPINGS["LoadFloatModels"]()
FloatProcess = NODE_CLASS_MAPPINGS["FloatProcess"]()
LoadImage = NODE_CLASS_MAPPINGS["LoadImage"]()

with torch.inference_mode():
    float_pipe = LoadFloatModels.loadmodel("float.pth")[0]

def download_file(url, save_dir, file_name):
    os.makedirs(save_dir, exist_ok=True)
    file_suffix = os.path.splitext(urlsplit(url).path)[1]
    file_name_with_suffix = file_name + file_suffix
    file_path = os.path.join(save_dir, file_name_with_suffix)
    response = requests.get(url)
    response.raise_for_status()
    with open(file_path, 'wb') as file:
        file.write(response.content)
    return file_path

def images_to_mp4(images, output_path, fps=24, audio=None):
    try:
        import cv2, os, tempfile, numpy as np, ffmpeg, torch, soundfile as sf

        # Step 1: Process image frames
        frames = []
        for image in images:
            i = 255. * image.cpu().numpy()
            img = np.clip(i, 0, 255).astype(np.uint8)
            if img.shape[0] in [1, 3, 4]:
                img = np.transpose(img, (1, 2, 0))
            if img.shape[-1] == 4:
                img = img[:, :, :3]
            frames.append(img)

        temp_files = [f"temp_{i:04d}.png" for i in range(len(frames))]
        for i, frame in enumerate(frames):
            success = cv2.imwrite(temp_files[i], frame[:, :, ::-1])  # RGB to BGR
            if not success:
                raise ValueError(f"Failed to write {temp_files[i]}")

        if not os.path.exists(temp_files[0]):
            raise FileNotFoundError("Temporary PNG files were not created")

        # Step 2: Create a temporary video file without audio
        video_temp_path = tempfile.mktemp(suffix=".mp4")
        stream = ffmpeg.input('temp_%04d.png', framerate=fps)
        stream = ffmpeg.output(stream, video_temp_path, vcodec='libx264', pix_fmt='yuv420p')
        ffmpeg.run(stream, overwrite_output=True)

        # Step 3: If audio is provided, save to a temporary WAV and mux it with the video
        if audio:
            waveform = audio['waveform']  # shape: (1, channels, samples)
            sample_rate = audio['sample_rate']
            audio_temp_path = tempfile.mktemp(suffix=".wav")

            # Convert to numpy and write audio
            audio_np = waveform.squeeze(0).transpose(0, 1).cpu().numpy()
            sf.write(audio_temp_path, audio_np, sample_rate)

            # Mux audio and video
            video_in = ffmpeg.input(video_temp_path)
            audio_in = ffmpeg.input(audio_temp_path)
            (
                ffmpeg
                .output(video_in, audio_in, output_path,
                        vcodec='copy', acodec='aac', shortest=None)
                .run(overwrite_output=True)
            )

            os.remove(audio_temp_path)
            os.remove(video_temp_path)
        else:
            os.rename(video_temp_path, output_path)

        # Cleanup images
        for temp_file in temp_files:
            os.remove(temp_file)

    except Exception as e:
        print(f"Error: {e}")

ENCODE_ARGS = ("utf-8", 'backslashreplace')

def ffmpeg_suitability(path):
    try:
        version = subprocess.run([path, "-version"], check=True,
                                 capture_output=True).stdout.decode(*ENCODE_ARGS)
    except:
        return 0
    score = 0
    #rough layout of the importance of various features
    simple_criterion = [("libvpx", 20),("264",10), ("265",3),
                        ("svtav1",5),("libopus", 1)]
    for criterion in simple_criterion:
        if version.find(criterion[0]) >= 0:
            score += criterion[1]
    #obtain rough compile year from copyright information
    copyright_index = version.find('2000-2')
    if copyright_index >= 0:
        copyright_year = version[copyright_index+6:copyright_index+9]
        if copyright_year.isnumeric():
            score += int(copyright_year)
    return score

ffmpeg_paths = []
try:
    from imageio_ffmpeg import get_ffmpeg_exe
    imageio_ffmpeg_path = get_ffmpeg_exe()
    ffmpeg_paths.append(imageio_ffmpeg_path)
except:
    if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
        raise
    logger.warn("Failed to import imageio_ffmpeg")
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
    ffmpeg_path = imageio_ffmpeg_path
else:
    system_ffmpeg = shutil.which("ffmpeg")
    if system_ffmpeg is not None:
        ffmpeg_paths.append(system_ffmpeg)
    if os.path.isfile("ffmpeg"):
        ffmpeg_paths.append(os.path.abspath("ffmpeg"))
    if os.path.isfile("ffmpeg.exe"):
        ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
    if len(ffmpeg_paths) == 0:
        logger.error("No valid ffmpeg found.")
        ffmpeg_path = None
    elif len(ffmpeg_paths) == 1:
        #Evaluation of suitability isn't required, can take sole option
        #to reduce startup time
        ffmpeg_path = ffmpeg_paths[0]
    else:
        ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)

def get_audio(file, start_time=0, duration=0):
    args = [ffmpeg_path, "-i", file]
    if start_time > 0:
        args += ["-ss", str(start_time)]
    if duration > 0:
        args += ["-t", str(duration)]
    try:
        #TODO: scan for sample rate and maintain
        res =  subprocess.run(args + ["-f", "f32le", "-"],
                              capture_output=True, check=True)
        audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
        match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode(*ENCODE_ARGS))
    except subprocess.CalledProcessError as e:
        raise Exception(f"VHS failed to extract audio from {file}:\n" \
                + e.stderr.decode(*ENCODE_ARGS))
    if match:
        ar = int(match.group(1))
        #NOTE: Just throwing an error for other channel types right now
        #Will deal with issues if they come
        ac = {"mono": 1, "stereo": 2}[match.group(2)]
    else:
        ar = 44100
        ac = 2
    audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0)
    return {'waveform': audio, 'sample_rate': ar}

@torch.inference_mode()
def generate(input):
    try:
        values = input["input"]

        input_image = values['input_image']
        input_image = download_file(url=input_image, save_dir='/content/ComfyUI/input', file_name='input_image')
        input_audio = values['input_audio']
        input_audio = download_file(url=input_audio, save_dir='/content/ComfyUI/input', file_name='input_audio')

        start_time = values['start_time']
        duration = values['duration']
        a_cfg_scale = values['a_cfg_scale']
        r_cfg_scale = values['r_cfg_scale']
        e_cfg_scale = values['e_cfg_scale']
        fps = values['fps']
        emotion = values['emotion']
        crop = values['crop']
        seed = values['seed']
        if seed == 0:
            random.seed(int(time.time()))
            seed = random.randint(0, 18446744073709551615)
        filename_prefix = values['filename_prefix'] # float

        ref_image = LoadImage.load_image(input_image)[0]
        
        ref_audio = get_audio(input_audio, start_time=start_time, duration=duration)

        images_bhwc = FloatProcess.floatprocess(ref_image, ref_audio, float_pipe, a_cfg_scale, r_cfg_scale, e_cfg_scale, fps, emotion, crop, seed)[0]

        images_to_mp4(images_bhwc, f"/content/{filename_prefix}.mp4", fps, audio=ref_audio)

        result = f"/content/{filename_prefix}.mp4"
        return {"status": "DONE", "result": result}
    except Exception as e:
        return {"status": "ERROR", "result": str(e)}

In [None]:
input = {
    "input": {
        "input_image": "https://s3.tost.ai/input/cc461064-0682-4649-a0d7-c8a1ceee7dcf.webp",
        "input_audio": "https://s3.tost.ai/input/cc9d9301-faf9-44dc-9044-2dba04e4ce8e.mp3",
        "start_time": 6,
        "duration": 30,
        "a_cfg_scale": 2.0,
        "r_cfg_scale": 1.0,
        "e_cfg_scale": 1.0,
        "fps": 25,
        "emotion": "neutral",
        "crop": True,
        "seed": 0,
        "filename_prefix": "float"
    }
}

generate(input)