# [SkyReels](https://www.skyreels.ai/) V1: Human-Centric Video Foundation Model
The first and most advanced open-source human-centric video foundation model.

By fine-tuning HunyuanVideo on 10 Million high-quality film and television clips.

This file is [SkyReels For Colab](https://github.com/c1t1zen1/SkyReels-V1-For-Colab) - c1t1zen1 version

## Start by selecting A100 GPU Runtime.
Works on L4 GPU also but much slower.


---


Text-To-Video and Image-To-Video available.


---


Step 1 - Clones the repo and installs requirements.

Step 2 v1 - Loads the weights and starts a Gradio app.
  - click on the URL link to open Gradio in a new browser window.

Step2 v2 - Gradio with extended settings.

Step 2 v3 - Gradio with Cluade API prompt enhance built in.





In [None]:
# @title Step 1 - Clone Repo and Install Requirements
# Clone the repository
!git clone https://github.com/SkyworkAI/SkyReels-V1.git &> /dev/null
# !git clone https://github.com/c1t1zen1/SkyReels-V1-For-Colab.git &> /dev/null

# Install required packages (add any others as needed)
!pip install gradio -q

# Append the cloned repo to sys.path so that its modules can be imported
import sys
sys.path.append('SkyReels-V1')
%cd SkyReels-V1
!pip install -r requirements.txt -q

In [None]:
# @title Step 2 - Load Skyreels Weights and Start Gradio
# @markdown -  Select t2v for Text-To-Video or i2v for Image-To-Video Setup
#WORKING CODE

import torch
import torch.distributed as dist
import os
from PIL import Image
import random
import time
import gradio as gr
from diffusers.utils import export_to_video, load_image

from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from skyreelsinfer.offload import OffloadConfig

def init_process_group():
    """Initialize the process group only if it hasn't been initialized yet"""
    if not dist.is_initialized():
        # Set environment variables
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '23456'
        os.environ['WORLD_SIZE'] = '1'
        os.environ['RANK'] = '0'
        os.environ['LOCAL_RANK'] = '0'

        # Initialize the process group
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:23456",
            world_size=1,
            rank=0
        )

class SimpleSkyReelsInfer:
    def __init__(
        self,
        task_type: TaskType,
        model_id: str,
        quant_model: bool = True,
    ):
        # Create proper offload config
        offload_config = OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
            compiler_cache="./compiler_cache"
        )

        # Monkey patch the init_process_group call in SkyReelsVideoSingleGpuInfer
        def dummy_init(*args, **kwargs):
            pass

        # Store the original function
        original_init = dist.init_process_group

        try:
            # Replace the init function temporarily
            dist.init_process_group = dummy_init

            # Create the pipe instance
            self.pipe = SkyReelsVideoSingleGpuInfer(
                task_type=task_type,
                model_id=model_id,
                quant_model=quant_model,
                local_rank=0,
                world_size=1,
                is_offload=True,
                offload_config=offload_config,
                enable_cfg_parallel=False
            )
        finally:
            # Restore the original function
            dist.init_process_group = original_init

    def inference(self, kwargs):
        if not isinstance(kwargs, dict):
            raise ValueError("kwargs must be a dictionary")

        print(f"Inference kwargs: {kwargs}")  # Debug print

        if "seed" in kwargs:
            kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
            del kwargs["seed"]

        return self.pipe.pipe(**kwargs).frames[0]

def generate_video(prompt, seed, image=None):
    global task_type, predictor
    print(f"image: {type(image)}")
    print(f"prompt: {prompt}")
    print(f"seed: {seed}")

    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    kwargs = {
        "prompt": prompt,
        "height": 544,
        "width": 960,
        "num_frames": 97,
        "num_inference_steps": 30,
        "seed": seed,
        "guidance_scale": 6.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
        "cfg_for": False
    }

    if task_type == "i2v":
        assert image is not None, "please input image"
        if isinstance(image, str):
            image = Image.open(image)
        kwargs["image"] = image

    print(f"Final kwargs: {kwargs}")  # Debug print

    try:
        output = predictor.inference(kwargs)
        save_dir = f"./result/{task_type}"
        os.makedirs(save_dir, exist_ok=True)
        video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
        print(f"generate video, local path: {video_out_file}")
        export_to_video(output, video_out_file, fps=24)
        return video_out_file, kwargs
    except Exception as e:
        print(f"Error during inference: {str(e)}")
        import traceback
        print(traceback.format_exc())
        raise

def init_predictor(task_type: str):
    global predictor
    try:
        # Set CUDA device first
        if torch.cuda.is_available():
            torch.cuda.init()
            torch.cuda.empty_cache()
            torch.cuda.set_device(0)

        # Initialize process group first
        init_process_group()

        model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"

        predictor = SimpleSkyReelsInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id=model_id,
            quant_model=True
        )
        print("Predictor initialized successfully")

    except Exception as e:
        print(f"Initialization error: {str(e)}")
        import traceback
        print(traceback.format_exc())
        raise

def create_gradio_interface(task_type):
    """Create a Gradio interface based on the task type."""
    if task_type == "i2v":
        with gr.Blocks() as demo:
            with gr.Row():
                image = gr.Image(label="Upload Image", type="filepath")
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")

            # Button logic
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed, image],
                outputs=[output_video, output_params]
            )
    elif task_type == "t2v":
        with gr.Blocks() as demo:
            with gr.Row():
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")

            # Button logic
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed],
                outputs=[output_video, output_params]
            )
    return demo

if __name__ == '__main__':
    # Set environmental variables
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    # Print system info
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device count: {torch.cuda.device_count()}")

    task_type = "i2v" # @param ["i2v","t2v"]
    print("Starting initialization...")
    init_predictor(task_type)

    # Create and launch the Gradio interface
    demo = create_gradio_interface(task_type)
    demo.launch(share=True, debug=True)

In [None]:
# @title Step 2 v2 - Load Skyreels Weights and Start Extended Gradio
# @markdown -  Select t2v for Text-To-Video or i2v for Image-To-Video Setup
#WORKING CODE

import torch
import torch.distributed as dist
import os
from PIL import Image
import random
import time
import gradio as gr
from diffusers.utils import export_to_video, load_image

from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from skyreelsinfer.offload import OffloadConfig

def init_process_group():
    """Initialize the process group only if it hasn't been initialized yet"""
    if not dist.is_initialized():
        # Set environment variables
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '23456'
        os.environ['WORLD_SIZE'] = '1'
        os.environ['RANK'] = '0'
        os.environ['LOCAL_RANK'] = '0'

        # Initialize the process group
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:23456",
            world_size=1,
            rank=0
        )

class SimpleSkyReelsInfer:
    def __init__(
        self,
        task_type: TaskType,
        model_id: str,
        quant_model: bool = True,
    ):
        # Create proper offload config
        offload_config = OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
            compiler_cache="./compiler_cache"
        )

        # Monkey patch the init_process_group call in SkyReelsVideoSingleGpuInfer
        def dummy_init(*args, **kwargs):
            pass

        # Store the original function
        original_init = dist.init_process_group

        try:
            # Replace the init function temporarily
            dist.init_process_group = dummy_init

            # Create the pipe instance
            self.pipe = SkyReelsVideoSingleGpuInfer(
                task_type=task_type,
                model_id=model_id,
                quant_model=quant_model,
                local_rank=0,
                world_size=1,
                is_offload=True,
                offload_config=offload_config,
                enable_cfg_parallel=False
            )
        finally:
            # Restore the original function
            dist.init_process_group = original_init

    def inference(self, kwargs):
        if not isinstance(kwargs, dict):
            raise ValueError("kwargs must be a dictionary")

        print(f"Inference kwargs: {kwargs}")  # Debug print

        if "seed" in kwargs:
            kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
            del kwargs["seed"]

        return self.pipe.pipe(**kwargs).frames[0]

def round_to_multiple_of_16(n):
    """Round a number to the nearest multiple of 16."""
    return ((n + 8) // 16) * 16

def generate_video(prompt, seed, negative_prompt, width, height, num_frames,
                   num_inference_steps, guidance_scale, embedded_guidance_scale, image=None):
    global task_type, predictor

    # Debug prints
    print(f"Starting video generation with parameters:")
    print(f"- Image type: {type(image)}")
    print(f"- Prompt: {prompt}")
    print(f"- Original dimensions: {width}x{height}")

    # Round dimensions to nearest multiple of 16
    width = round_to_multiple_of_16(int(width))
    height = round_to_multiple_of_16(int(height))
    print(f"- Adjusted dimensions: {width}x{height}")

    print(f"- Frames: {num_frames}")
    print(f"- Steps: {num_inference_steps}")
    print(f"- Seed: {seed}")

    # Handle random seed
    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    # Build kwargs dictionary
    kwargs = {
        "prompt": prompt,
        "height": height,  # Now guaranteed to be divisible by 16
        "width": width,   # Now guaranteed to be divisible by 16
        "num_frames": int(num_frames),
        "num_inference_steps": int(num_inference_steps),
        "seed": seed,
        "guidance_scale": float(guidance_scale),
        "embedded_guidance_scale": float(embedded_guidance_scale),
        "negative_prompt": negative_prompt,
        "cfg_for": False
    }

    # Handle image for i2v mode
    if task_type == "i2v":
        if image is None:
            raise ValueError("Image is required for image-to-video mode")

        if isinstance(image, str):
            try:
                image = Image.open(image)
                # Resize image if needed while maintaining aspect ratio
                aspect_ratio = image.width / image.height
                new_width = width
                new_height = int(width / aspect_ratio)
                if new_height > height:
                    new_height = height
                    new_width = int(height * aspect_ratio)
                # Round the dimensions to multiples of 16
                new_width = round_to_multiple_of_16(new_width)
                new_height = round_to_multiple_of_16(new_height)
                image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            except Exception as e:
                raise ValueError(f"Error processing image: {str(e)}")

        kwargs["image"] = image

    print(f"Final generation parameters: {kwargs}")

    try:
        # Generate the video
        output = predictor.inference(kwargs)

        # Save the video
        save_dir = f"./result/{task_type}"
        os.makedirs(save_dir, exist_ok=True)

        # Create a more informative filename
        base_filename = f"{prompt[:50].replace(' ', '_').replace('/','')}_{width}x{height}_{num_frames}fr_{seed}"
        video_out_file = f"{save_dir}/{base_filename}.mp4"

        print(f"Saving video to: {video_out_file}")
        export_to_video(output, video_out_file, fps=24)

        # Return both the video file and the parameters used
        return video_out_file, kwargs

    except Exception as e:
        print(f"Error during video generation: {str(e)}")
        import traceback
        print(traceback.format_exc())
        raise ValueError(f"Video generation failed: {str(e)}")

def create_gradio_interface(task_type):
    """Create an enhanced Gradio interface with more controls."""

    css = """
        .container { margin: 15px; }
        .output-panel { margin-top: 20px; }
        .control-panel { padding: 10px; border: 1px solid #ccc; border-radius: 8px; }
    """

    with gr.Blocks(css=css) as demo:
        gr.Markdown("# SkyReels Video Generation")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Input Controls")
                with gr.Group(elem_classes="control-panel"):
                    # Image upload for i2v mode
                    if task_type == "i2v":
                        image = gr.Image(
                            label="Upload Image",
                            type="filepath",
                            # tool="editor",
                            elem_id="input-image"
                        )

                    # Basic parameters
                    prompt = gr.Textbox(
                        label="Input Prompt",
                        placeholder="Describe what you want to generate...",
                        lines=3
                    )
                    negative_prompt = gr.Textbox(
                        label="Negative Prompt",
                        value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
                        lines=2
                    )

                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(
                            label="Random Seed (-1 for random)",
                            value=-1,
                            minimum=-1,
                            precision=0
                        )
                        width = gr.Slider(
                            label="Width",
                            minimum=128,
                            maximum=1024,
                            value=512,
                            step=16  # Changed to 16 to ensure valid values
                        )
                        height = gr.Slider(
                            label="Height",
                            minimum=128,
                            maximum=1024,
                            value=512,
                            step=16  # Changed to 16 to ensure valid values
                        )
                        num_frames = gr.Slider(
                            label="Number of Frames",
                            minimum=16,
                            maximum=128,
                            value=97,
                            step=1
                        )
                        num_inference_steps = gr.Slider(
                            label="Inference Steps",
                            minimum=1,
                            maximum=100,
                            value=30,
                            step=1
                        )
                        guidance_scale = gr.Slider(
                            label="Guidance Scale",
                            minimum=1,
                            maximum=20,
                            value=6.0,
                            step=0.5
                        )
                        embedded_guidance_scale = gr.Slider(
                            label="Embedded Guidance Scale",
                            minimum=0.1,
                            maximum=5.0,
                            value=1.0,
                            step=0.1
                        )

                    submit_button = gr.Button(
                        "Generate Video",
                        variant="primary",
                        size="lg"
                    )

            with gr.Column(scale=1):
                gr.Markdown("### Output")
                with gr.Group(elem_classes="output-panel"):
                    output_video = gr.Video(label="Generated Video")
                    output_params = gr.JSON(label="Generation Parameters")

        # Event handlers
        input_components = [
            prompt,
            seed,
            negative_prompt,
            width,
            height,
            num_frames,
            num_inference_steps,
            guidance_scale,
            embedded_guidance_scale
        ]

        if task_type == "i2v":
            input_components.append(image)

        submit_button.click(
            fn=generate_video,
            inputs=input_components,
            outputs=[output_video, output_params]
        )

        # Add help section
        with gr.Accordion("Help", open=False):
            gr.Markdown("""
                ### Tips for best results:
                - Use clear, descriptive prompts
                - Experiment with different guidance scales
                - Try different seeds for variation
                - Adjust frame count based on your needs

                ### Common issues:
                - If generation fails, try reducing the resolution
                - High guidance scales may produce stronger but potentially less stable results
                - Larger frame counts will take longer to generate
            """)

    return demo

if __name__ == '__main__':
    # Set environmental variables
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    # Print system info
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device count: {torch.cuda.device_count()}")

    task_type = "t2v" # @param ["i2v","t2v"]
    print("Starting initialization...")
    init_predictor(task_type)

    # Create and launch the Gradio interface
    demo = create_gradio_interface(task_type)
    demo.launch(share=True, debug=True)

In [None]:
# @title Step 2 v3 - Enhance Prompts with Claude API

!pip install anthropic tiktoken -q

import torch
import torch.distributed as dist
import os
from PIL import Image
import random
import time
import gradio as gr
from diffusers.utils import export_to_video, load_image
import anthropic
import json
from functools import partial
import tiktoken


from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from skyreelsinfer.offload import OffloadConfig

# Initialize the Claude client
api_key = "" # @param {"type":"string","placeholder":"Your Claude API Key Here - Optional"}
claude = anthropic.Anthropic(api_key=api_key)

tokenizer = tiktoken.get_encoding("cl100k_base")

def init_process_group():
    """Initialize the process group only if it hasn't been initialized yet"""
    if not dist.is_initialized():
        # Set environment variables
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '23456'
        os.environ['WORLD_SIZE'] = '1'
        os.environ['RANK'] = '0'
        os.environ['LOCAL_RANK'] = '0'

        # Initialize the process group
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:23456",
            world_size=1,
            rank=0
        )

class SimpleSkyReelsInfer:
    def __init__(
        self,
        task_type: TaskType,
        model_id: str,
        quant_model: bool = True,
    ):
        # Create proper offload config
        offload_config = OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
            compiler_cache="./compiler_cache"
        )

        # Monkey patch the init_process_group call in SkyReelsVideoSingleGpuInfer
        def dummy_init(*args, **kwargs):
            pass

        # Store the original function
        original_init = dist.init_process_group

        try:
            # Replace the init function temporarily
            dist.init_process_group = dummy_init

            # Create the pipe instance
            self.pipe = SkyReelsVideoSingleGpuInfer(
                task_type=task_type,
                model_id=model_id,
                quant_model=quant_model,
                local_rank=0,
                world_size=1,
                is_offload=True,
                offload_config=offload_config,
                enable_cfg_parallel=False
            )
        finally:
            # Restore the original function
            dist.init_process_group = original_init

    def inference(self, kwargs):
        if not isinstance(kwargs, dict):
            raise ValueError("kwargs must be a dictionary")

        print(f"Inference kwargs: {kwargs}")  # Debug print

        if "seed" in kwargs:
            kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
            del kwargs["seed"]

        return self.pipe.pipe(**kwargs).frames[0]

def count_tokens(text):
    """Count the number of tokens in a text string"""
    return len(tokenizer.encode(text))

def truncate_to_token_limit(text, limit):
    """Truncate text to stay within token limit while trying to maintain coherent phrases"""
    tokens = tokenizer.encode(text)
    if len(tokens) <= limit:
        return text

    # Decode truncated tokens and try to find a good breaking point
    truncated_text = tokenizer.decode(tokens[:limit])

    # Try to break at the last sentence
    last_period = truncated_text.rfind('.')
    if last_period > len(truncated_text) * 0.7:  # Only break at sentence if it's in the latter part
        return truncated_text[:last_period + 1]

    # Try to break at the last comma
    last_comma = truncated_text.rfind(',')
    if last_comma > len(truncated_text) * 0.7:
        return truncated_text[:last_comma + 1]

    # If no good breaking point, just return the truncated text
    return truncated_text

def enhance_prompt(prompt, negative_prompt):
    """
    Use Claude to enhance both the positive and negative prompts,
    ensuring responses don't exceed 77 tokens each
    """
    system = """You are an expert at creating prompts for text-to-video AI models.
    Given an input prompt and negative prompt, enhance them to create better quality videos.
    Return only a JSON object with two keys: 'prompt' and 'negative_prompt'.
    Each prompt MUST be 77 tokens or less.
    Focus on adding cinematic details and quality specifications.
    Do not explain or add any other text."""

    user_message = f"""Original prompt: {prompt}
    Original negative prompt: {negative_prompt}

    Please enhance both prompts while keeping each under 77 tokens."""

    try:
        response = claude.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=1000,
            temperature=0.7,
            system=system,
            messages=[
                {
                    "role": "user",
                    "content": user_message
                }
            ]
        )

        # Extract JSON from response
        enhanced = json.loads(response.content[0].text)

        # Ensure prompts are within token limit
        enhanced_prompt = truncate_to_token_limit(enhanced['prompt'], 77)
        enhanced_negative = truncate_to_token_limit(enhanced['negative_prompt'], 77)

        # Double check token counts
        prompt_tokens = count_tokens(enhanced_prompt)
        negative_tokens = count_tokens(enhanced_negative)

        print(f"Enhanced prompt tokens: {prompt_tokens}")
        print(f"Enhanced negative prompt tokens: {negative_tokens}")

        return enhanced_prompt, enhanced_negative

    except Exception as e:
        print(f"Error enhancing prompt: {str(e)}")
        # Ensure original prompts are within token limit
        safe_prompt = truncate_to_token_limit(prompt, 77)
        safe_negative = truncate_to_token_limit(negative_prompt, 77)
        return safe_prompt, safe_negative

def round_to_multiple_of_16(n):
    """Round a number to the nearest multiple of 16."""
    return ((n + 8) // 16) * 16

def update_prompts(prompt, negative_prompt):
    """Callback function for the enhance button"""
    try:
        enhanced_prompt, enhanced_negative = enhance_prompt(prompt, negative_prompt)
        return enhanced_prompt, enhanced_negative
    except Exception as e:
        print(f"Error in update_prompts: {str(e)}")
        return prompt, negative_prompt

def generate_video(prompt, seed, negative_prompt, width, height, num_frames,
                   num_inference_steps, guidance_scale, embedded_guidance_scale, image=None):
    global task_type, predictor

    # Debug prints
    print(f"Starting video generation with parameters:")
    print(f"- Image type: {type(image)}")
    print(f"- Prompt: {prompt}")
    print(f"- Original dimensions: {width}x{height}")

    # Round dimensions to nearest multiple of 16
    width = round_to_multiple_of_16(int(width))
    height = round_to_multiple_of_16(int(height))
    print(f"- Adjusted dimensions: {width}x{height}")

    print(f"- Frames: {num_frames}")
    print(f"- Steps: {num_inference_steps}")
    print(f"- Seed: {seed}")

    # Handle random seed
    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    # Build kwargs dictionary
    kwargs = {
        "prompt": prompt,
        "height": height,
        "width": width,
        "num_frames": int(num_frames),
        "num_inference_steps": int(num_inference_steps),
        "seed": seed,
        "guidance_scale": float(guidance_scale),
        "embedded_guidance_scale": float(embedded_guidance_scale),
        "negative_prompt": negative_prompt,
        "cfg_for": False
    }

    # Handle image for i2v mode
    if task_type == "i2v":
        if image is None:
            raise ValueError("Image is required for image-to-video mode")

        if isinstance(image, str):
            try:
                image = Image.open(image)
                aspect_ratio = image.width / image.height
                new_width = width
                new_height = int(width / aspect_ratio)
                if new_height > height:
                    new_height = height
                    new_width = int(height * aspect_ratio)
                new_width = round_to_multiple_of_16(new_width)
                new_height = round_to_multiple_of_16(new_height)
                image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            except Exception as e:
                raise ValueError(f"Error processing image: {str(e)}")

        kwargs["image"] = image

    print(f"Final generation parameters: {kwargs}")

    try:
        # Generate the video
        output = predictor.inference(kwargs)

        # Save the video
        save_dir = f"./result/{task_type}"
        os.makedirs(save_dir, exist_ok=True)

        base_filename = f"{prompt[:50].replace(' ', '_').replace('/','')}_{width}x{height}_{num_frames}fr_{seed}"
        video_out_file = f"{save_dir}/{base_filename}.mp4"

        print(f"Saving video to: {video_out_file}")
        export_to_video(output, video_out_file, fps=24)

        return video_out_file, kwargs

    except Exception as e:
        print(f"Error during video generation: {str(e)}")
        import traceback
        print(traceback.format_exc())
        raise ValueError(f"Video generation failed: {str(e)}")

def init_predictor(task_type: str):
    global predictor
    try:
        # Set CUDA device first
        if torch.cuda.is_available():
            torch.cuda.init()
            torch.cuda.empty_cache()
            torch.cuda.set_device(0)

        # Initialize process group first
        init_process_group()

        model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"

        predictor = SimpleSkyReelsInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id=model_id,
            quant_model=True
        )
        print("Predictor initialized successfully")

    except Exception as e:
        print(f"Initialization error: {str(e)}")
        import traceback
        print(traceback.format_exc())
        raise

def create_gradio_interface(task_type):
    """Create an enhanced Gradio interface with prompt enhancement."""

    css = """
        .container { margin: 15px; }
        .output-panel { margin-top: 20px; }
        .control-panel { padding: 10px; border: 1px solid #ccc; border-radius: 8px; }
        .enhance-button { margin: 10px 0; }
    """

    with gr.Blocks(css=css) as demo:
        gr.Markdown("# SkyReels Video Generation")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Input Controls")
                with gr.Group(elem_classes="control-panel"):
                    # Image upload for i2v mode
                    if task_type == "i2v":
                        image = gr.Image(
                            label="Upload Image",
                            type="filepath",
                            # tool="editor",
                            elem_id="input-image"
                        )

                    # Prompt inputs with enhancement
                    with gr.Group():
                        prompt = gr.Textbox(
                            label="Input Prompt",
                            placeholder="Describe what you want to generate...",
                            lines=3
                        )
                        negative_prompt = gr.Textbox(
                            label="Negative Prompt",
                            value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
                            lines=2
                        )
                        enhance_button = gr.Button(
                            "✨ Enhance Prompts with AI",
                            elem_classes="enhance-button"
                        )

                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(
                            label="Random Seed (-1 for random)",
                            value=-1,
                            minimum=-1,
                            precision=0
                        )
                        width = gr.Slider(
                            label="Width",
                            minimum=128,
                            maximum=1024,
                            value=512,
                            step=16
                        )
                        height = gr.Slider(
                            label="Height",
                            minimum=128,
                            maximum=1024,
                            value=512,
                            step=16
                        )
                        num_frames = gr.Slider(
                            label="Number of Frames",
                            minimum=16,
                            maximum=128,
                            value=97,
                            step=1
                        )
                        num_inference_steps = gr.Slider(
                            label="Inference Steps",
                            minimum=1,
                            maximum=100,
                            value=30,
                            step=1
                        )
                        guidance_scale = gr.Slider(
                            label="Guidance Scale",
                            minimum=1,
                            maximum=20,
                            value=6.0,
                            step=0.5
                        )
                        embedded_guidance_scale = gr.Slider(
                            label="Embedded Guidance Scale",
                            minimum=0.1,
                            maximum=5.0,
                            value=1.0,
                            step=0.1
                        )

                    generate_button = gr.Button(
                        "Generate Video",
                        variant="primary",
                        size="lg"
                    )

            with gr.Column(scale=1):
                gr.Markdown("### Output")
                with gr.Group(elem_classes="output-panel"):
                    output_video = gr.Video(label="Generated Video")
                    output_params = gr.JSON(label="Generation Parameters")

        # Set up event handlers
        enhance_button.click(
            fn=update_prompts,
            inputs=[prompt, negative_prompt],
            outputs=[prompt, negative_prompt]
        )

        input_components = [
            prompt,
            seed,
            negative_prompt,
            width,
            height,
            num_frames,
            num_inference_steps,
            guidance_scale,
            embedded_guidance_scale
        ]

        if task_type == "i2v":
            input_components.append(image)

        generate_button.click(
            fn=generate_video,
            inputs=input_components,
            outputs=[output_video, output_params]
        )

        # Add help section
        with gr.Accordion("Help", open=False):
            gr.Markdown("""
                ### Tips for best results:
                - Use the "Enhance Prompts" button to improve your prompts with AI
                - Use clear, descriptive prompts
                - Experiment with different guidance scales
                - Try different seeds for variation
                - Adjust frame count based on your needs

                ### Common issues:
                - If generation fails, try reducing the resolution
                - High guidance scales may produce stronger but potentially less stable results
                - Larger frame counts will take longer to generate
            """)

    return demo

if __name__ == '__main__':
    # Set environmental variables
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    # Print system info
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device count: {torch.cuda.device_count()}")
# @markdown For Text-To-Video = t2v or For  Image-To-Video = i2v
    task_type = "t2v" # @param ["i2v","t2v"]
    print("Starting initialization...")
    init_predictor(task_type)

    # Create and launch the Gradio interface
    demo = create_gradio_interface(task_type)
    demo.launch(share=True, debug=True)