In [1]:
%load_ext autoreload
%autoreload 2

In [18]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

In [5]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('/data2/jatin/vimeo_septuplet/sequences')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [6]:
video_net = DCVC_net()

In [8]:
video_net.load_state_dict(torch.load('checkpoints/model_dcvc_quality_3_psnr.pth'))
video_net = video_net.to(DEVICE)

In [9]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=256, make_p_cut=False, make_b_cut=True, deterministic=False):
        if make_b_cut and make_p_cut:
            raise ValueError('Can only choose one of B-frames or P-frames')
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.make_p_cut = make_p_cut
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir):
            subseq = os.listdir(self.data_dir / seq)
            for s in subseq:
                self.all_paths.append(self.data_dir / seq / s)
        assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        elif self.make_p_cut:
            ref = plt.imread(path / f'im1.png')
            im = plt.imread(path / f'im2.png')
            imgs = [ref, im]
        else:
            # load full sequence
            for i in range(1, 8):
                # should be between [0, 1]
                img = plt.imread(path / f'im{i}.png')
        
        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR, make_p_cut=True, make_b_cut=False)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5,
)

In [11]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 7944448 trainable parameters


In [19]:
def test_epoch(model, dl, compress):
    mse_criterion = torch.nn.MSELoss()
    model.eval() # this seems to make a difference
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref = x[:,0]
        im = x[:,1]
        with torch.no_grad():
            preds_dict = model(ref, im, compress=compress)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        mse_loss = mse_criterion(preds, im)

        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            

In [20]:
test_epoch(video_net, dl, compress=True)

Avg PSNR: 41.433961, bpp: 0.061885, bpp_mv_y: 0.015289, avg_bpp_mv_z: 0.00101 avg_bpp_y: 0.044273, avg_bpp_z: 0.001313:   1%|          | 158/22926 [00:26<1:04:13,  5.91it/s] 


KeyboardInterrupt: 

In [15]:
test_epoch(video_net, dl, compress=False)

Avg PSNR/Bitrate: (41.826272, 0.057603):   4%|▍         | 1004/22926 [02:00<43:56,  8.31it/s] 


KeyboardInterrupt: 

In [10]:
test_epoch(video_net, dl, compress=True)

Avg PSNR/Bitrate: (41.423817, 0.062252):   7%|▋         | 1566/22926 [03:00<41:08,  8.65it/s]  


KeyboardInterrupt: 

In [None]:
video_net.eval()