In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('/data2/jatin/vimeo_septuplet/sequences')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net(up_strategy='default')

In [5]:
chpt = torch.load('dcvc-b-frame-with-bitrate/dcvc_epoch=2_int.pt')
video_net.load_state_dict(chpt['model'])
video_net = video_net.to(DEVICE)
del chpt

In [6]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=256, make_b_cut=True, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir):
            subseq = os.listdir(self.data_dir / seq)
            for s in subseq:
                self.all_paths.append(self.data_dir / seq / s)
        assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        else:
            # load full sequence
            for i in range(1, 8):
                # should be between [0, 1]
                img = plt.imread(path / f'im{i}.png')
        
        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [7]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 10622976 trainable parameters


In [9]:
o = next(iter(dl))

In [11]:
with torch.no_grad():
    x = o.to(DEVICE)
    ref1 = x[:,0]
    ref2 = x[:,1]
    im = x[:,2]
    preds_dict = video_net(ref1, ref2, im, compress_type='train_compress')

In [12]:
preds_dict['recon_image']

tensor([[[[ 1.6533e-01,  1.4435e-01,  1.4015e-01,  ...,  5.1220e-01,
            4.5511e-01,  4.3836e-01],
          [ 1.9250e-01,  1.7274e-01,  1.6534e-01,  ...,  5.2217e-01,
            4.8226e-01,  4.7115e-01],
          [ 2.0468e-01,  1.8371e-01,  1.8046e-01,  ...,  5.2086e-01,
            5.0545e-01,  4.9344e-01],
          ...,
          [ 5.4512e-01,  5.3954e-01,  5.4066e-01,  ...,  5.9998e-01,
            5.9744e-01,  6.2171e-01],
          [ 5.4208e-01,  5.3724e-01,  5.3748e-01,  ...,  6.0074e-01,
            5.9800e-01,  6.2137e-01],
          [ 5.2892e-01,  5.2702e-01,  5.2891e-01,  ...,  5.8546e-01,
            5.8178e-01,  6.0670e-01]],

         [[ 9.9622e-02,  7.7275e-02,  7.0123e-02,  ...,  4.7495e-01,
            4.2000e-01,  3.9010e-01],
          [ 1.0420e-01,  8.7383e-02,  7.7790e-02,  ...,  4.7109e-01,
            4.2957e-01,  4.1782e-01],
          [ 1.0545e-01,  8.9938e-02,  8.0619e-02,  ...,  4.4922e-01,
            4.3528e-01,  4.2729e-01],
          ...,
     

In [13]:
torch.square(preds_dict['recon_image'] - im).mean()

tensor(7.8456e-05, device='cuda:0')

In [14]:
torch.save(o, 'good_input.pt')

In [10]:
def test_epoch(model, dl):
    mse_criterion = torch.nn.MSELoss()
    model.eval() # this seems to make a difference
    epoch_mse = 0
    epoch_bitrate = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        with torch.no_grad():
            preds_dict = model(ref1, ref2, im, compress_type='train_compress')
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        mse_loss = mse_criterion(preds, im)

        epoch_mse += mse_loss.item()
        epoch_bitrate += bpp.item()
        total_count += 1
        if i % 1 == 0:
            avg_mse_loss = epoch_mse / total_count
            avg_bitrate = epoch_bitrate / total_count
            avg_psnr = -10.0*np.log10(avg_mse_loss)
            avg_bitrate = round(avg_bitrate, 6)
            avg_psnr = round(avg_psnr, 6)
            pbar.set_description(f'Avg PSNR/Bitrate: {avg_psnr, avg_bitrate}')
            

In [11]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate: (44.391822, 0.025431):   0%|          | 13/22926 [00:06<2:56:29,  2.16it/s]


KeyboardInterrupt: 

In [10]:
test_epoch(video_net, dl)

Avg PSNR/Bitrate: (29.594852, 0.012332):   0%|          | 54/22926 [00:15<1:52:01,  3.40it/s]


KeyboardInterrupt: 