In [1]:
# * Download the models, we evaluate continuous 4x8x8 here

from huggingface_hub import login, snapshot_download
import os

# login(token="<YOUR-HF-TOKEN>", add_to_git_credential=True)
model_names = [
        # "Cosmos-Tokenizer-CI8x8",
        # "Cosmos-Tokenizer-CI16x16",
        "Cosmos-Tokenizer-CV4x8x8",
        # "Cosmos-Tokenizer-CV8x8x8",
        # "Cosmos-Tokenizer-CV8x16x16",
        # "Cosmos-Tokenizer-DI8x8",
        # "Cosmos-Tokenizer-DI16x16",
        # "Cosmos-Tokenizer-DV4x8x8",
        # "Cosmos-Tokenizer-DV8x8x8",
        # "Cosmos-Tokenizer-DV8x16x16",
]
for model_name in model_names:
    hf_repo = "nvidia/" + model_name
    local_dir = "/home/maij/fall_2024/sora3r/Pollux/tvae_bench/pretrained_models/" + model_name
    os.makedirs(local_dir, exist_ok=True)
    print(f"downloading {model_name}...")
    snapshot_download(repo_id=hf_repo, local_dir=local_dir)


downloading Cosmos-Tokenizer-CV4x8x8...


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

model_config.yaml:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/21.8k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.67k [00:00<?, ?B/s]

decoder.jit:   0%|          | 0.00/88.1M [00:00<?, ?B/s]

encoder.jit:   0%|          | 0.00/61.0M [00:00<?, ?B/s]

autoencoder.jit:   0%|          | 0.00/149M [00:00<?, ?B/s]

In [None]:
git clone https://github.com/NVIDIA/Cosmos-Tokenizer.git
cd Cosmos-Tokenizer
apt-get install -y ffmpeg
pip3 install -e .

In [None]:
import torch
from cosmos_tokenizer.video_lib import CausalVideoTokenizer

model_name = "Cosmos-Tokenizer-CV4x8x8"
model_root="/home/maij/fall_2024/sora3r/Pollux/tvae_bench/pretrained_models"
input_tensor = torch.randn(1, 3, 9, 512, 512).to('cuda').to(torch.bfloat16)  # [B, C, T, H, W]
encoder = CausalVideoTokenizer(checkpoint_enc=f'{model_root}/{model_name}/encoder.jit')
(latent,) = encoder.encode(input_tensor)
torch.testing.assert_close(latent.shape, (1, 16, 3, 64, 64))

# The input tensor can be reconstructed by the decoder as:
decoder = CausalVideoTokenizer(checkpoint_dec=f'{model_root}/{model_name}/decoder.jit')
reconstructed_tensor = decoder.decode(latent)
torch.testing.assert_close(reconstructed_tensor.shape, input_tensor.shape)

In [1]:
import os
import torch
import imageio
import numpy as np
from torchvision import transforms
from typing import Tuple, List, Dict, Type
import logging

logging.getLogger('imageio_ffmpeg').setLevel(logging.ERROR)

# Configuration Class
class Config:
    def __init__(
        self,
        model_class: Type,
        model_path: str,
        device: str = "cuda",
        dtype: str = "float16",
        batch_size: int = 1,
        custom_batch_size: Dict[str, int] = None,
        source_base: str = "../resources/videos/",
        output_base: str = "./output_videos/",
        model_type: str = "general"  # "general" or "cosmos"
    ):
        self.model_class = model_class
        self.model_path = model_path
        self.device = device
        self.dtype = torch.float16 if dtype == "float16" else torch.bfloat16
        self.batch_size = batch_size
        self.custom_batch_size = custom_batch_size or {}
        self.source_base = source_base
        self.output_base = output_base
        self.model_type = model_type

# GeneralAutoEncoderKL Class
class GeneralAutoEncoderKL:
    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device(self.config.device)
        self.dtype = self.config.dtype

        os.makedirs(self.config.output_base, exist_ok=True)

        if self.config.model_type == "general":
            self.model = self.config.model_class.from_pretrained(
                self.config.model_path,
                torch_dtype=self.dtype
            ).to(self.device)
            self.model.enable_slicing()
            self.model.enable_tiling()
        elif self.config.model_type == "cosmos":
            from cosmos_tokenizer.video_lib import CausalVideoTokenizer
            self.encoder = CausalVideoTokenizer(
                checkpoint_enc=f'{self.config.model_path}/encoder.jit'
            ).to(self.device)
            self.decoder = CausalVideoTokenizer(
                checkpoint_dec=f'{self.config.model_path}/decoder.jit'
            ).to(self.device)

        print(f"Model loaded successfully from {self.config.model_path}.")

        self.transform = transforms.ToTensor()

    def preprocess_videos(self, video_paths: List[str]) -> Tuple[torch.Tensor, List[float], List[int], List[Tuple[int, int]]]:
        batch_frames = []
        fps_list, num_frames_list, resolutions = [], [], []

        for video_path in video_paths:
            video_reader = imageio.get_reader(video_path, "ffmpeg")
            meta_data = video_reader.get_meta_data()
            fps = meta_data.get('fps', 30)

            frames = [self.transform(frame) for frame in video_reader]
            video_reader.close()

            if not frames:
                raise ValueError(f"No frames found in video: {video_path}")

            num_frames = len(frames)
            resolution = frames[0].shape[1], frames[0].shape[2]

            fps_list.append(fps)
            num_frames_list.append(num_frames)
            resolutions.append(resolution)

            frames_tensor = torch.stack(frames).to(self.device).permute(1, 0, 2, 3)
            batch_frames.append(frames_tensor)

        batch_tensor = torch.stack(batch_frames).to(self.dtype)
        return batch_tensor, fps_list, num_frames_list, resolutions

    def encode(self, frames_tensor: torch.Tensor) -> torch.Tensor:
        if self.config.model_type == "general":
            with torch.no_grad():
                encoded_frames = self.model.encode(frames_tensor)[0].sample()
        elif self.config.model_type == "cosmos":
            (encoded_frames,) = self.encoder.encode(frames_tensor)
        return encoded_frames

    def decode(self, encoded_tensor: torch.Tensor) -> torch.Tensor:
        if self.config.model_type == "general":
            with torch.no_grad():
                decoded_frames = self.model.decode(encoded_tensor).sample
        elif self.config.model_type == "cosmos":
            decoded_frames = self.decoder.decode(encoded_tensor)
        return decoded_frames

    def save_videos(self, tensor: torch.Tensor, output_paths: List[str], fps_list: List[float],
                    num_frames_list: List[int], resolutions: List[Tuple[int, int]]):
        tensor = tensor.to(dtype=torch.float32)
        for i, (fps, num_frames, resolution, output_path) in enumerate(zip(fps_list, num_frames_list, resolutions, output_paths)):
            frames = tensor[i].permute(1, 2, 3, 0).cpu().numpy()
            frames = np.clip(frames, 0, 1) * 255
            frames = frames.astype(np.uint8)

            num_output_frames = frames.shape[0]
            assert num_output_frames == num_frames, (
                f"Frame count mismatch: input {num_frames} vs output {num_output_frames}")

            output_resolution = frames.shape[1], frames.shape[2]
            assert output_resolution == resolution, (
                f"Resolution mismatch: input {resolution} vs output {output_resolution}")

            writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
            for frame in frames:
                writer.append_data(frame)
            writer.close()

            print(f"Saved video to {output_path} with {num_output_frames} frames at {output_resolution} resolution and {fps} fps.")

    def reconstruct_videos(self, video_paths: List[str], output_paths: List[str]):
        batch_size = self.config.batch_size
        for dataset_name, custom_size in self.config.custom_batch_size.items():
            if any(dataset_name in path for path in video_paths):
                batch_size = custom_size
                break

        for i in range(0, len(video_paths), batch_size):
            batch_video_paths = video_paths[i:i + batch_size]
            batch_output_paths = output_paths[i:i + batch_size]

            frames_tensor, fps_list, num_frames_list, resolutions = self.preprocess_videos(batch_video_paths)
            encoded = self.encode(frames_tensor)
            decoded = self.decode(encoded)
            self.save_videos(decoded, batch_output_paths, fps_list, num_frames_list, resolutions)
            del frames_tensor, encoded, decoded

In [2]:
# * Testing code on H100 server
import os
import time
from datetime import datetime, timedelta

# Set Hugging Face home directory
os.environ["HF_HOME"] = "/jfs/jinjie/huggingface"
def format_timedelta(td):
    """Formats a timedelta object into HH:MM:SS string."""
    seconds = int(td.total_seconds())
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"

from cosmos_tokenizer.video_lib import CausalVideoTokenizer

# Processing Script
if __name__ == "__main__":
    
    data_root="/jfs/jinjie"
    
    config = Config(
        model_class=None,  # Not needed for Cosmos
        model_path=f"{data_root}/huggingface/pretrained_models/Cosmos-Tokenizer-CV4x8x8",   # Not needed for Cosmos
        device="cuda:1",
        dtype='bfloat16',
        batch_size=2,
        custom_batch_size={'imagenet_val': 1, 'textocr': 1, 'bridgedata_v2':1, 'panda_70m':1, 'real10k':1}, # * variable resolution
        source_base=f"{data_root}/data/vae_eval_bench/processed_gt_v3", 
        output_base=f"{data_root}/data/vae_eval_bench/model_recon/cosmos",
        model_type="cosmos",  # Specify Cosmos as the model type
    )

    autoencoder = GeneralAutoEncoderKL(config)

    for dataset in os.listdir(config.source_base):
        
        dataset_source_path = os.path.join(config.source_base, dataset)
        dataset_output_path = os.path.join(config.output_base, dataset)

        if not os.path.isdir(dataset_source_path):
            continue

        print(f"Processing dataset: {dataset}")
        start_time = datetime.now()  # Start the timer
        
        os.makedirs(dataset_output_path, exist_ok=True)

        video_files = sorted([f for f in os.listdir(dataset_source_path) if f.endswith('.mp4')])
        video_paths = [os.path.join(dataset_source_path, f) for f in video_files]
        output_paths = [os.path.join(dataset_output_path, f) for f in video_files]

        autoencoder.reconstruct_videos(video_paths, output_paths)
        
        end_time = datetime.now()  # Stop the timer
        time_taken = end_time - start_time
        formatted_time = format_timedelta(time_taken)

        print(f"Finished processing {dataset}. Time taken: {formatted_time}")
        

    print("All datasets have been processed successfully.")


Model loaded successfully from /home/maij/fall_2024/sora3r/Pollux/tvae_bench/pretrained_models/Cosmos-Tokenizer-CV4x8x8.
Processing dataset: ego-exo-4d-ego
Saved video to /home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench/model_recon/cosmos/ego-exo-4d-ego/0000.mp4 with 105 frames at (448, 448) resolution and 30.0 fps.
Saved video to /home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench/model_recon/cosmos/ego-exo-4d-ego/0002.mp4 with 105 frames at (448, 448) resolution and 30.0 fps.
Saved video to /home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench/model_recon/cosmos/ego-exo-4d-ego/0004.mp4 with 105 frames at (448, 448) resolution and 30.0 fps.
Saved video to /home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench/model_recon/cosmos/ego-exo-4d-ego/0013.mp4 with 105 frames at (448, 448) resolution and 30.0 fps.
Saved video to /home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench/model_recon/cosmos/ego-exo-4d-ego/0015.mp4 with 105 frames at (448, 448) resolution and

KeyboardInterrupt: 