In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('../uvg/Beauty_PNG_1024/')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [5]:
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate-ft/dcvc_epoch=1_int_psnr=0.4147_bpp=0.0425.pt')
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=3_int_0.05bpp_40.65psnr.pt')
# chpt = torch.load('dcvc-b-frame-with-bitrate/dcvc_epoch=2_int_really_good_missed_one.pt')
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-ft-from-quant-epoch2/dcvc_epoch=1_int_psnr=0.41104_bpp=0.073_ft.pt')
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-ft-from-quant-epoch2/dcvc_epoch=1_batch_3999_avg_psnr=40.991327.pt')
chpt = torch.load('dcvc-b-frame-from-epoch4-from-no-lamb-lambda-2048/dcvc_epoch=2_psnr=0.4061_bpp=0.1407.pt')
video_net.load_state_dict(chpt['model'], strict=True)

# LOADS OLD WEIGHTS
# chpt = torch.load('dcvc-b-frame-convtranspose-fail/dcvc_epoch=4_int.pt')
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder
# video_net.load_state_dict(chpt['model'], strict=False)
# video_net.temporalPriorEncoder = temporalPriorEncoder

video_net = video_net.to(DEVICE)
del chpt


In [6]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        # load two reference frames and the B-frame in the middle
        #TODO: implement making this deterministic
        max_interval = min((len(self) + 2 - i)//2, 7)
        if max_interval == 1:
            interval = 1
        else:
            interval = np.random.randint(1, max_interval)
        
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref1 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + 2*interval}'.zfill(5)
        ref2 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref1, ref2, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 2

ds = UVGDataset(DATA_DIR, crop_size=640)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [7]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [8]:
def test_epoch(model, dl, compress_type):
    mse_criterion = torch.nn.MSELoss()
    model.eval()
    
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type=compress_type)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        
        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            

In [9]:
test_epoch(video_net, dl, compress_type='train_compress')

Avg PSNR: 35.534677, bpp: 0.114048, bpp_mv_y: 0.04983, avg_bpp_mv_z: 0.000902 avg_bpp_y: 0.061805, avg_bpp_z: 0.00151:   1%|          | 1/150 [00:03<07:46,  3.13s/it]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.539254, bpp: 0.11583, bpp_mv_y: 0.050209, avg_bpp_mv_z: 0.000901 avg_bpp_y: 0.063236, avg_bpp_z: 0.001484:   1%|▏         | 2/150 [00:03<04:19,  1.76s/it]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.564295, bpp: 0.11334, bpp_mv_y: 0.049195, avg_bpp_mv_z: 0.000884 avg_bpp_y: 0.061793, avg_bpp_z: 0.001468:   2%|▏         | 3/150 [00:04<03:12,  1.31s/it]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.571661, bpp: 0.113216, bpp_mv_y: 0.049169, avg_bpp_mv_z: 0.000888 avg_bpp_y: 0.061693, avg_bpp_z: 0.001466:   3%|▎         | 4/150 [00:05<02:41,  1.10s/it]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.59496, bpp: 0.115606, bpp_mv_y: 0.049872, avg_bpp_mv_z: 0.000881 avg_bpp_y: 0.063393, avg_bpp_z: 0.00146:   3%|▎         | 5/150 [00:06<02:23,  1.01it/s]  

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.597098, bpp: 0.114997, bpp_mv_y: 0.049761, avg_bpp_mv_z: 0.000885 avg_bpp_y: 0.062883, avg_bpp_z: 0.001469:   4%|▍         | 6/150 [00:07<02:12,  1.09it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.57922, bpp: 0.117666, bpp_mv_y: 0.050499, avg_bpp_mv_z: 0.00089 avg_bpp_y: 0.064804, avg_bpp_z: 0.001473:   5%|▍         | 7/150 [00:07<02:05,  1.14it/s]  

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.57428, bpp: 0.118437, bpp_mv_y: 0.050875, avg_bpp_mv_z: 0.000894 avg_bpp_y: 0.065202, avg_bpp_z: 0.001466:   5%|▌         | 8/150 [00:08<02:00,  1.18it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.556846, bpp: 0.12113, bpp_mv_y: 0.051532, avg_bpp_mv_z: 0.000893 avg_bpp_y: 0.067241, avg_bpp_z: 0.001464:   6%|▌         | 9/150 [00:09<01:56,  1.21it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.547601, bpp: 0.121557, bpp_mv_y: 0.051679, avg_bpp_mv_z: 0.000893 avg_bpp_y: 0.067525, avg_bpp_z: 0.00146:   7%|▋         | 10/150 [00:10<01:53,  1.23it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.556732, bpp: 0.122792, bpp_mv_y: 0.052278, avg_bpp_mv_z: 0.000896 avg_bpp_y: 0.068163, avg_bpp_z: 0.001456:   7%|▋         | 11/150 [00:10<01:51,  1.25it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.556402, bpp: 0.123454, bpp_mv_y: 0.052345, avg_bpp_mv_z: 0.000896 avg_bpp_y: 0.068752, avg_bpp_z: 0.001461:   8%|▊         | 12/150 [00:11<01:49,  1.26it/s]

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.560596, bpp: 0.123625, bpp_mv_y: 0.05241, avg_bpp_mv_z: 0.000896 avg_bpp_y: 0.06886, avg_bpp_z: 0.001459:   9%|▊         | 13/150 [00:12<01:48,  1.26it/s]  

CYR torch.Size([4, 96, 40, 40])


Avg PSNR: 35.560596, bpp: 0.123625, bpp_mv_y: 0.05241, avg_bpp_mv_z: 0.000896 avg_bpp_y: 0.06886, avg_bpp_z: 0.001459:   9%|▊         | 13/150 [00:13<02:20,  1.03s/it]


KeyboardInterrupt: 

In [10]:
test_epoch(video_net, dl, compress_type='full')

Avg PSNR: 35.432825, bpp: 0.090584, bpp_mv_y: 0.052081, avg_bpp_mv_z: 0.000658 avg_bpp_y: 0.036806, avg_bpp_z: 0.001039:  38%|███▊      | 57/150 [00:46<01:15,  1.23it/s]


KeyboardInterrupt: 