In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('../uvg/Beauty_PNG_1024/')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [5]:
chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
video_net.load_state_dict(chpt['model'], strict=True)

# LOADS OLD WEIGHTS
# chpt = torch.load('dcvc-b-frame-convtranspose-fail/dcvc_epoch=4_int.pt')
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder
# video_net.load_state_dict(chpt['model'], strict=False)
# video_net.temporalPriorEncoder = temporalPriorEncoder

video_net = video_net.to(DEVICE)
del chpt

In [6]:
class UVGDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=None, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.deterministic = deterministic
        self.numframes = len(os.listdir(data_dir))
        if crop_size is not None:
            self.transforms = torch.nn.Sequential(
                transforms.CenterCrop(crop_size)
            )
        else:
            self.transforms = None
       
    def __getitem__(self, i):
        # load two reference frames and the B-frame in the middle
        #TODO: implement making this deterministic
        max_interval = min((len(self) + 2 - i)//2, 7)
        if max_interval == 1:
            interval = 1
        else:
            interval = np.random.randint(1, max_interval)
        
        # this is 1-indexed on disk
        i += 1
        int_as_str = f'{i}'.zfill(5)
        ref1 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + interval}'.zfill(5)
        im = plt.imread(self.data_dir / f'im{int_as_str}.png')
        int_as_str = f'{i + 2*interval}'.zfill(5)
        ref2 = plt.imread(self.data_dir / f'im{int_as_str}.png')
        imgs = [ref1, ref2, im]

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        imgs = torch.FloatTensor(imgs)
        if self.transforms:
            imgs = self.transforms(imgs)
        return imgs

    def __len__(self):
        return self.numframes - 2

ds = UVGDataset(DATA_DIR, crop_size=640)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [7]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [11]:
def test_epoch(model, dl, compress_type):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    epoch_mse = 0
    epoch_bitrate = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type=compress_type)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        mse_loss = mse_criterion(preds, im)
        epoch_mse += mse_loss.item()
        epoch_bitrate += bpp.item()
        total_count += 1
        if i % 1 == 0:
            avg_mse = epoch_mse / total_count
            avg_psnr = -10.0*np.log10(avg_mse)
            avg_bitrate = epoch_bitrate / total_count
            avg_mse = round(avg_mse, 6)
            avg_psnr = round(avg_psnr, 6)
            avg_bitrate = round(avg_bitrate, 6)
            pbar.set_description(f'Avg PSNR/MSE/Bitrate: {avg_psnr, avg_mse, avg_bitrate}')
            

In [9]:
test_epoch(video_net, dl)

Avg PSNR/MSE/Bitrate: (35.344456, 0.000292, 0.120522):   7%|▋         | 10/150 [00:14<03:21,  1.44s/it]


KeyboardInterrupt: 

In [12]:
test_epoch(video_net, dl, compress_type='full')

Avg PSNR/MSE/Bitrate: (35.180333, 0.000303, 0.066224):   3%|▎         | 5/150 [00:08<03:52,  1.60s/it]


RuntimeError: CUDA out of memory. Tried to allocate 400.00 MiB (GPU 0; 9.78 GiB total capacity; 7.40 GiB already allocated; 175.62 MiB free; 8.14 GiB reserved in total by PyTorch)