In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIRS = [
    pathlib.Path('../uvg/Beauty_PNG_1024/'),
    pathlib.Path('../uvg/Bosphorus_PNG_1024/'),
    pathlib.Path('../uvg/HoneyBee_PNG_1024/'),
    pathlib.Path('../uvg/Jockey_PNG_1024/'),
    pathlib.Path('../uvg/ReadySetGo_PNG_1024/'),
    pathlib.Path('../uvg/ShakeNDry_PNG_1024/'),
    pathlib.Path('../uvg/YachtRide_PNG_1024/'),
]
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [5]:
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
# chpt = torch.load('b_frame_lamba=2048-old-continue/dcvc_epoch=5_int.pt')
chpt = torch.load('b_frame_lamba=2048-old-continue/dcvc_epoch=2_avg_loss=0.28391063538771005.pt')
video_net.load_state_dict(chpt['model'], strict=True)

# LOADS OLD WEIGHTS
# chpt = torch.load('dcvc-b-frame-convtranspose-fail/dcvc_epoch=4_int.pt')
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder
# video_net.load_state_dict(chpt['model'], strict=False)
# video_net.temporalPriorEncoder = temporalPriorEncoder

video_net = video_net.to(DEVICE)
del chpt

In [6]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        # load two reference frames and the B-frame in the middle
        #TODO: implement making this deterministic
        max_interval = min((len(self) + 2 - i)//2, 7)
        if max_interval == 1:
            interval = 1
        else:
            interval = np.random.randint(1, max_interval)
        
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref1 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + 2*interval}'.zfill(5)
        ref2 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref1, ref2, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 2

all_uvg_dl = []
for d in DATA_DIRS:
    ds = UVGDataset(d, crop_size=640)
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=True,
        batch_size=BATCH_SIZE,
        num_workers=6,
        prefetch_factor=5
    )
    all_uvg_dl.append(dl)


In [7]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [8]:
def test_epoch(model, dl, compress_type, max_iters=None):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type=compress_type)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        
        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            
        if max_iters is not None and i == max_iters:
            break
    
    avg_psnr = -10.0*np.log10(mse_meter.value()[0])
    avg_bpp = bpp_meter.value()[0]                     
    return avg_psnr, avg_bpp
            

In [9]:
psnrs = []
bpps = []
for p, dl in zip(DATA_DIRS, all_uvg_dl):
    print("RUNNING:", p)
    psnr, bpp = test_epoch(video_net, dl, compress_type='full', max_iters=10)
    psnrs.append(psnr)
    bpps.append(bpp)

RUNNING: ../uvg/Beauty_PNG_1024


Avg PSNR: 35.559249, bpp: 0.065763, bpp_mv_y: 0.027618, avg_bpp_mv_z: 0.000859 avg_bpp_y: 0.037287, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:15<03:35,  1.54s/it]


RUNNING: ../uvg/Bosphorus_PNG_1024


Avg PSNR: 40.365234, bpp: 0.046341, bpp_mv_y: 0.018727, avg_bpp_mv_z: 0.001083 avg_bpp_y: 0.026531, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:15<03:30,  1.51s/it]


RUNNING: ../uvg/HoneyBee_PNG_1024


Avg PSNR: 38.318393, bpp: 0.013457, bpp_mv_y: 0.009858, avg_bpp_mv_z: 0.00074 avg_bpp_y: 0.002859, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:15<03:31,  1.51s/it] 


RUNNING: ../uvg/Jockey_PNG_1024


Avg PSNR: 37.393785, bpp: 0.133821, bpp_mv_y: 0.043777, avg_bpp_mv_z: 0.001152 avg_bpp_y: 0.088892, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:14<03:26,  1.47s/it]


RUNNING: ../uvg/ReadySetGo_PNG_1024


Avg PSNR: 38.071369, bpp: 0.18464, bpp_mv_y: 0.039889, avg_bpp_mv_z: 0.001481 avg_bpp_y: 0.14327, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:14<03:25,  1.47s/it]  


RUNNING: ../uvg/ShakeNDry_PNG_1024


Avg PSNR: 37.10105, bpp: 0.174979, bpp_mv_y: 0.035349, avg_bpp_mv_z: 0.001266 avg_bpp_y: 0.138364, avg_bpp_z: 0.0:  13%|█▎        | 10/75 [00:14<01:35,  1.48s/it]


RUNNING: ../uvg/YachtRide_PNG_1024


Avg PSNR: 39.00695, bpp: 0.179313, bpp_mv_y: 0.042162, avg_bpp_mv_z: 0.001369 avg_bpp_y: 0.135781, avg_bpp_z: 0.0:   7%|▋         | 10/150 [00:14<03:28,  1.49s/it] 


In [10]:
# more training
np.mean(psnrs), np.mean(bpps)

(37.97371873251983, 0.11404493726886712)

In [10]:
# old
np.mean(psnrs), np.mean(bpps)

(38.00560484809568, 0.11046517900396866)