In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('/data2/jatin/vimeo_septuplet/sequences')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [15]:
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-from_epoch4/dcvc_epoch=1_int.pt')
chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
video_net.load_state_dict(chpt['model'])
video_net = video_net.to(DEVICE)
del chpt

In [16]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=256, make_b_cut=True, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir):
            subseq = os.listdir(self.data_dir / seq)
            for s in subseq:
                self.all_paths.append(self.data_dir / seq / s)
        assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        else:
            # load full sequence
            for i in range(1, 8):
                # should be between [0, 1]
                img = plt.imread(path / f'im{i}.png')
        
        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [17]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [22]:
def test_epoch(model, dl):
    mse_criterion = torch.nn.MSELoss()
    model.eval() # this seems to make a difference
    epoch_mse = 0
    epoch_bitrate = 0
    epoch_bpp_mv_y = 0
    epoch_bpp_mv_z = 0
    epoch_bpp_y = 0
    epoch_bpp_z = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type='train_compress')
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        mse_loss = mse_criterion(preds, im)

        epoch_mse += mse_loss.item()
        epoch_bitrate += bpp.item()
        epoch_bpp_mv_y += preds_dict['bpp_mv_y'].item()
        epoch_bpp_mv_z += preds_dict['bpp_mv_z'].item()
        epoch_bpp_y += preds_dict['bpp_y'].item()
        epoch_bpp_z += preds_dict['bpp_z'].item()
        total_count += 1
        if i % 1 == 0:
            avg_mse_loss = epoch_mse / total_count
            avg_bitrate = epoch_bitrate / total_count
            avg_bpp_mv_y = round(epoch_bpp_mv_y / total_count, 6)
            avg_bpp_mv_z = round(epoch_bpp_mv_z / total_count, 6)
            avg_bpp_y = round(epoch_bpp_y / total_count, 6)
            avg_bpp_z = round(epoch_bpp_z / total_count, 6)
            avg_psnr = -10.0*np.log10(avg_mse_loss)
            avg_bitrate = round(avg_bitrate, 6)
            avg_psnr = round(avg_psnr, 6)
            pbar.set_description(f'Avg PSNR/Bitrate/MV Y/MV Z/Y/Z: {avg_psnr, avg_bitrate, avg_bpp_mv_y, avg_bpp_mv_z, avg_bpp_y, avg_bpp_z}')
            
            

In [23]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate/MV Y/MV Z/Y/Z: (40.078138, 0.075693, 0.02275, 0.001542, 0.051352, 5e-05):   0%|          | 68/22926 [00:14<1:19:14,  4.81it/s]   


KeyboardInterrupt: 

In [9]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate/MV Y/MV Z/Y/Z: (42.983657, 0.021004, 0.020142, 0.000788, 0.0, 7.3e-05):   0%|          | 42/22926 [00:09<1:23:36,  4.56it/s]


KeyboardInterrupt: 

In [12]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate/MV Y/MV Z/Y/Z: (39.240036, 0.192805, 0.061298, 0.004048, 0.124781, 0.002678):   0%|          | 55/22926 [00:12<1:27:43,  4.35it/s]


KeyboardInterrupt: 

In [11]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate: (44.391822, 0.025431):   0%|          | 13/22926 [00:06<2:56:29,  2.16it/s]


KeyboardInterrupt: 

In [10]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate: (29.594852, 0.012332):   0%|          | 54/22926 [00:15<1:52:01,  3.40it/s]


KeyboardInterrupt: 