In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
from torchnet import meter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
LAMBDA = 2048
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('/data2/jatin/vimeo_septuplet/sequences')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
exp_name = f'b_frame_lamba={LAMBDA}-dcvc-train-procedure'

In [5]:
wandb.init(
    project="DCVC-B-frames", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=exp_name, 
    # Track hyperparameters and run metadata
    config={
    "learning_rate": 1e-4,
    "architecture": "DCVC",
    "dataset": "UVG",
    "epochs": 20,
    },
    resume=True,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcodec-crew[0m (use `wandb login --relogin` to force relogin)


In [6]:
video_net = DCVC_net(up_strategy='default')

In [7]:
# load the good weights
video_net.opticFlow = torch.load('../DCVC-Old/optflow.pth')
# these keys do not match, so we remove it from the network and add it back later
# temporalPriorEncoder = video_net.temporalPriorEncoder
# del video_net.temporalPriorEncoder

# chpt = torch.load('dcvc-b-frame-with-bitrate-lambda-2048-tryfix-bitrate/dcvc_epoch=2_int_allquantize.pt')
# chpt = torch.load('b_frame_lamba=2048-old-continue/dcvc_epoch=5_int.pt')
# video_net.load_state_dict(chpt['model'], strict=True)

video_net = video_net.to(DEVICE)

optimizer = torch.optim.AdamW(video_net.parameters(), lr=wandb.config.learning_rate)
# optimizer.load_state_dict(chpt['optimizer'])

# video_net.opticFlow.requires_grad_ = False
# video_net.mvEncoder.requires_grad_ = False
# video_net.mvDecoder_part1.requires_grad_ = False
# video_net.mvDecoder_part2.requires_grad_ = False
# video_net.feature_extract.requires_grad_ = False
# video_net.context_refine.requires_grad_ = False
# video_net.contextualDecoder_part1.requires_grad_ = False

# del chpt

In [8]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=256, make_b_cut=True, deterministic=False):
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir):
            subseq = os.listdir(self.data_dir / seq)
            for s in subseq:
                self.all_paths.append(self.data_dir / seq / s)
        assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        else:
            # load full sequence
            for i in range(1, 8):
                # should be between [0, 1]
                img = plt.imread(path / f'im{i}.png')
        
        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5
)

In [9]:
mse_criterion = torch.nn.MSELoss()

In [11]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 17969762 trainable parameters


In [12]:
def mse_to_psnr(mse):
    return -10.0*np.log10(mse)
    

In [13]:
def train_epoch(model, epoch, dl, optimizer, train_type):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    
    if train_type == 'memc':
        mse1_meter = meter.AverageValueMeter()
        mse2_meter = meter.AverageValueMeter()
    elif train_type == 'memc_bpp':
        mse1_meter = meter.AverageValueMeter()
        mse2_meter = meter.AverageValueMeter()
        bpp_meter = meter.AverageValueMeter()
    elif train_type == 'recon':
        mse_meter = meter.AverageValueMeter()
    else:
        mse_meter = meter.AverageValueMeter()
        bpp_meter = meter.AverageValueMeter()
        bpp_mv_y_meter = meter.AverageValueMeter()
        bpp_mv_z_meter = meter.AverageValueMeter()
        bpp_y_meter = meter.AverageValueMeter()
        bpp_z_meter = meter.AverageValueMeter()
    
    optimizer.zero_grad()
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref1 = x[:,0]
        ref2 = x[:,1]
        im = x[:,2]
        preds_dict = model(ref1, ref2, im, compress_type='train_compress', train_type=train_type)
        if train_type == 'memc':
            preds_dict = video_net(ref1, ref2, im, compress_type='train_compress', train_type='memc')
            mse1 = mse_criterion(preds_dict['pred1'], im)
            mse2 = mse_criterion(preds_dict['pred2'], im)
            mse = (mse1 + mse2)/2
            mse.backward()
            
            # metrics
            mse1_meter.add(mse1.item())
            mse2_meter.add(mse2.item())
            wandb_export = {
                'train/psnr1': mse_to_psnr(mse1_meter.value()[0]),
                'train/psnr2': mse_to_psnr(mse2_meter.value()[0]),
            }
        elif train_type == 'memc_bpp':
            preds_dict = video_net(ref1, ref2, im, compress_type='train_compress', train_type='memc_bpp')
            mse1 = mse_criterion(preds_dict['pred1'], im)
            mse2 = mse_criterion(preds_dict['pred2'], im)
            mse = (mse1 + mse2)/2
            bpp = preds_dict['mv_z_bpp'] + preds_dict['mv_y_bpp']
            loss = mse * LAMBDA + bpp
            loss.backward()
            
            # metrics
            mse1_meter.add(mse1.item())
            mse2_meter.add(mse2.item())
            bpp_meter.add(bpp.item())
            wandb_export = {
                'train/psnr1': mse_to_psnr(mse1_meter.value()[0]),
                'train/psnr2': mse_to_psnr(mse2_meter.value()[0]),
                'train/bpp': bpp_meter.value()[0]
            }
        elif train_type == 'recon':
            preds_dict = video_net(ref1, ref2, im, compress_type='train_compress', train_type='recon')
            mse = mse_criterion(preds_dict['recon_image'], im)
            mse.backward()
            
            # metrics
            mse_meter.add(mse.item())
            wandb_export = {
                'train/psnr_recon': mse_to_psnr(mse_meter.value()[0]),
            }
        else:
            preds_dict = video_net(ref1, ref2, im, compress_type='train_compress', train_type='full')
            mse = mse_criterion(preds_dict['recon_image'], im)
            loss = mse * LAMBDA + preds_dict['bpp']
            loss.backward()
            
            # metrics
            mse_meter.add(mse.item())
            bpp_meter.add(preds_dict['bpp'].item())
            bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
            bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
            bpp_y_meter.add(preds_dict['bpp_y'].item())
            bpp_z_meter.add(preds_dict['bpp_z'].item())
            
            wandb_export = {
                'train/psnr_recon': mse_to_psnr(mse_meter.value()[0]),
                'train/bpp': bpp_meter.value()[0],
                'train/bpp_mv_y': bpp_mv_y_meter.value()[0],
                'train/bpp_mv_z': bpp_mv_z_meter.value()[0],
                'train/bpp_y': bpp_y_meter.value()[0],
                'train/bpp_z': bpp_z_meter.value()[0],
            }
            

        optimizer.step()
        optimizer.zero_grad()
        
        if i % 1 == 0:
            wandb_export['train/epoch'] = epoch
            wandb_export['train/train_type'] = train_type
            wandb.log(wandb_export)
        # save every 
        if i % 4000 == 3999:
            print('Saving model')
            torch.save(
                {'model': model.state_dict(), 'optimizer': optimizer.state_dict()},
                SAVE_FOLDER / f"dcvc_epoch={epoch}_batch_{i}.pt",
            )
            

In [14]:
SAVE_FOLDER = pathlib.Path(exp_name)
os.makedirs(SAVE_FOLDER, exist_ok=True)

In [15]:
def freeze_layer(l):
    for p in l.parameters():
        p.requires_grad = False

def unfreeze_layer(l):
    for p in l.parameters():
        p.requires_grad = True

In [None]:
freeze_layer(video_net.opticFlow)

train_type = 'memc'
train_epoch(video_net, 1, dl, optimizer, train_type=train_type)
torch.save(
    {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
    SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
)

train_type = 'memc_bpp'
for i in range(1, 4):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )

# freeze mv layers
mv_layers = [
    video_net.bitEstimator_z_mv,
    video_net.mvpriorEncoder,
    video_net.mvpriorDecoder,
    video_net.mvDecoder_part1,
    video_net.mvDecoder_part2,
    video_net.auto_regressive_mv,
    video_net.entropy_parameters_mv,
]
print('Freezing MV layers')
for l in mv_layers:
    freeze_layer(l)

train_type = 'recon'
for i in range(4, 10):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )

train_type = 'full'
for i in range(10, 13):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )
    
# now unfreeze
print('Unfreezing MV layers and optical flow for end-to-end training')
for l in [video_net.opticFlow] + mv_layers:
    unfreeze_layer(l)
    
for i in range(13, 19):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )
    

 14%|█▎        | 3118/22926 [12:01<1:14:47,  4.41it/s]

In [15]:
wandb.finish()

0,1
train/avg_bitrate,█▆▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/avg_loss,█▇▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/avg_mse_loss,██▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train/avg_psnr,▁▁▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▇▇▇▇███▇████████
train/batch_loss,█▆▅▄▄▃▄▂▃▃▃▂▃▃▄▂▁▂▃▃▂▂▃▄▂▁▅▂▂▁▂▂▁▂▂▂▁▁▁▂
train/epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████

0,1
train/avg_bitrate,0.02835
train/avg_loss,0.12763
train/avg_mse_loss,5e-05
train/avg_psnr,43.14486
train/batch_loss,0.07535
train/epoch,2.0
