In [3]:
# Install required packages if not already installed
# !pip install termcolor tqdm decord pytorchvideo

# Import necessary libraries
import os
import torch
import imageio
import numpy as np
from torchvision import transforms
from typing import Tuple, List, Dict
import logging
import json
from termcolor import colored
from tqdm import tqdm
from torch.utils.data import Dataset, Subset, DataLoader
from decord import VideoReader, cpu
from torchvision.transforms import Compose, Lambda
from pytorchvideo.transforms import ShortSideScale
from torchvision.transforms._transforms_video import CenterCropVideo

from cal_flolpips import calculate_flolpips
from cal_lpips import calculate_lpips
from cal_psnr import calculate_psnr
from cal_ssim import calculate_ssim

# Suppress specific imageio FFmpeg warnings
logging.getLogger('imageio_ffmpeg').setLevel(logging.ERROR)

# Configuration Class
class Config:
    """
    Configuration class to manage parameters for evaluation operations.
    """
    def __init__(
        self,
        data_root: str,
        model_name: str = "cogvideox",
        device: str = "cuda",
        dtype: str = "float16",
        metrics: List[str] = ["ssim", "psnr", "lpips", "flolpips"],
        batch_size: int = 2,
        num_workers: int = 4,
        num_frames: int = 100,
        sample_rate: int = 1,
        resolution: int = 128,
        crop_size: int = None,
        subset_size: int = None,
        fvd_method: str = "styleganv",
        output_json: str = "result.json"
    ):
        """
        Initializes the configuration with default or specified parameters.
        
        Parameters:
        - data_root (str): Root directory containing 'processed_gt_v2' and 'model_recon/<model_name>'.
        - model_name (str): Name of the VAE model (e.g., 'cogvideox').
        - device (str): Computation device ('cuda', 'cuda:0', 'cpu', etc.).
        - dtype (str): Data type for computation ('float16' or 'bfloat16').
        - metrics (List[str]): List of metrics to calculate.
        - batch_size (int): Batch size for DataLoader.
        - num_workers (int): Number of worker processes for DataLoader.
        - num_frames (int): Number of frames to sample from each video.
        - sample_rate (int): Sampling rate for frames.
        - resolution (int): Short side size for scaling videos.
        - crop_size (int): Size for center cropping. If None, no cropping is applied.
        - subset_size (int): If specified, process only a subset of the dataset.
        - fvd_method (str): Method for FVD calculation.
        - output_json (str): Filename for saving the evaluation results.
        """
        self.data_root = data_root
        self.model_name = model_name
        self.device = device
        self.dtype = torch.float16 if dtype == "float16" else torch.bfloat16
        self.metrics = metrics
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_frames = num_frames
        self.sample_rate = sample_rate
        self.resolution = resolution
        self.crop_size = crop_size
        self.subset_size = subset_size
        self.fvd_method = fvd_method
        self.output_json = output_json
        
        # Define paths
        self.real_video_base = os.path.join(self.data_root, "processed_gt_v2")
        self.generated_video_base = os.path.join(self.data_root, "model_recon", self.model_name)
        self.result_json_path = os.path.join(self.generated_video_base, self.output_json)

# VideoDataset Class
class VideoDataset(Dataset):
    def __init__(
        self,
        real_video_dir: str,
        generated_video_dir: str,
        num_frames: int,
        sample_rate: int = 1,
        crop_size: int = None,
        resolution: int = 128,
    ) -> None:
        super().__init__()
        self.real_video_files = self._get_sorted_videos(real_video_dir)
        self.generated_video_files = self._get_sorted_videos(generated_video_dir)
        self.num_frames = num_frames
        self.sample_rate = sample_rate
        self.crop_size = crop_size
        self.short_size = resolution

    def __len__(self):
        return len(self.real_video_files)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError
        real_video_file = self.real_video_files[index]
        generated_video_file = self.generated_video_files[index]
        real_video_tensor = self._load_video(real_video_file)
        generated_video_tensor = self._load_video(generated_video_file)
        return {"real": real_video_tensor, "generated": generated_video_tensor}

    def _load_video(self, video_path: str) -> torch.Tensor:
        num_frames = self.num_frames
        sample_rate = self.sample_rate
        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        sample_frames_len = sample_rate * num_frames

        if total_frames >= sample_frames_len:
            s = 0
            e = s + sample_frames_len
            num_frames = num_frames
        else:
            s = 0
            e = total_frames
            num_frames = int(total_frames / sample_frames_len * num_frames)
            print(
                colored(f"Sample_frames_len {sample_frames_len}, can only sample {num_frames * sample_rate} frames from {video_path}, total frames: {total_frames}", "yellow")
            )

        frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
        video_data = vr.get_batch(frame_id_list).asnumpy()
        video_data = torch.from_numpy(video_data)
        video_data = video_data.permute(0, 3, 1, 2)  # (T, H, W, C) -> (C, T, H, W)
        return self._preprocess(video_data)

    def _preprocess(self, video_data: torch.Tensor) -> torch.Tensor:
        transform = Compose(
            [
                Lambda(lambda x: x / 255.0),
                ShortSideScale(size=self.short_size),
                CenterCropVideo(crop_size=self.crop_size),
            ]
        )
        video_outputs = transform(video_data)
        return video_outputs  # (C, T, H, W)

    def _get_sorted_videos(self, folder_path: str) -> List[str]:
        """
        Retrieves and sorts video files from a directory.
        
        Parameters:
        - folder_path (str): Path to the video directory.
        
        Returns:
        - List[str]: Sorted list of video file names.
        """
        videos = [f for f in os.listdir(folder_path) 
                  if os.path.isfile(os.path.join(folder_path, f)) 
                  and os.path.splitext(f)[1].lower() in ['.mp4']]
        return sorted(videos)

# Custom Print Function with Colored Output
def print_metric_result(dataset: str, metrics: Dict[str, float]):
    """
    Prints the metric results for a single dataset with colored output.
    
    Parameters:
    - dataset (str): Name of the dataset.
    - metrics (Dict[str, float]): Dictionary of metric names and their corresponding scores.
    """
    print(colored(f"\nDataset: {dataset}", "cyan", attrs=["bold"]))
    for metric, score in metrics.items():
        if metric.lower() in ["ssim", "psnr"]:
            color = "green" if metric.lower() == "ssim" else "yellow"
        elif metric.lower() in ["lpips", "flolpips"]:
            color = "magenta"
        else:
            color = "white"
        print(colored(f"  {metric.upper()}: {score:.4f}", color, attrs=["bold"]))

# calculate_common_metric Function
def calculate_common_metric(
    metrics: List[str],
    dataloader: DataLoader,
    device: torch.device
) -> Dict[str, float]:
    """
    Calculates the specified metrics for the given DataLoader.
    
    Parameters:
    - metrics (List[str]): List of metrics to calculate.
    - dataloader (DataLoader): DataLoader for the dataset.
    - device (torch.device): Device to perform calculations on.
    
    Returns:
    - Dict[str, float]: Dictionary of metric names and their corresponding scores.
    """
    metric_dict = {}
    
    print(colored(f"Calculating Metrics: {', '.join(metrics)}", "blue", attrs=["bold"]))
    
    for metric in metrics:
        score_list = []
        for batch_data in tqdm(dataloader, desc=f"Calculating {metric.upper()}"):
            real_videos = batch_data["real"].to(device)
            generated_videos = batch_data["generated"].to(device)
            assert real_videos.shape[2] == generated_videos.shape[2], "Frame count mismatch between real and generated videos."
            
            if metric.lower() == "ssim":
                tmp_list = list(calculate_ssim(real_videos, generated_videos)["value"].values())
            elif metric.lower() == "psnr":
                tmp_list = list(calculate_psnr(real_videos, generated_videos)["value"].values())
            elif metric.lower() == "flolpips":
                result = calculate_flolpips(real_videos, generated_videos, device)
                tmp_list = list(result["value"].values())
            elif metric.lower() == "lpips":
                tmp_list = list(calculate_lpips(real_videos, generated_videos, device)["value"].values())
            else:
                print(colored(f"Metric '{metric}' is not supported. Skipping.", "red"))
                continue
            score_list += tmp_list
        if score_list:
            metric_dict[metric] = np.mean(score_list)
    
    return metric_dict

# Evaluator Class
class Evaluator:
    """
    Evaluator class to handle the calculation of metrics for video datasets.
    """
    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device(self.config.device)
        self.metrics = config.metrics
        self.result = {}
        self.avg_result = {}
        
    def get_datasets(self) -> List[str]:
        """
        Retrieves the list of dataset names from the real_video_base directory.
        
        Returns:
        - List[str]: List of dataset directory names.
        """
        datasets = [d for d in os.listdir(self.config.real_video_base) 
                   if os.path.isdir(os.path.join(self.config.real_video_base, d))]
        return sorted(datasets)
    
    def calculate_metrics_for_dataset(self, dataset: str):
        """
        Calculates metrics for a single dataset.
        
        Parameters:
        - dataset (str): Name of the dataset.
        """
        real_video_dir = os.path.join(self.config.real_video_base, dataset)
        generated_video_dir = os.path.join(self.config.generated_video_base, dataset)
        
        # Check if generated_video_dir exists
        if not os.path.exists(generated_video_dir):
            print(colored(f"Generated video directory does not exist for dataset: {dataset}. Skipping...", "red"))
            return
        
        # Initialize VideoDataset and DataLoader
        dataset_obj = VideoDataset(
            real_video_dir=real_video_dir,
            generated_video_dir=generated_video_dir,
            num_frames=self.config.num_frames,
            sample_rate=self.config.sample_rate,
            crop_size=self.config.crop_size,
            resolution=self.config.resolution
        )
        
        if self.config.subset_size:
            dataset_obj = Subset(dataset_obj, indices=range(self.config.subset_size))
        
        dataloader = DataLoader(
            dataset_obj,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
            pin_memory=True
        )
        
        # Calculate metrics
        metric_scores = calculate_common_metric(
            metrics=self.metrics,
            dataloader=dataloader,
            device=self.device
        )
        
        # Store the result
        self.result[dataset] = metric_scores
        
        # Print the metric result
        print_metric_result(dataset, metric_scores)
    
    def evaluate_all_datasets(self):
        """
        Iterates through all datasets and calculates metrics.
        """
        datasets = self.get_datasets()
        if not datasets:
            print(colored("No datasets found in the real video base directory.", "red"))
            return
        
        for dataset in tqdm(datasets, desc="Processing Datasets"):
            self.calculate_metrics_for_dataset(dataset)
        
        # Calculate average metrics
        self.calculate_average_metrics()
        
        # Save results to JSON
        self.save_results()
        
        # Print average results
        self.print_average_results()
    
    def calculate_average_metrics(self):
        """
        Calculates the average of each metric across all datasets.
        """
        if not self.result:
            print(colored("No results to average.", "red"))
            return
        
        metric_sums = {metric: 0.0 for metric in self.metrics}
        metric_counts = {metric: 0 for metric in self.metrics}
        
        for dataset_metrics in self.result.values():
            for metric, score in dataset_metrics.items():
                metric_sums[metric] += score
                metric_counts[metric] += 1
        
        self.avg_result = {metric: (metric_sums[metric] / metric_counts[metric] 
                                    if metric_counts[metric] > 0 else 0.0) 
                           for metric in self.metrics}
    
    def print_average_results(self):
        """
        Prints the average metrics across all datasets with colored output.
        """
        print(colored("\nAverage Metrics Across All Datasets:", "cyan", attrs=["bold"]))
        for metric, score in self.avg_result.items():
            if metric.lower() in ["ssim", "psnr"]:
                color = "green" if metric.lower() == "ssim" else "yellow"
            elif metric.lower() in ["lpips", "flolpips"]:
                color = "magenta"
            else:
                color = "white"
            print(colored(f"  {metric.upper()}: {score:.4f}", color, attrs=["bold"]))
    
    def save_results(self):
        """
        Saves the per-dataset and average metrics to a JSON file.
        """
        output_dict = {
            "per_dataset": self.result,
            "average": self.avg_result
        }
        
        os.makedirs(os.path.dirname(self.config.result_json_path), exist_ok=True)
        
        with open(self.config.result_json_path, "w") as f:
            json.dump(output_dict, f, indent=4)
        
        print(colored(f"\nResults saved to {self.config.result_json_path}", "green"))


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [on]
Loading model from: /home/maij/miniforge3/envs/sora/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth


In [None]:
# Initialize configuration
config = Config(
    data_root="/home/maij/fall_2024/sora3r/Open-Sora/data/vae_eval_bench",  # Replace with your actual data root path
    model_name="cogvideox",  # Name of your model
    device="cuda:1",  # or "cpu"
    dtype="float16",  # or "bfloat16"
    metrics=["ssim", "psnr", "lpips", "flolpips"],  # List of metrics to calculate
    batch_size=2,
    num_workers=4,
    num_frames=100,
    sample_rate=1,
    resolution=128,
    crop_size=None,
    subset_size=None,
    fvd_method="styleganv",
    output_json="result.json"
)

# Initialize Evaluator
evaluator = Evaluator(config)

# Run evaluation
evaluator.evaluate_all_datasets()
