In [11]:
import os
import json
import numpy as np
from PIL import Image
from datasets import load_dataset
from diffusers import AutoencoderKL
import torch
from torchvision import transforms
from tqdm import tqdm

In [3]:
import os
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T


class Subject200KDateset(Dataset):
    def __init__(
        self,
        base_dataset,
        condition_size: int = 512,
        target_size: int = 512,
        image_size: int = 512,
        padding: int = 0,
        condition_type: str = "subject",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
    ):
        self.base_dataset = base_dataset
        self.condition_size = condition_size
        self.target_size = target_size
        self.image_size = image_size
        self.padding = padding
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image

        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.base_dataset) * 2

    def __getitem__(self, idx):
        # If target is 0, left image is target, right image is condition
        # target = idx % 2
        item = self.base_dataset[idx // 2]

        # Crop the image to target and condition
        image = item["image"]
        left_img = image.crop(
            (
                self.padding,
                self.padding,
                self.image_size + self.padding,
                self.image_size + self.padding,
            )
        )
        right_img = image.crop(
            (
                self.image_size + self.padding * 2,
                self.padding,
                self.image_size * 2 + self.padding * 2,
                self.image_size + self.padding,
            )
        )

        # Resize the images
        left_img = left_img.resize(
            (self.condition_size, self.condition_size)
        ).convert("RGB")
        right_img = right_img.resize(
            (self.target_size, self.target_size)
        ).convert("RGB")


        description_0 = item["description"]["description_0"]
        description_1 = item["description"]["description_1"]


        return {
            "left_image": left_img,
            "right_image": right_img,
            "condition_type": self.condition_type,
            "description_0": description_0,
            "description_1": description_1,
            # **({"pil_image": image} if self.return_pil_image else {}),
        }

In [6]:
# Load the dataset
dataset = load_dataset('Yuanshi/Subjects200K')
def filter_func(item):
    if not item.get("quality_assessment"):
        return False
    return all(
        item["quality_assessment"].get(key, 0) >= 5
        for key in ["compositeStructure", "objectConsistency", "imageQuality"]
    )

data_valid = dataset["train"].filter(
    filter_func,
    num_proc=16,
    cache_file_name="./cache/dataset/data_valid.arrow",
)

Filter (num_proc=16): 100%|██████████| 206841/206841 [00:33<00:00, 6205.68 examples/s] 


In [13]:
# Initialize the dataset
training_config = {
    "dataset": {
        "condition_size": 512,
        "target_size": 512,
        "image_size": 512,
        "padding": 8,
        "drop_text_prob": 0.1,
        "drop_image_prob": 0.1,
    },
    "condition_type": "subject",
}

subject_dataset = Subject200KDateset(
    data_valid,
    condition_size=training_config["dataset"]["condition_size"],
    target_size=training_config["dataset"]["target_size"],
    image_size=training_config["dataset"]["image_size"],
    padding=training_config["dataset"]["padding"],
    condition_type=training_config["condition_type"],
    drop_text_prob=training_config["dataset"]["drop_text_prob"],
    drop_image_prob=training_config["dataset"]["drop_image_prob"],
)

# Create directories to save images
os.makedirs("output/left_images", exist_ok=True)
os.makedirs("output/right_images", exist_ok=True)
os.makedirs("output/metadata", exist_ok=True)
# Save target and condition images
for idx in tqdm(range(len(subject_dataset))):
    item = subject_dataset[idx]

    # Save target image
    left_images_path = f"output/left_images/left_{idx}.png"
    item["left_image"].save(left_images_path)

    # Save condition image
    right_images_path = f"output/condition_images/condition_{idx}.png"
    item["right_image"].save(right_images_path)
    
    # Extract metadata
    metadata = {
        "description_0": item["description_0"],
        "description_1": item["description_1"],
        "collection": data_valid[idx // 2]["collection"],
        "quality_assessment": data_valid[idx // 2]["quality_assessment"],
        "target_image_path": left_images_path,
        "condition_image_path": right_images_path,
    }

    # Save metadata as JSON
    metadata_path = f"output/metadata/meta_{idx}.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f)
    break
    # if idx % 100 == 0:
    #     print(f"Saved {idx} images.")

print("✅ Image saving completed!")

  1%|          | 2026/239934 [03:44<7:19:45,  9.02it/s]


KeyboardInterrupt: 

In [None]:
# Process and save images, latents, and metadata
for idx, item in enumerate(filtered_dataset):
    # Save the original image
    image = item['image']
    image_path = f"output/images/image_{idx}.png"
    image.save(image_path)

    # Preprocess image
    image_tensor = preprocess_image(image)

    # # Extract VAE latents
    # latents = extract_latents(image_tensor, vae)
    # np.save(f"output/latents/latent_{idx}.npy", latents)

    # Save metadata
    metadata = {
        "collection": item['collection'],
        "quality_assessment": item['quality_assessment'],
        "description": item['description'],
        "image_path": image_path,
    }
    with open(f"output/metadata/meta_{idx}.json", "w") as f:
        json.dump(metadata, f)

    if idx % 100 == 0:
        print(f"Processed {idx} samples.")

print("✅ Extraction completed!")

: 

In [4]:
import os
import json
import cv2
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from diffusers import AutoencoderKLCogVideoX
# import decord
# from decord import VideoReader
from multiprocessing import Process, Queue, Value

In [8]:
def process_video(queue, progress_queue, vae_model_path, max_frames, width, height, gpu_id, output_dir, fps):
    """
    Process videos assigned to a specific GPU.
    """
    device = f"cuda:{gpu_id}"

    # Load the VAE model
    vae = AutoencoderKLCogVideoX.from_pretrained(vae_model_path, subfolder="vae")
    vae.to(device)
    vae.eval()

    while True:
        video_path = queue.get()
        if video_path is None:  # End signal
            break

        try:
            # Load video using Decord
            # decord.bridge.set_bridge("native")
            # vr = VideoReader(video_path, ctx=decord.cpu(0))

            # Calculate frame interval
            # original_fps = 8
            # frame_interval = int(original_fps / fps)

            # # Extract frames
            # # frames = vr.get_batch(range(0, min(len(vr), max_frames * frame_interval), frame_interval)).asnumpy()
            # frames = vr.get_batch(range(0, min(len(vr), max_frames * frame_interval), frame_interval)).asnumpy()

            # # Ensure exact number of frames
            # if frames.shape[0] < max_frames:
            #     pad_frames = max_frames - frames.shape[0]
            #     print('>> shorter than max_frames : doing padding')
            #     frames = np.pad(frames, ((0, pad_frames), (0, 0), (0, 0), (0, 0)), mode="constant")
            # elif frames.shape[0] > max_frames:
            #     frames = frames[:max_frames]
            # # Resize frames using OpenCV
            # frames = np.array([cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR) for frame in frames])
            frames = cv2.imread(video_path) 
            frames = cv2.resize(frames, (width, height), interpolation=cv2.INTER_LINEAR)
            frames = np.expand_dims(frames, axis=0)  # Add frame dimension# single frame so open with cv2
            # frames = np.array([cv2.resize(frames, (width, height), interpolation=cv2.INTER_LINEAR)])
            frames = np.array(frames)
            # Convert to torch tensor and preprocess
            frames = torch.from_numpy(frames).float() / 255.0 * 2.0 - 1.0  # Normalize [-1, 1]
            frames = frames.permute(0, 3, 1, 2)  # [F, H, W, C] -> [F, C, H, W]
            
            # Add batch dimension and permute for VAE
            frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4).to(device)  # [B, C, F, H, W]

            # Encode video to latent space
            with torch.no_grad():
                latent_dist = vae.encode(frames).latent_dist
                latents = latent_dist.sample() * vae.config.scaling_factor

            # Save latents
            output_path = os.path.join(output_dir, Path(video_path).stem + "_vae_latents.npy")
            np.save(output_path, latents.cpu().numpy())

        except Exception as e:
            print(f"Error processing video {video_path}: {e}")

        # Clear GPU memory
        torch.cuda.empty_cache()

        # Notify progress
        progress_queue.put(1)

In [9]:
def extract_vae_latents(
    video_paths, vae_model_path, output_dir, height=480, width=720, max_frames=49, fps=8
):
    """
    Extract VAE latents using multiple GPUs with controlled processes.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Detect available GPUs
    available_gpus = list(range(torch.cuda.device_count()))
    if not available_gpus:
        raise RuntimeError("No GPUs are available!")

    print(f"Using GPUs: {available_gpus}")

    # Create a process for each GPU
    queues = [Queue() for _ in available_gpus]
    progress_queue = Queue()
    processes = []

    for gpu_id, queue in zip(available_gpus, queues):
        process = Process(target=process_video, args=(queue, progress_queue, vae_model_path, max_frames, width, height, gpu_id, output_dir, fps))
        process.start()
        processes.append(process)

    # Distribute videos to queues
    for i, video_path in enumerate(video_paths):
        queues[i % len(available_gpus)].put(video_path)

    # Send termination signals
    for queue in queues:
        queue.put(None)

    # Track progress using tqdm
    with tqdm(total=len(video_paths), desc="Extracting VAE latents") as pbar:
        completed = 0
        while completed < len(video_paths):
            progress_queue.get()  # Wait for progress notification
            completed += 1
            pbar.update(1)

    # Wait for all processes to finish
    for process in processes:
        process.join()

In [10]:
video_dir1 = "output/left_images"
video_dir2 = "output/right_images"
video_paths = [os.path.join(video_dir1, f) for f in os.listdir(video_dir1) if f.endswith(".png")] # single frame video (image)
video_paths += [os.path.join(video_dir2, f) for f in os.listdir(video_dir2) if f.endswith(".png")][:10] # single frame video (image)

# Extract VAE latents
extract_vae_latents(
    video_paths,
    vae_model_path="THUDM/CogVideoX-5b",
    output_dir="output/latents",
    height=512,
    width=512,
    max_frames=1,
    # fps=8,
)

Using GPUs: [0, 1, 2, 3]


Extracting VAE latents:   0%|          | 54/239944 [00:11<14:00:16,  4.76it/s]


KeyboardInterrupt: 