In [57]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
import torchnet.meter as meter

In [58]:
LAMBDA = 2048
BATCH_SIZE = 4
DATA_DIR1 = pathlib.Path('.//vimeo_septuplet/sequences')  #pathlib.Path('PATH')
DEVICE = torch.device('cpu')


video_net = DCVC_net()

# Freeze parameters
# for p in video_net.mvEncoder.parameters():
#     p.requires_grad = False
# for p in video_net.contextualEncoder.parameters():
#     p.requires_grad = False
# for p in video_net.priorEncoder.parameters():
#     p.requires_grad = False
# for p in video_net.mvpriorEncoder.parameters():
#     p.requires_grad = False

chpt = torch.load('./checkpoints/model_dcvc_quality_3_psnr.pth')
video_net.load_state_dict(chpt)
video_net = video_net.to(DEVICE)

expirement_name = f"experiment-variable-quantization-{LAMBDA}-local"

wandb.init(
    project="DCVC-Variable-Quantization", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=expirement_name, 
    # Track hyperparameters and run metadata
    config={
    "learning_rate": 1e-4,
    "architecture": "DCVC",
    "dataset": "Vimeo-90k",
    "epochs": 20,
    "resume": True
})

optimizer = torch.optim.Adam(video_net.parameters(), lr=wandb.config.learning_rate)
del chpt


In [63]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir1, crop_size=256, make_b_cut=True, deterministic=False):
        self.data_dir1 = data_dir1
        self.crop_size = crop_size
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir1):
            subseq = os.listdir(self.data_dir1 / seq)
            for s in subseq:
                print(self.data_dir1 / seq / s)
                self.all_paths.append(self.data_dir1 / seq / s)
        # assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        '''
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        else:
        '''
        '''
        # load full sequence
        for i in range(1, 8):
            # should be between [0, 1]
            img = plt.imread(path / f'im{i}.png')
            imgs.append(img)
        '''
        
        ref = plt.imread(path / f'im1.png')
        im = plt.imread(path /	f'im2.png')
        imgs = [ref, im]        

        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR1)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True
)

criterion = torch.nn.MSELoss()

def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')




vimeo_septuplet/sequences/00001/0291
vimeo_septuplet/sequences/00001/0622
vimeo_septuplet/sequences/00001/0278
vimeo_septuplet/sequences/00001/0285
vimeo_septuplet/sequences/00001/0268
vimeo_septuplet/sequences/00001/0266
vimeo_septuplet/sequences/00001/0287
vimeo_septuplet/sequences/00001/0619
vimeo_septuplet/sequences/00001/0275
The model has 7944448 trainable parameters


In [64]:
def train_epoch(model, epoch, dl, optimizer, criterion, use_lambda=True):
    model.train()

    loss_meter = meter.AverageValueMeter()
    mse_meter = meter.AverageValueMeter()
    bpp_meter = meter.AverageValueMeter()
    bpp_mv_y_meter = meter.AverageValueMeter()
    bpp_mv_z_meter = meter.AverageValueMeter()
    bpp_y_meter = meter.AverageValueMeter()
    bpp_z_meter = meter.AverageValueMeter()
    
    optimizer.zero_grad()
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        #ref2 = x[:,1]
        im = x[:,1]
        preds_dict = model(ref1, im)
        preds = preds_dict['recon_image']
        bpp = preds_dict['bpp']
        mse_loss = criterion(preds, im)
        mse_ls = mse_loss.item()
        avg_mse = mse_meter.value()[0]
        if use_lambda:
            loss = mse_loss * LAMBDA + bpp
        else:
            loss = mse_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        ls = loss.item()
        loss_meter.add(ls)
        mse_meter.add(mse_ls)
        bpp_meter.add(bpp.item())
        bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
        bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
        bpp_y_meter.add(preds_dict['bpp_y'].item())
        bpp_z_meter.add(preds_dict['bpp_z'].item())
        if i % 1 == 0:
            avg_psnr = -10.0*np.log10(mse_meter.value()[0])
            wandb.log(
                {
                    'train/epoch': epoch,
                    'train/batch_loss': ls,
                    'train/avg_loss': loss_meter.value()[0],
                    'train/avg_mse_loss': mse_meter.value()[0],
                    'train/avg_bpp': bpp_meter.value()[0],
                    'train/avg_bpp_mv_y': bpp_mv_y_meter.value()[0],
                    'train/avg_bpp_mv_z': bpp_mv_z_meter.value()[0],
                    'train/avg_bpp_y': bpp_y_meter.value()[0],
                    'train/avg_bpp_z': bpp_z_meter.value()[0],
                    'train/avg_psnr': avg_psnr,
                }
            )
            ls = round(ls, 6)
            avg_bitrate = round(bpp_meter.value()[0], 6)
            avg_psnr = round(avg_psnr, 6)
            pbar.set_description(f'Avg PSNR/Bitrate, Batch Loss: {avg_psnr, avg_bitrate, ls}')

    return loss_meter.value()[0], False, None


In [65]:
SAVE_FOLDER = pathlib.Path(expirement_name)
os.makedirs(SAVE_FOLDER, exist_ok=True)

USE_LAMBDA = True
for i in range(1, wandb.config.epochs + 1):
    avg_loss, had_err, err_x = train_epoch(video_net, i, dl, optimizer, criterion, use_lambda=USE_LAMBDA)

Avg PSNR/Bitrate, Batch Loss: (42.596586, 0.054241, 0.290059): 100%|██████████| 9/9 [01:34<00:00, 10.52s/it]
Avg PSNR/Bitrate, Batch Loss: (42.68549, 0.053809, 0.096833): 100%|██████████| 9/9 [01:21<00:00,  9.05s/it] 
Avg PSNR/Bitrate, Batch Loss: (41.837742, 0.060076, 0.111628): 100%|██████████| 9/9 [01:19<00:00,  8.87s/it]
Avg PSNR/Bitrate, Batch Loss: (42.863766, 0.059353, 0.1041): 100%|██████████| 9/9 [01:18<00:00,  8.74s/it]  
Avg PSNR/Bitrate, Batch Loss: (42.413106, 0.067534, 0.071266): 100%|██████████| 9/9 [01:18<00:00,  8.75s/it]
Avg PSNR/Bitrate, Batch Loss: (41.592343, 0.075367, 0.243876): 100%|██████████| 9/9 [01:18<00:00,  8.75s/it]
Avg PSNR/Bitrate, Batch Loss: (42.056497, 0.078416, 0.044493): 100%|██████████| 9/9 [01:18<00:00,  8.75s/it]
Avg PSNR/Bitrate, Batch Loss: (43.194841, 0.064473, 0.11833): 100%|██████████| 9/9 [01:18<00:00,  8.75s/it] 
Avg PSNR/Bitrate, Batch Loss: (42.530699, 0.085155, 0.248313): 100%|██████████| 9/9 [01:18<00:00,  8.76s/it]
Avg PSNR/Bitrate, B

In [66]:
wandb.finish()

0,1
train/avg_bpp,▁▂▃▂▁▂▂▂▃▂▇▃▄▄▂▂▄▄▄▃▄▃▄▄▃▄▆▅▆▅▅▅▅▅▇▆▆▇██
train/avg_bpp_mv_y,▁▁▁▁▁▁▂▁▂▁▂▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▃▂▄▃▄▃▃▄▆▅▅▆█▇
train/avg_bpp_mv_z,▃▄█▃▃▂▁▃▆▂▃▂▇▄▁▃▇▃▅▄▄▄▆▃▅▅▄▄▄▂▁▃▄▃▄▄▃▃▄▄
train/avg_bpp_y,▁▂▃▂▁▂▂▂▃▃█▃▄▄▂▃▄▄▄▄▄▃▄▄▃▄▆▅▅▄▄▃▄▄▄▄▄▃▄▄
train/avg_bpp_z,▁▃▄▄▂▃▂▃▄▃█▃▆▅▂▄▆▅▄▄▃▃▅▄▂▄▅▄▅▄▅▄▃▄▅▅▆▄▄▄
train/avg_loss,▁▂▂▂▁▂▂▂▃▂█▂▃▃▂▂▃▃▃▂▂▂▃▃▂▂▄▃▃▃▃▂▂▃▄▃▃▃▃▄
train/avg_mse_loss,▁▂▂▂▁▂▂▂▂▂█▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▃▂▂▂▂▂▁▂▂▂▂▂▂▂
train/avg_psnr,█▆▅▆▇▅▆▆▅▆▁▅▅▅▆▆▅▅▆▅▆▇▆▅▇▆▄▅▅▆▆▆▇▆▅▆▅▆▆▅
train/batch_loss,▁▃▃▁▂█▂▃▃▂█▁▃▃▂▃▃▂▂▅▂▂▃▃▂▃▆▃▂▁▂▂▂▂▅▃▃▅▂▂
train/epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████

0,1
train/avg_bpp,0.197
train/avg_bpp_mv_y,0.11948
train/avg_bpp_mv_z,0.00113
train/avg_bpp_y,0.07469
train/avg_bpp_z,0.00169
train/avg_loss,0.33799
train/avg_mse_loss,7e-05
train/avg_psnr,41.6214
train/batch_loss,0.19661
train/epoch,20.0
