In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIRS = [
    pathlib.Path('../uvg/Beauty_PNG_1024/'),
    pathlib.Path('../uvg/Bosphorus_PNG_1024/'),
    pathlib.Path('../uvg/HoneyBee_PNG_1024/'),
    pathlib.Path('../uvg/Jockey_PNG_1024/'),
    pathlib.Path('../uvg/ReadySetGo_PNG_1024/'),
    pathlib.Path('../uvg/ShakeNDry_PNG_1024/'),
    pathlib.Path('../uvg/YachtRide_PNG_1024/'),
]
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [5]:
video_net = DCVC_net()

In [22]:
# chpt = torch.load('checkpoints/model_dcvc_quality_3_psnr.pth')
chpt = torch.load('reproduce-dcvc-lamb-2048-ft/dcvc_epoch=2_psnr=39.108_bpp=0.0807.pt')
video_net.load_state_dict(chpt['model'], strict=True)

# LOADS OLD WEIGHTS
# chpt = torch.load('dcvc-b-frame-convtranspose-fail/dcvc_epoch=4_int.pt')
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder
# video_net.load_state_dict(chpt['model'], strict=False)
# video_net.temporalPriorEncoder = temporalPriorEncoder

video_net = video_net.to(DEVICE)
del chpt

In [23]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        interval = 1
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 1

all_uvg_dl = []
for d in DATA_DIRS:
    ds = UVGDataset(d, crop_size=640)
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=True,
        batch_size=BATCH_SIZE,
        num_workers=6,
        prefetch_factor=5
    )
    all_uvg_dl.append(dl)


In [24]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 7944448 trainable parameters


In [25]:
def test_epoch(model, dl, compress_type, max_iters=None):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref = x[:,0]
        im = x[:,1]
        with torch.no_grad():
            preds_dict = model(ref, im, compress_type=compress_type)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        
        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            
        if max_iters is not None and i == max_iters:
            break
    
    avg_psnr = -10.0*np.log10(mse_meter.value()[0])
    avg_bpp = bpp_meter.value()[0]                     
    return avg_psnr, avg_bpp
            

In [26]:
# US
psnrs = []
bpps = []
for p, dl in zip(DATA_DIRS, all_uvg_dl):
    print("RUNNING:", p)
    psnr, bpp = test_epoch(video_net, dl, compress_type='full', max_iters=10)
    psnrs.append(psnr)
    bpps.append(bpp)

RUNNING: ../uvg/Beauty_PNG_1024


Avg PSNR: 34.863215, bpp: 0.063536, bpp_mv_y: 0.024459, avg_bpp_mv_z: 0.000602 avg_bpp_y: 0.037669, avg_bpp_z: 0.000805:   7%|▋         | 10/150 [00:07<01:46,  1.32it/s]


RUNNING: ../uvg/Bosphorus_PNG_1024


Avg PSNR: 40.716065, bpp: 0.026453, bpp_mv_y: 0.010586, avg_bpp_mv_z: 0.000608 avg_bpp_y: 0.01434, avg_bpp_z: 0.000918:   7%|▋         | 10/150 [00:07<01:46,  1.31it/s] 


RUNNING: ../uvg/HoneyBee_PNG_1024


Avg PSNR: 37.39349, bpp: 0.01756, bpp_mv_y: 0.011401, avg_bpp_mv_z: 0.000514 avg_bpp_y: 0.005114, avg_bpp_z: 0.000532:   7%|▋         | 10/150 [00:07<01:48,  1.29it/s] 


RUNNING: ../uvg/Jockey_PNG_1024


Avg PSNR: 33.287528, bpp: 0.121161, bpp_mv_y: 0.044937, avg_bpp_mv_z: 0.001129 avg_bpp_y: 0.073995, avg_bpp_z: 0.0011:   7%|▋         | 10/150 [00:07<01:46,  1.31it/s]  


RUNNING: ../uvg/ReadySetGo_PNG_1024


Avg PSNR: 37.343446, bpp: 0.117792, bpp_mv_y: 0.029257, avg_bpp_mv_z: 0.00127 avg_bpp_y: 0.085918, avg_bpp_z: 0.001347:   7%|▋         | 10/150 [00:07<01:46,  1.32it/s] 


RUNNING: ../uvg/ShakeNDry_PNG_1024


Avg PSNR: 35.433665, bpp: 0.199458, bpp_mv_y: 0.052192, avg_bpp_mv_z: 0.001112 avg_bpp_y: 0.144622, avg_bpp_z: 0.001532:  13%|█▎        | 10/75 [00:07<00:50,  1.29it/s]


RUNNING: ../uvg/YachtRide_PNG_1024


Avg PSNR: 38.318923, bpp: 0.099404, bpp_mv_y: 0.026996, avg_bpp_mv_z: 0.000942 avg_bpp_y: 0.069715, avg_bpp_z: 0.001751:   7%|▋         | 10/150 [00:07<01:48,  1.29it/s]


In [27]:
np.mean(psnrs), np.mean(bpps)

(36.76519039425703, 0.09219490400479212)

In [19]:
psnrs = []
bpps = []
for p, dl in zip(DATA_DIRS, all_uvg_dl):
    print("RUNNING:", p)
    psnr, bpp = test_epoch(video_net, dl, compress_type='full', max_iters=10)
    psnrs.append(psnr)
    bpps.append(bpp)

RUNNING: ../uvg/Beauty_PNG_1024


Avg PSNR: 35.721022, bpp: 0.072026, bpp_mv_y: 0.026973, avg_bpp_mv_z: 0.000756 avg_bpp_y: 0.04308, avg_bpp_z: 0.001217:   7%|▋         | 10/150 [00:07<01:49,  1.28it/s]


RUNNING: ../uvg/Bosphorus_PNG_1024


Avg PSNR: 41.192818, bpp: 0.024123, bpp_mv_y: 0.009814, avg_bpp_mv_z: 0.000766 avg_bpp_y: 0.012507, avg_bpp_z: 0.001035:   7%|▋         | 10/150 [00:07<01:45,  1.33it/s]


RUNNING: ../uvg/HoneyBee_PNG_1024


Avg PSNR: 38.025866, bpp: 0.019324, bpp_mv_y: 0.011802, avg_bpp_mv_z: 0.000674 avg_bpp_y: 0.005931, avg_bpp_z: 0.000916:   7%|▋         | 10/150 [00:07<01:47,  1.30it/s]


RUNNING: ../uvg/Jockey_PNG_1024


Avg PSNR: 38.360225, bpp: 0.07145, bpp_mv_y: 0.021946, avg_bpp_mv_z: 0.000815 avg_bpp_y: 0.047631, avg_bpp_z: 0.001057:   7%|▋         | 10/150 [00:07<01:45,  1.32it/s] 


RUNNING: ../uvg/ReadySetGo_PNG_1024


Avg PSNR: 39.914104, bpp: 0.077235, bpp_mv_y: 0.019839, avg_bpp_mv_z: 0.001058 avg_bpp_y: 0.054994, avg_bpp_z: 0.001344:   7%|▋         | 10/150 [00:07<01:46,  1.32it/s]


RUNNING: ../uvg/ShakeNDry_PNG_1024


Avg PSNR: 36.971907, bpp: 0.223654, bpp_mv_y: 0.042442, avg_bpp_mv_z: 0.001081 avg_bpp_y: 0.178626, avg_bpp_z: 0.001505:  13%|█▎        | 10/75 [00:07<00:50,  1.28it/s]


RUNNING: ../uvg/YachtRide_PNG_1024


Avg PSNR: 39.980588, bpp: 0.122437, bpp_mv_y: 0.024519, avg_bpp_mv_z: 0.001016 avg_bpp_y: 0.09509, avg_bpp_z: 0.001812:   7%|▋         | 10/150 [00:07<01:46,  1.32it/s]


In [20]:
# THEM
psnrs = [35.721022, 41.192818, 38.025866, 38.360225, 39.914104, 36.971907, 39.980588]

bpps = [0.072026, 0.024123, 0.019324, 0.07145, 0.077235, 0.223654, 0.122437]

np.mean(psnrs), np.mean(bpps)

(38.595218571428575, 0.08717842857142857)