In [31]:
import os
import sys
import ffmpeg
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.transforms.functional import resize
from pytorch_msssim import ms_ssim, ssim
from tqdm import tqdm
from PIL import Image
import lpips

dcvc_path = os.path.abspath("/h/lkcai/code/video-perception/DCVC_HEM")
if dcvc_path not in sys.path:
    sys.path.insert(0, dcvc_path)
    
hific_path = os.path.abspath("/h/lkcai/code/video-perception/hific")
if hific_path not in sys.path:
    sys.path.insert(0, hific_path)

thirdparty_repo_path = os.path.abspath("/h/lkcai/code/video-perception/video_quality_metrics")
if thirdparty_repo_path not in sys.path:
    sys.path.insert(0, thirdparty_repo_path)

from hific.compress import prepare_model, prepare_dataloader, \
    compress_and_save, load_and_decompress, compress_and_decompress

from UVG1 import UVG
from ssf_model import ScaleSpaceFlow
from DCVC_HEM.src.models.video_model import DMC
from DCVC_HEM.src.utils.stream_helper import get_padding_size, get_state_dict

from video_quality_metrics.calculate_fvd import calculate_fvd


device = torch.device('cuda' if torch.cuda.is_available else cpu)
!nvidia-smi

Wed Mar 26 15:43:58 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            On   | 00000000:86:00.0 Off |                    0 |
| N/A   42C    P0    28W /  70W |  14519MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [2]:
def load_ssf_model(model, pre_path):
    model.motion_encoder.load_state_dict(torch.load(pre_path + '/m_enc.pth'))
    model.motion_decoder.load_state_dict(torch.load(pre_path + '/m_dec.pth'))
    model.P_encoder.load_state_dict(torch.load(pre_path + '/p_enc.pth'))
    model.res_encoder.load_state_dict(torch.load(pre_path + '/r_enc.pth'))
    model.res_decoder.load_state_dict(torch.load(pre_path + '/r_dec.pth'))
    return model

def hwc_tonp(x):
    x = x.detach().cpu().numpy()
    x = x.transpose([0, 2, 3, 1])
    return x

In [3]:
def PSNR(input1, input2):
    mse = torch.mean((input1 - input2) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))
    return psnr.item()

def MS_SSIM(v1, v2):
    # [0, 1]
    v1 = (v1 + 1) * 0.5
    v2 = (v2 + 1) * 0.5
    return ssim(v1, v2, data_range=1, size_average=True).item()

lpips_vgg = lpips.LPIPS(net='alex').to(device)
def LPIPS(img1, img2):
    # img1 = img1 * 2 - 1
    # img2 = img2 * 2 - 1
    img1 = (img1 + 1) * 0.5
    img2 = (img2 + 1) * 0.5
    with torch.no_grad():
        lpips = lpips_vgg(img1, img2)

    return lpips.mean().item()

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /h/lkcai/anaconda3/envs/video_perc/lib/python3.6/site-packages/lpips/weights/v0.1/alex.pth


In [4]:
model_path = './DCVC_HEM/checkpoints/acmmm2022_video_ssim.pth.tar'
p_frame_y_q_scales, p_frame_mv_y_q_scales = DMC.get_q_scales_from_ckpt(model_path)

p_state_dict = get_state_dict(model_path)
video_net = DMC()
video_net.load_state_dict(p_state_dict)
video_net = video_net.to(device)
video_net.eval()

def run_dcvc_test(x, rate_idx=0):
    x2     = x[:, 1, ...]
    x3     = x[:, 2, ...]
    x1_hat = x[:, 3, ...]
    
    p_frame_y_q_scale = p_frame_y_q_scales[rate_idx]
    p_frame_mv_y_q_scale = p_frame_mv_y_q_scales[rate_idx]

    bbp_2 = []
    bbp_3 = []
    with torch.no_grad():
        h, w = x2.shape[2], x2.shape[3]
        bin_path = None
        dpb = {
            "ref_frame": x1_hat, "ref_feature": None, "ref_y": None, "ref_mv_y": None,
        }
        x2_result = video_net.encode_decode(
            x2, dpb, bin_path,
            pic_height=h, pic_width=w,
            mv_y_q_scale=p_frame_mv_y_q_scale,
            y_q_scale=p_frame_y_q_scale
        )
        dpb = x2_result["dpb"]
        x2_hat = dpb["ref_frame"].clamp_(0, 1)
        bbp_2.append(x2_result['bit'])

        x3_result = video_net.encode_decode(
            x3, dpb, bin_path,
            pic_height=h, pic_width=w,
            mv_y_q_scale=p_frame_mv_y_q_scale,
            y_q_scale=p_frame_y_q_scale
        )
        dpb = x3_result["dpb"]
        x3_hat = dpb["ref_frame"].clamp_(0, 1)
        bbp_3.append(x3_result['bit'])

    avg_bbp_2 = sum(bbp_2) / len(bbp_2)
    avg_bbp_3 = sum(bbp_3) / len(bbp_3)

    return x2_hat, x3_hat, avg_bbp_2, avg_bbp_3

In [17]:
train_transforms = transforms.Compose([
    transforms.ToTensor(), 
    transforms.RandomCrop(256)
])

uvg_dataset = UVG("./data/uvg/", train_transforms)
uvg_dataloader = DataLoader(
    uvg_dataset,
    batch_size=5,
    num_workers=0,
    shuffle=True,
    pin_memory=True,
)

In [6]:
l_AR = 0.08

ssf_JD = ScaleSpaceFlow().to(device)
ssf_JD.load_state_dict(torch.load('./saved_models/vimeo-90k/JD/ssf_uvg_JD.pth'))

ssf_AR = ScaleSpaceFlow().to(device)
ssf_AR = load_ssf_model(ssf_AR, f'./saved_models/vimeo-90k/AR_{l_AR}/')

ssf_MSE = ScaleSpaceFlow().to(device)
ssf_MSE = load_ssf_model(ssf_MSE, f'./saved_models/vimeo-90k/mse/')

In [7]:
hific_model_path = "./saved_models/vimeo-90k/FMD/hific_hi.pt"
hific_log_path = "./hific/log"

fmd, args = prepare_model(hific_model_path, hific_log_path)

15:01:11 INFO - logger_setup: /fs01/home/lkcai/code/video-perception/hific/compress.py


Building prior probability tables...


100%|██████████| 64/64 [00:00<00:00, 250.98it/s]


Setting up Perceptual loss...


15:01:17 INFO - load_model: Loading model ...
15:01:17 INFO - load_model: Estimated model size (under fp32): 725.903 MB
15:01:17 INFO - load_model: Model init 5.602s


Loading model from: /h/lkcai/code/video-perception/hific/src/loss/perceptual_similarity/weights/v0.1/alex.pth
...[net-lin [alex]] initialized
...Done


15:01:17 INFO - prepare_model: Model loaded from disk.
15:01:17 INFO - prepare_model: Building hyperprior probability tables...
100%|██████████| 320/320 [00:00<00:00, 1215.16it/s]
15:01:28 INFO - prepare_model: All tables built.


In [37]:
### Evaluate FVD
mse_eval = {'fvd': []} # {'2_psnr':[], '3_psnr': [], '2_ssim': [], '3_ssim': []}
jd_eval  = {'fvd': []} # {'2_psnr':[], '3_psnr': [], '2_ssim': [], '3_ssim': []}
fmd_eval = {'fvd': []} # {'2_psnr':[], '3_psnr': [], '2_ssim': [], '3_ssim': []}
ar_eval  = {'fvd': []} # {'2_psnr':[], '3_psnr': [], '2_ssim': [], '3_ssim': []}
dcvc_eval = {'fvd': []} # {'2_psnr':[], '3_psnr': [], '2_ssim': [], '3_ssim': []}

dcvc_bbp_x2 = []
dcvc_bbp_x3 = []

length = 10
test_loader = itertools.islice(uvg_dataloader, length)

ssf_JD.eval()
ssf_AR.eval()
ssf_MSE.eval()
fmd.eval()

for data in tqdm(test_loader):
    with torch.no_grad():
        x = data[:, :4, ...].to(device)
        # x1 = 2 * (data[:, 0, ...] - 0.5).to(device) # origin first frame
        x2 = 2 * (data[:, 1, ...] - 0.5).to(device)
        x3 = 2 * (data[:, 2, ...] - 0.5).to(device)
        x1_hat = 2 * (data[:, 3, ...] - 0.5).to(device) # low-rate first frame
        
        x_fvd = 2 * (x[:, 1:3, ...] - 0.5)
        x_fvd = x_fvd.view(1, -1, 3, 256, 256)

        x2_hat_JD = ssf_JD([x1_hat, x2])
        x3_hat_JD = ssf_JD([x2_hat_JD, x3])
        x_fvd_JD = torch.cat([x2_hat_JD.unsqueeze(1), x3_hat_JD.unsqueeze(1)], dim=1)
        x_fvd_JD = x_fvd_JD.view(1, -1, 3, 256, 256)
        
        x2_hat_AR = ssf_AR([x1_hat, x2])
        x3_hat_AR = ssf_AR([x2_hat_AR, x3])
        x_fvd_AR = torch.cat([x2_hat_AR.unsqueeze(1), x3_hat_AR.unsqueeze(1)], dim=1)
        x_fvd_AR = x_fvd_AR.view(1, -1, 3, 256, 256)
        
        x2_hat_MSE = ssf_MSE([x1_hat, x2])
        x3_hat_MSE = ssf_MSE([x2_hat_MSE, x3])
        x_fvd_MSE = torch.cat([x2_hat_MSE.unsqueeze(1), x3_hat_MSE.unsqueeze(1)], dim=1)
        x_fvd_MSE = x_fvd_MSE.view(1, -1, 3, 256, 256)    

        x2_hat_FMD, _ = fmd(x2)
        x3_hat_FMD, _ = fmd(x3)
        x_fvd_FMD = torch.cat([x2_hat_FMD.unsqueeze(1), x3_hat_FMD.unsqueeze(1)], dim=1)
        x_fvd_FMD = x_fvd_FMD.view(1, -1, 3, 256, 256)

        x2_hat_dcvc, x3_hat_dcvc, bbp_x2, bbp_x3 = run_dcvc_test(x, rate_idx=3)
        dcvc_bbp_x2.append(bbp_x2)
        dcvc_bbp_x3.append(bbp_x3)
        x_fvd_dcvc = torch.cat([x2_hat_dcvc.unsqueeze(1), x3_hat_dcvc.unsqueeze(1)], dim=1)
        x_fvd_dcvc = x_fvd_dcvc.view(1, -1, 3, 256, 256)

        mse_eval['fvd'].append(calculate_fvd(x_fvd, x_fvd_MSE, device, only_final=True)['value'])
        jd_eval['fvd'].append(calculate_fvd(x_fvd, x_fvd_JD, device, only_final=True)['value'])
        fmd_eval['fvd'].append(calculate_fvd(x_fvd, 2 * x_fvd_FMD - 1., device, only_final=True)['value'])
        ar_eval['fvd'].append(calculate_fvd(x_fvd, x_fvd_AR, device, only_final=True)['value'])
        dcvc_eval['fvd'].append(calculate_fvd(x_fvd, 2 * x_fvd_dcvc - 1., device, only_final=True)['value'])


def compute_stats(eval_dict):
    stats = {}
    for key, values in eval_dict.items():
        values = np.array(values)
        mean_val = np.mean(values)
        std_val = np.std(values)
        max_val = np.max(values)
        min_val = np.min(values)
        stats[key] = (mean_val, std_val, max_val, min_val)  # Normalize by length
    return stats

# Compute statistics for each evaluation method
mse_stats = compute_stats(mse_eval)
jd_stats = compute_stats(jd_eval)
fmd_stats = compute_stats(fmd_eval)
ar_stats = compute_stats(ar_eval)
dcvc_stats = compute_stats(dcvc_eval)

# Print formatted results
print('FVD:    ' + 
      f'MSE {mse_stats["fvd"][0]:.4f} | ' + 
      f'JD {jd_stats["fvd"][0]:.4f} | ' +
      f'FMD {fmd_stats["fvd"][0]:.4f} | ' +
      f'AR {ar_stats["fvd"][0]:.4f} | ' +
      f'DCVC {dcvc_stats["fvd"][0]:.4f}')

# print('Second frame MS-SSIM: ' + # f'MSE {mse_stats["2_ssim"][0]:.4f}±{mse_stats["2_ssim"][1]:.4f} | ' +
#       f'JD {jd_stats["2_ssim"][0]:.4f} | ' +
#       f'FMD {fmd_stats["2_ssim"][0]:.4f} | ' +
#       f'AR {ar_stats["2_ssim"][0]:.4f} | ' +
#       f'DCVC {dcvc_stats["2_ssim"][0]:.4f}')

# print('Third frame PSNR:     ' + # f'MSE {mse_stats["3_psnr"][0]:.4f}±{mse_stats["3_psnr"][1]:.4f} | ' +
#       f'JD {jd_stats["3_psnr"][0]:.4f} | ' +
#       f'FMD {fmd_stats["3_psnr"][0]:.4f} | ' +
#       f'AR {ar_stats["3_psnr"][0]:.4f} | ' +
#       f'DCVC {dcvc_stats["3_psnr"][0]:.4f}')

# print('Third frame MS-SSIM:  ' + # f'MSE {mse_stats["3_ssim"][0]:.4f}±{mse_stats["3_ssim"][1]:.4f} | ' +
#       f'JD {jd_stats["3_ssim"][0]:.4f} | ' +
#       f'FMD {fmd_stats["3_ssim"][0]:.4f} | ' +
#       f'AR {ar_stats["3_ssim"][0]:.4f} | ' +
#       f'DCVC {dcvc_stats["3_ssim"][0]:.4f}')


10it [03:23, 20.37s/it]

FVD:    MSE 56.4959 | JD 690.8981 | FMD 65.5926 | AR 35.1941 | DCVC 76.0244



