In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('../uvg/Beauty_PNG_1024/')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net()

In [5]:
# # load the good weights
# video_net.opticFlow = torch.load('../DCVC/optflow.pth')
# video_net.mvEncoder = torch.load('../DCVC/mvenc.pth')
# video_net.mvDecoder_part1 = torch.load('../DCVC/mvDecoder_part1.pth')
# video_net.mvDecoder_part2 = torch.load('../DCVC/mvDecoder_part2.pth')
# video_net.feature_extract = torch.load('../DCVC/feature_extract.pth')
# video_net.context_refine = torch.load('../DCVC/context_refine.pth')
# # video_net.contextualDecoder_part1 = torch.load('../DCVC/contextualDecoder_part1.pth')

# video_net.load_state_dict(torch.load('checkpoints/model_dcvc_quality_3_psnr.pth'))
# chpt = torch.load('reproduce-dcvc-lamb-2048/dcvc_epoch=5_avg_loss=0.3576670109304032.pt')
chpt = torch.load('reproduce-dcvc-lamb-2048-ft/dcvc_epoch=2_psnr=39.108_bpp=0.0807.pt')
video_net.load_state_dict(chpt['model'])
# print('Reminder to load optimizer too')
# # optimizer = torch.optim.Adam(video_net.parameters(), lr=wandb.config.learning_rate)

# video_net.opticFlow.requires_grad_ = False
# video_net.mvEncoder.requires_grad_ = False
# video_net.mvDecoder_part1.requires_grad_ = False
# video_net.mvDecoder_part2.requires_grad_ = False
# video_net.feature_extract.requires_grad_ = False
# video_net.context_refine.requires_grad_ = False
# # video_net.contextualDecoder_part1.requires_grad_ = False
del chpt

In [6]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        interval = 1
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 1

ds = UVGDataset(DATA_DIR, crop_size=640)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [7]:
video_net = video_net.to(DEVICE)

In [8]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 7944448 trainable parameters


In [9]:
def test_epoch(model, dl, compress):
    model.eval()
    mse_criterion = torch.nn.MSELoss()
    
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref = x[:,0]
        im = x[:,1]
        with torch.no_grad():
            preds_dict = model(ref, im, compress_type=compress)
        preds = preds_dict['recon_image']

        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(preds_dict['bpp'].item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            
    return avg_loss
            

In [10]:
test_epoch(video_net, dl, compress='train_compress')

Avg PSNR: 34.979436, bpp: 0.09029, bpp_mv_y: 0.036449, avg_bpp_mv_z: 0.00088 avg_bpp_y: 0.051925, avg_bpp_z: 0.001035:  17%|█▋        | 26/150 [00:18<01:26,  1.43it/s]  


KeyboardInterrupt: 

In [11]:
test_epoch(video_net, dl, compress='full')

Avg PSNR: 34.865312, bpp: 0.062524, bpp_mv_y: 0.024114, avg_bpp_mv_z: 0.000605 avg_bpp_y: 0.037002, avg_bpp_z: 0.000804:  31%|███       | 46/150 [00:29<01:07,  1.54it/s]


KeyboardInterrupt: 

In [18]:
test_epoch(video_net, dl, compress=True)

  0%|          | 0/150 [00:00<?, ?it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.685461, bpp: 0.076247, bpp_mv_y: 0.028536, avg_bpp_mv_z: 0.000756 avg_bpp_y: 0.045754, avg_bpp_z: 0.001201:   1%|          | 1/150 [00:01<03:19,  1.34s/it]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.856819, bpp: 0.071894, bpp_mv_y: 0.027246, avg_bpp_mv_z: 0.000764 avg_bpp_y: 0.04269, avg_bpp_z: 0.001193:   1%|▏         | 2/150 [00:01<02:15,  1.09it/s] 

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.852304, bpp: 0.072634, bpp_mv_y: 0.027474, avg_bpp_mv_z: 0.000759 avg_bpp_y: 0.043209, avg_bpp_z: 0.001193:   2%|▏         | 3/150 [00:02<01:54,  1.28it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.824212, bpp: 0.072619, bpp_mv_y: 0.027499, avg_bpp_mv_z: 0.000757 avg_bpp_y: 0.043156, avg_bpp_z: 0.001207:   3%|▎         | 4/150 [00:03<01:44,  1.39it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.837625, bpp: 0.071832, bpp_mv_y: 0.027266, avg_bpp_mv_z: 0.000755 avg_bpp_y: 0.042601, avg_bpp_z: 0.00121:   3%|▎         | 5/150 [00:03<01:38,  1.47it/s] 

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.85089, bpp: 0.071543, bpp_mv_y: 0.027262, avg_bpp_mv_z: 0.000754 avg_bpp_y: 0.042318, avg_bpp_z: 0.001209:   4%|▍         | 6/150 [00:04<01:34,  1.52it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.830057, bpp: 0.072321, bpp_mv_y: 0.027411, avg_bpp_mv_z: 0.000754 avg_bpp_y: 0.042948, avg_bpp_z: 0.001208:   5%|▍         | 7/150 [00:05<01:32,  1.55it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.81649, bpp: 0.071677, bpp_mv_y: 0.027163, avg_bpp_mv_z: 0.000757 avg_bpp_y: 0.042542, avg_bpp_z: 0.001216:   5%|▌         | 8/150 [00:05<01:30,  1.57it/s] 

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.827356, bpp: 0.07162, bpp_mv_y: 0.027128, avg_bpp_mv_z: 0.000758 avg_bpp_y: 0.042515, avg_bpp_z: 0.001219:   6%|▌         | 9/150 [00:06<01:28,  1.59it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.831966, bpp: 0.071406, bpp_mv_y: 0.027117, avg_bpp_mv_z: 0.000758 avg_bpp_y: 0.042317, avg_bpp_z: 0.001214:   7%|▋         | 10/150 [00:06<01:27,  1.60it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.823271, bpp: 0.071729, bpp_mv_y: 0.027201, avg_bpp_mv_z: 0.000758 avg_bpp_y: 0.042557, avg_bpp_z: 0.001213:   7%|▋         | 11/150 [00:07<01:26,  1.61it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.825312, bpp: 0.071761, bpp_mv_y: 0.027219, avg_bpp_mv_z: 0.000758 avg_bpp_y: 0.042569, avg_bpp_z: 0.001215:   8%|▊         | 12/150 [00:08<01:25,  1.61it/s]

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.813511, bpp: 0.071872, bpp_mv_y: 0.027224, avg_bpp_mv_z: 0.00076 avg_bpp_y: 0.042669, avg_bpp_z: 0.001219:   9%|▊         | 13/150 [00:08<01:24,  1.61it/s] 

torch.Size([4, 96, 40, 40])


Avg PSNR: 35.806097, bpp: 0.07213, bpp_mv_y: 0.027322, avg_bpp_mv_z: 0.000759 avg_bpp_y: 0.042831, avg_bpp_z: 0.001218:   9%|▉         | 14/150 [00:09<01:24,  1.61it/s]

torch.Size([4, 96, 40, 40])


Traceback (most recent call last):
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Avg PSNR: 35.806097, bpp: 0.07213, bpp_mv_y: 0.027322, avg_bpp_mv_z: 0.000759 avg_bpp_y: 0.042831, avg_bpp_z: 0.001218:   9%|▉         | 14/150 [00:09<01:37,  1.40it/s]


KeyboardInterrupt: 

In [19]:
test_epoch(video_net, dl, compress=False)

Avg PSNR: 35.857873, bpp: 0.068798, bpp_mv_y: 0.025938, avg_bpp_mv_z: 0.000754 avg_bpp_y: 0.040869, avg_bpp_z: 0.001236:  14%|█▍        | 21/150 [00:14<01:28,  1.46it/s]


KeyboardInterrupt: 