In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.DCVC_net import DCVC_net
import torch
from torchvision import transforms
import numpy as np
import pathlib
import os
import matplotlib.pyplot as plt
import wandb
import tqdm
import torchnet.meter as meter


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
LAMBDA = 2048
BATCH_SIZE = 4
DATA_DIR = pathlib.Path('/data2/jatin/vimeo_septuplet/sequences')
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [4]:
video_net = DCVC_net()

In [None]:
exp_name = f'dcvc-reproduce_lamba={LAMBDA}-train-procedure'

In [1]:
# load the good weights
# video_net.opticFlow = torch.load('./optflow.pth')
video_net = video_net.to(DEVICE)
optimizer = torch.optim.AdamW(video_net.parameters(), lr=1e-4)

NameError: name 'video_net' is not defined

In [6]:
wandb.init(
    project="DCVC-B-frames", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=exp_name, 
    # Track hyperparameters and run metadata
    config={
    "learning_rate": 1e-4,
    "architecture": "DCVC",
    "dataset": "Vimeo-90k",
    "epochs": 20,
    "resume": True
})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcodec-crew[0m (use `wandb login --relogin` to force relogin)


In [8]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, crop_size=256, make_p_cut=False, make_b_cut=True, deterministic=False):
        if make_b_cut and make_p_cut:
            raise ValueError('Can only choose one of B-frames or P-frames')
        self.data_dir = data_dir
        self.crop_size = crop_size
        self.make_p_cut = make_p_cut
        self.make_b_cut = make_b_cut
        self.deterministic = deterministic
        self.all_paths = []
        for seq in os.listdir(self.data_dir):
            subseq = os.listdir(self.data_dir / seq)
            for s in subseq:
                self.all_paths.append(self.data_dir / seq / s)
        assert len(self.all_paths) == 91701
        
        self.transforms = torch.nn.Sequential(
            transforms.RandomCrop(crop_size)
        )
       
    def __getitem__(self, i):
        path = self.all_paths[i]
        imgs = []
        if self.make_b_cut:
            # load two reference frames and the B-frame in the middle
            #TODO: implement making this deterministic
            interval = np.random.randint(1, 4) # can be 1, 2, or 3
            ref1 = plt.imread(path / f'im{1}.png')
            ref2 = plt.imread(path / f'im{1 + interval*2}.png')
            # this is the B-frame, in the middle
            im = plt.imread(path / f'im{1 + interval}.png')
            imgs = [ref1, ref2, im]
        elif self.make_p_cut:
            ref = plt.imread(path / f'im1.png')
            im = plt.imread(path / f'im2.png')
            imgs = [ref, im]
        else:
            # load full sequence
            for i in range(1, 8):
                # should be between [0, 1]
                img = plt.imread(path / f'im{i}.png')
        
        # plt.imread should make inputs in [0, 1] for us
        imgs = np.stack(imgs, axis=0)
        # bring RGB channels in front
        imgs = imgs.transpose(0, 3, 1, 2)
        return self.transforms(torch.FloatTensor(imgs))

    def __len__(self):
        return len(self.all_paths)

ds = VideoDataset(DATA_DIR, make_p_cut=True, make_b_cut=False)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=6,
    prefetch_factor=5,
)

In [9]:
mse_criterion = torch.nn.MSELoss()

In [10]:
def count_parameters(model):
    """Return number of parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(video_net)} trainable parameters')

The model has 7944448 trainable parameters


In [None]:
def mse_to_psnr(mse):
    return -10.0*np.log10(mse)
    

In [11]:
def train_epoch(model, epoch, dl, optimizer, train_type):
    mse_criterion = torch.nn.MSELoss()
    model.train()
    
    if train_type == 'memc':
        mse_meter = meter.AverageValueMeter()
    elif train_type == 'memc_bpp':
        mse_meter = meter.AverageValueMeter()
        bpp_meter = meter.AverageValueMeter()
    elif train_type == 'recon':
        mse_meter = meter.AverageValueMeter()
    else:
        mse_meter = meter.AverageValueMeter()
        bpp_meter = meter.AverageValueMeter()
        bpp_mv_y_meter = meter.AverageValueMeter()
        bpp_mv_z_meter = meter.AverageValueMeter()
        bpp_y_meter = meter.AverageValueMeter()
        bpp_z_meter = meter.AverageValueMeter()
    
    optimizer.zero_grad()
    pbar = tqdm.tqdm(dl)
    for i, x in enumerate(pbar):
        x = x.to(DEVICE)
        ref = x[:,0]
        im = x[:,1]
        preds_dict = model(ref, im, compress_type='train_compress', train_type=train_type)
        if train_type == 'memc':
            mse = mse_criterion(preds_dict['pred'], im)
            mse.backward()
            
            # metrics
            mse_meter.add(mse.item())
            wandb_export = {
                'train/psnr': mse_to_psnr(mse_meter.value()[0]),
            }
        elif train_type == 'memc_bpp':
            mse = mse_criterion(preds_dict['pred'], im)
            bpp = preds_dict['mv_z_bpp'] + preds_dict['mv_y_bpp']
            loss = mse * LAMBDA + bpp
            loss.backward()
            
            # metrics
            mse_meter.add(mse.item())
            bpp_meter.add(bpp.item())
            wandb_export = {
                'train/psnr': mse_to_psnr(mse_meter.value()[0]),
                'train/bpp': bpp_meter.value()[0]
            }
        elif train_type == 'recon':
            mse = mse_criterion(preds_dict['recon_image'], im)
            mse.backward()
            
            # metrics
            mse_meter.add(mse.item())
            wandb_export = {
                'train/psnr_recon': mse_to_psnr(mse_meter.value()[0]),
            }
        else:
            mse = mse_criterion(preds_dict['recon_image'], im)
            loss = mse * LAMBDA + preds_dict['bpp']
            loss.backward()
            
            # metrics
            mse_meter.add(mse.item())
            bpp_meter.add(preds_dict['bpp'].item())
            bpp_mv_y_meter.add(preds_dict['bpp_mv_y'].item())
            bpp_mv_z_meter.add(preds_dict['bpp_mv_z'].item())
            bpp_y_meter.add(preds_dict['bpp_y'].item())
            bpp_z_meter.add(preds_dict['bpp_z'].item())
            
            wandb_export = {
                'train/psnr_recon': mse_to_psnr(mse_meter.value()[0]),
                'train/bpp': bpp_meter.value()[0],
                'train/bpp_mv_y': bpp_mv_y_meter.value()[0],
                'train/bpp_mv_z': bpp_mv_z_meter.value()[0],
                'train/bpp_y': bpp_y_meter.value()[0],
                'train/bpp_z': bpp_z_meter.value()[0],
            }
            

        optimizer.step()
        optimizer.zero_grad()
        
        if i % 1 == 0:
            wandb_export['train/epoch'] = epoch
            wandb_export['train/train_type'] = train_type
            wandb.log(wandb_export)
        # save every 
        if i % 4000 == 3999:
            print('Saving model')
            torch.save(
                {'model': model.state_dict(), 'optimizer': optimizer.state_dict()},
                SAVE_FOLDER / f"dcvc_epoch={epoch}_batch_{i}.pt",
            )
            

In [12]:
SAVE_FOLDER = pathlib.Path(exp_name)
os.makedirs(SAVE_FOLDER, exist_ok=True)

In [None]:
def freeze_layer(l):
    for p in l.parameters():
        p.requires_grad = False

def unfreeze_layer(l):
    for p in l.parameters():
        p.requires_grad = True

In [13]:
freeze_layer(video_net.opticFlow)

train_type = 'memc'
train_epoch(video_net, 1, dl, optimizer, train_type=train_type)
torch.save(
    {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
    SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
)

train_type = 'memc_bpp'
for i in range(1, 4):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )

# freeze mv layers
mv_layers = [
    video_net.bitEstimator_z_mv,
    video_net.mvpriorEncoder,
    video_net.mvpriorDecoder,
    video_net.mvDecoder_part1,
    video_net.mvDecoder_part2,
    video_net.auto_regressive_mv,
    video_net.entropy_parameters_mv,
]
print('Freezing MV layers')
for l in mv_layers:
    freeze_layer(l)

train_type = 'recon'
for i in range(4, 10):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )

train_type = 'full'
for i in range(10, 13):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )
    
# now unfreeze
print('Unfreezing MV layers and optical flow for end-to-end training')
for l in [video_net.opticFlow] + mv_layers:
    unfreeze_layer(l)
    
for i in range(13, 19):
    train_epoch(video_net, i, dl, optimizer, train_type=train_type)
    torch.save(
        {'model': video_net.state_dict(), 'optimizer': optimizer.state_dict()},
        SAVE_FOLDER / f"dcvc_{train_type}_epoch={i}.pt",
    )
    

USE LAMBDA True


Avg PSNR/Bitrate, Batch Loss: (39.220927, 0.062847, 0.246654):  17%|█▋        | 3999/22926 [17:05<1:20:39,  3.91it/s]

Saving model with avg psnr 39.220927


Avg PSNR/Bitrate, Batch Loss: (38.977758, 0.067147, 0.499104):  35%|███▍      | 8000/22926 [34:09<1:13:29,  3.39it/s]

Saving model with avg psnr 38.977758


Avg PSNR/Bitrate, Batch Loss: (38.869666, 0.069992, 0.829218):  52%|█████▏    | 12000/22926 [51:13<53:27,  3.41it/s] 

Saving model with avg psnr 38.869666


Avg PSNR/Bitrate, Batch Loss: (38.385357, 0.071914, 0.397238):  70%|██████▉   | 16000/22926 [1:08:17<34:26,  3.35it/s]

Saving model with avg psnr 38.385357


Avg PSNR/Bitrate, Batch Loss: (38.225119, 0.072755, 0.295643):  87%|████████▋ | 20000/22926 [1:25:22<14:22,  3.39it/s]

Saving model with avg psnr 38.225119


Avg PSNR/Bitrate, Batch Loss: (38.190542, 0.072985, 0.250607): 100%|██████████| 22926/22926 [1:37:51<00:00,  3.90it/s]
Avg PSNR/Bitrate, Batch Loss: (38.053746, 0.074942, 0.344783):  17%|█▋        | 4000/22926 [17:05<1:34:23,  3.34it/s]

Saving model with avg psnr 38.053746


Avg PSNR/Bitrate, Batch Loss: (38.022941, 0.075042, 0.402999):  35%|███▍      | 8000/22926 [34:09<1:14:34,  3.34it/s]

Saving model with avg psnr 38.022941


Avg PSNR/Bitrate, Batch Loss: (38.062079, 0.074641, 0.346605):  52%|█████▏    | 12000/22926 [51:14<54:08,  3.36it/s] 

Saving model with avg psnr 38.062079


Avg PSNR/Bitrate, Batch Loss: (38.041525, 0.074999, 0.210176):  70%|██████▉   | 16000/22926 [1:08:18<34:18,  3.36it/s]

Saving model with avg psnr 38.041525


Avg PSNR/Bitrate, Batch Loss: (37.918479, 0.076044, 0.397714):  87%|████████▋ | 20000/22926 [1:25:21<14:37,  3.33it/s] 

Saving model with avg psnr 37.918479


Avg PSNR/Bitrate, Batch Loss: (37.862637, 0.077274, 0.748301): 100%|██████████| 22926/22926 [1:37:51<00:00,  3.90it/s]
Avg PSNR/Bitrate, Batch Loss: (37.281517, 0.08414, 0.235896):  17%|█▋        | 4000/22926 [17:04<1:35:03,  3.32it/s]  

Saving model with avg psnr 37.281517


Avg PSNR/Bitrate, Batch Loss: (37.499517, 0.083779, 0.278157):  35%|███▍      | 8000/22926 [34:07<1:13:39,  3.38it/s]

Saving model with avg psnr 37.499517


Avg PSNR/Bitrate, Batch Loss: (37.554326, 0.083896, 0.459263):  52%|█████▏    | 12000/22926 [51:12<53:50,  3.38it/s] 

Saving model with avg psnr 37.554326


Avg PSNR/Bitrate, Batch Loss: (37.598734, 0.084145, 0.461153):  70%|██████▉   | 16000/22926 [1:08:16<33:58,  3.40it/s]

Saving model with avg psnr 37.598734


Avg PSNR/Bitrate, Batch Loss: (37.617796, 0.084297, 0.554207):  87%|████████▋ | 20000/22926 [1:25:20<14:31,  3.36it/s]

Saving model with avg psnr 37.617796


Avg PSNR/Bitrate, Batch Loss: (37.625801, 0.084552, 0.316915): 100%|██████████| 22926/22926 [1:37:48<00:00,  3.91it/s]
Avg PSNR/Bitrate, Batch Loss: (37.643878, 0.088164, 0.456853):  17%|█▋        | 4000/22926 [17:05<1:34:50,  3.33it/s]

Saving model with avg psnr 37.643878


Avg PSNR/Bitrate, Batch Loss: (37.590251, 0.088868, 0.440065):  35%|███▍      | 8000/22926 [34:09<1:13:48,  3.37it/s]

Saving model with avg psnr 37.590251


Avg PSNR/Bitrate, Batch Loss: (37.596512, 0.089813, 0.543199):  52%|█████▏    | 12000/22926 [51:14<53:44,  3.39it/s] 

Saving model with avg psnr 37.596512


Avg PSNR/Bitrate, Batch Loss: (37.538685, 0.090109, 0.324686):  70%|██████▉   | 16000/22926 [1:08:18<34:16,  3.37it/s]

Saving model with avg psnr 37.538685


Avg PSNR/Bitrate, Batch Loss: (37.54284, 0.090848, 0.420329):  87%|████████▋ | 20000/22926 [1:25:22<14:26,  3.38it/s] 

Saving model with avg psnr 37.54284


Avg PSNR/Bitrate, Batch Loss: (37.420717, 0.091282, 0.95865): 100%|██████████| 22926/22926 [1:37:52<00:00,  3.90it/s]  
Avg PSNR/Bitrate, Batch Loss: (37.570614, 0.094544, 0.327195):  17%|█▋        | 4000/22926 [17:05<1:33:55,  3.36it/s]

Saving model with avg psnr 37.570614


Avg PSNR/Bitrate, Batch Loss: (37.523306, 0.094739, 0.34857):  19%|█▉        | 4451/22926 [19:00<1:18:54,  3.90it/s] 


KeyboardInterrupt: 

In [None]:
wandb.finish()