In [None]:
import sys
import os 
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import warnings
warnings.filterwarnings("ignore")
import cv2
import numpy as np
from PIL import Image
from lora_w2w import LoRAw2w, LoRAModule
from utils import load_models, load_controlnet_models, inference, save_model_w2w, save_model_for_diffusers, unflatten, tensor_to_pil
from inversion_utils_experimental import invert
from diffusers import ControlNetModel
device = "cuda:0"

controlnet_model = load_controlnet_models(device)

unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)

In [None]:
mean = torch.load("../files/mean.pt").bfloat16().to(device)
std = torch.load("../files/std.pt").bfloat16().to(device)
v = torch.load("../files/V.pt").bfloat16().to(device)
weight_dimensions = torch.load("../files/weight_dimensions.pt")
proj = torch.zeros(1,10000).bfloat16().to(device)
network = LoRAw2w( proj, mean, std, v[:, :10000], 
                    unet,
                    rank=1,
                    multiplier=1.0,
                    alpha=27.0,
                    train_method="xattn-strict"
                ).to(device, torch.bfloat16)

In [None]:
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished" 
batch_size = 1
height = 512
width = 512
controlnet_conditioning_scale = 0.8
guidance_scale = 7
seed = 42
ddim_steps = 50
generator = torch.Generator(device=device).manual_seed(seed)

In [None]:
network = invert(network=network, unet=unet, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, 
                 prompt=prompt, noise_scheduler = noise_scheduler, epochs=50, 
                 image_path = "../inversion/images/deepfake", 
                 mask_path = "../inversion/images/deepfake_masks", device = device, 
                 lr=0.2,
                 weight_decay=0)

if False:
    for seed in [1, 42, 360, 1111]:
        generator = torch.Generator(device=device).manual_seed(seed)
        
        image = inference(network, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, seed, generator, device)
    
        image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
        image = Image.fromarray((image * 255).round().astype("uint8"))
        image.save(f"img_{seed}.png")

In [None]:

def scale_image_and_add_black_background(input_image_instance: Image.Image, scale_factor: float = 0.7) -> Image.Image:
    original_image = input_image_instance.convert("RGBA")
    original_width, original_height = original_image.size
    new_width, new_height = int(original_width * scale_factor), int(original_height * scale_factor)
    scaled_image = original_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    black_background = Image.new("RGBA", (original_width, original_height), (0, 0, 0, 255))
    x_offset, y_offset = (original_width - new_width) // 2, (original_height - new_height) // 2
    black_background.paste(scaled_image, (x_offset, y_offset), scaled_image)
    return black_background

In [None]:
openpose_sequence = []
for i in range(128):
    file = f"deepfake_{i}.png"
    
    input_image_for_canny = Image.open(f"references/{file}").convert("RGB")
    input_image_for_canny = scale_image_and_add_black_background(input_image_for_canny)

    image_np = np.array(input_image_for_canny)
    low_threshold = 100
    high_threshold = 200
    canny_edges_np = cv2.Canny(image_np, low_threshold, high_threshold)
    control_image_pil = Image.fromarray(canny_edges_np).convert("RGB")
    
    
    openpose_sequence.append(control_image_pil)

In [None]:
from controlnet_aux import CannyDetector, MidasDetector
from utils import generate_video_sequence

prompt = "(sks person:1.25); portrait of a middle-aged woman with a serious expression, pale skin, slightly sunken cheeks, sharp jawline, subtle makeup, wearing a dark sleeveless top, platinum blonde hair tied back, neutral black background, realistic lighting, slightly stylized, soft shadows, high detail, photorealistic"
negative_prompt = "blurry, distorted face, extra limbs, deformed eyes, cartoon, painting, unrealistic skin texture, low quality, glitch, waxy"

canny_processor = CannyDetector()
depth_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")

canny_sequence = [canny_processor(img) for img in openpose_sequence]
depth_sequence = [depth_processor(img) for img in openpose_sequence]

frames = generate_video_sequence(
    network=network,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    noise_scheduler=noise_scheduler,
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=guidance_scale,
    ddim_steps=ddim_steps,
    device=device,
    openpose_images=openpose_sequence,
    canny_images=canny_sequence,
    depth_images=depth_sequence,
    controlnet_openpose=controlnet_model["openpose"],
    controlnet_canny=controlnet_model["canny"],
    controlnet_depth=controlnet_model["depth"],
    openpose_scale=0.7,
    canny_scale=0.1,
    depth_scale=0.1,
    skip_first_frame_controlnet=True
)

In [None]:
if False:
    for file in os.listdir("references"):
        input_image_for_canny = Image.open(f"references/{file}").convert("RGB")
        input_image_for_canny = scale_image_and_add_black_background(input_image_for_canny)
    
        image_np = np.array(input_image_for_canny)
        low_threshold = 100
        high_threshold = 200
        canny_edges_np = cv2.Canny(image_np, low_threshold, high_threshold)
        control_image_pil = Image.fromarray(canny_edges_np).convert("RGB")
    
        img = inference_with_controlnet(
            network=network,
            unet=unet,
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            controlnet_model=controlnet_model,
            prompt=prompt,
            negative_prompt=negative_prompt,
            control_image_pil=control_image_pil,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            noise_scheduler=noise_scheduler,
            ddim_steps=ddim_steps,
            seed=seed,
            generator=generator,
            device=device,
        )
        
        image = img.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
        image = Image.fromarray((image * 255).round().astype("uint8"))
        image.save(f"deepfake/{file}")


In [None]:
def save_frames_as_gif(frames, output_path, duration=100, loop=0):
    """
    Save a list of tensor frames as an animated GIF
    
    Args:
        frames: List of tensor frames from the inference function
        output_path: Path to save the GIF file (e.g., "output.gif")
        duration: Duration of each frame in milliseconds (default: 100ms = 10 FPS)
        loop: Number of loops (0 = infinite loop)
    """
    from PIL import Image
    
    # Convert all tensor frames to PIL Images
    pil_frames = []
    for frame in frames:
        pil_frame = tensor_to_pil(frame)
        pil_frames.append(pil_frame)
    
    # Save as animated GIF
    if pil_frames:
        pil_frames[0].save(
            output_path,
            save_all=True,
            append_images=pil_frames[1:],
            duration=duration,
            loop=loop,
            optimize=True
        )
        print(f"GIF saved to: {output_path}")
        print(f"Frames: {len(pil_frames)}, Duration: {duration}ms per frame")
    else:
        print("No frames to save!")

def save_frames_as_images(frames, output_dir, prefix="frame", format="PNG"):
    """
    Save individual frames as separate image files
    
    Args:
        frames: List of tensor frames from the inference function
        output_dir: Directory to save images
        prefix: Prefix for image filenames
        format: Image format (PNG, JPEG, etc.)
    """
    import os
    from PIL import Image
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    for i, frame in enumerate(frames):
        pil_frame = frame
        filename = f"{prefix}_{i}.{format.lower()}"
        filepath = os.path.join(output_dir, filename)
        pil_frame.save(filepath, format=format)
        print(f"Saved: {filepath}")
    
    print(f"Saved {len(frames)} frames to {output_dir}")

In [None]:
save_frames_as_images(frames, "deepfake", prefix="deepfake")