In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('../uvg/Beauty_PNG_1024/')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [5]:
# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
chpt = torch.load('b_frame_lamba=2048-old-continue/dcvc_epoch=5_int.pt')
video_net.load_state_dict(chpt['model'], strict=True)

# LOADS OLD WEIGHTS
# chpt = torch.load('dcvc-b-frame-convtranspose-fail/dcvc_epoch=4_int.pt')
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder
# video_net.load_state_dict(chpt['model'], strict=False)
# video_net.temporalPriorEncoder = temporalPriorEncoder

video_net = video_net.to(DEVICE)
del chpt

In [6]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        # load two reference frames and the B-frame in the middle
        #TODO: implement making this deterministic
        max_interval = min((len(self) + 2 - i)//2, 7)
        if max_interval == 1:
            interval = 1
        else:
            interval = np.random.randint(1, max_interval)
        
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref1 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + 2*interval}'.zfill(5)
        ref2 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref1, ref2, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 2

ds = UVGDataset(DATA_DIR, crop_size=640)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [7]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [8]:
def test_epoch(model, dl, compress_type):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type=compress_type)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        
        mse_ls = mse_criterion(preds, im)
        mse_meter.add(mse_ls.item())
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = round(-10.0*np.log10(mse_meter.value()[0]), 6)
            avg_bpp = round(bpp_meter.value()[0], 6)
            avg_bpp_mv_y = round(bpp_mv_y_meter.value()[0], 6)
            avg_bpp_mv_z = round(bpp_mv_z_meter.value()[0], 6)
            avg_bpp_y = round(bpp_y_meter.value()[0], 6)
            avg_bpp_z = round(bpp_z_meter.value()[0], 6)
            msg = (
                f'Avg PSNR: {avg_psnr}, bpp: {avg_bpp}, bpp_mv_y: {avg_bpp_mv_y}, avg_bpp_mv_z: {avg_bpp_mv_z} '
                f'avg_bpp_y: {avg_bpp_y}, avg_bpp_z: {avg_bpp_z}'
            )
            pbar.set_description(msg)
            

In [13]:
test_epoch(video_net, dl, compress_type='train_compress')

Avg PSNR: 35.790386, bpp: 0.11227, bpp_mv_y: 0.046325, avg_bpp_mv_z: 0.001131 avg_bpp_y: 0.064795, avg_bpp_z: 1.9e-05:   7%|▋         | 11/150 [00:16<03:30,  1.52s/it]
Traceback (most recent call last):
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/jatin/miniconda3/envs/DCVC/

KeyboardInterrupt: 

In [9]:
test_epoch(video_net, dl, compress_type='full')

Avg PSNR: 35.649426, bpp: 0.063976, bpp_mv_y: 0.02787, avg_bpp_mv_z: 0.000829 avg_bpp_y: 0.035277, avg_bpp_z: 0.0:  10%|█         | 15/150 [00:20<03:07,  1.39s/it] 
Traceback (most recent call last):
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/jatin/miniconda3/envs/DCVC/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/jatin/miniconda3/envs/DCVC/lib

KeyboardInterrupt: 

In [9]:
test_epoch(video_net, dl, compress_type='train_compress')

Avg PSNR/MSE/Bitrate: (35.344456, 0.000292, 0.120522):   7%|▋         | 10/150 [00:14<03:21,  1.44s/it]


KeyboardInterrupt: 

In [12]:
test_epoch(video_net, dl, compress_type='full')

Avg PSNR/MSE/Bitrate: (35.180333, 0.000303, 0.066224):   3%|▎         | 5/150 [00:08<03:52,  1.60s/it]


RuntimeError: CUDA out of memory. Tried to allocate 400.00 MiB (GPU 0; 9.78 GiB total capacity; 7.40 GiB already allocated; 175.62 MiB free; 8.14 GiB reserved in total by PyTorch)