In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import numpy as np
import cv2
import os
import sys

from PWC_src import PWC_Net, flow_to_image
from PWC_src.pwc import FlowEstimate

import warnings
warnings.filterwarnings("ignore")
device = 'cuda'
starting_epoch = 0

In [2]:
def dense_warp(image, flow):
    """
    Densely warps an image using optical flow.

    Args:
        image (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
        flow (torch.Tensor): Optical flow tensor of shape (batch_size, 2, height, width).

    Returns:
        torch.Tensor: Warped image tensor of shape (batch_size, channels, height, width).
    """
    batch_size, channels, height, width = image.size()

    # Generate a grid of pixel coordinates based on the optical flow
    grid_y, grid_x = torch.meshgrid(torch.arange(height), torch.arange(width))
    grid = torch.stack((grid_x, grid_y), dim=-1).to(image.device)
    grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1)
    new_grid = grid + flow.permute(0, 2, 3, 1)

    # Normalize the grid coordinates between -1 and 1
    new_grid /= torch.tensor([width - 1, height - 1], dtype=torch.float32, device=image.device)
    new_grid = new_grid * 2 - 1
    # Perform the dense warp using grid_sample
    warped_image = F.grid_sample(image, new_grid, align_corners=True)

    return warped_image

In [3]:
from models import ResNet, UNet
class DIFRINT(nn.Module):
    def __init__(self):
        super(DIFRINT,self).__init__()
        self.resnet = ResNet(hidden_size=64).to(device).train()
        self.unet = UNet(hidden_size=64).to(device).train()
        self.pwc = PWC_Net('./ckpt/sintel.pytorch').to(device).eval()

    def get_flow(self,img1,img2):
        img1_t = (img1 + 1) / 2 
        img2_t = (img2 + 1) / 2 
        flow = FlowEstimate(img1_t,img2_t, self.pwc)
        return flow.detach()
    
    def forward(self, ft_minus, ft, fs, ft_plus):
        with torch.no_grad():
            flo1 = self.get_flow(ft_minus, fs)
            flo2 = self.get_flow(ft_plus, fs)
            warped1 = dense_warp(ft_minus,flo1)
            warped2 = dense_warp(ft_plus,flo2)
            fint = self.unet(warped1, warped2, flo1, flo2, ft_minus, ft_plus)
            flo3 = self.get_flow(ft,fint)
            warped3 = dense_warp(ft,flo3)
        fout = self.resnet(fint, warped3,flo3, ft)
        return fint,fout


In [4]:
difrint = DIFRINT().train().to(device)
optimizer = torch.optim.Adam(difrint.parameters(), lr=1e-4,betas=(0.9, 0.99))
vgg19 = models.vgg19(weights='IMAGENET1K_V1')
vgg19 = nn.Sequential(*list(vgg19.children())[0][:-1]) # use all layers up to relu3_3
vgg19.eval()
vgg19.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [5]:
def perceptual_loss(img1, img2):
    b,c,h,w = img1.shape
    epsilon = 1e-8
    #with torch.no_grad():
    x = vgg19(img1)
    y = vgg19(img2)
    x_norm = x / (torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + epsilon)
    y_norm = y / (torch.sqrt(torch.sum(y**2, dim=1, keepdim=True)) + epsilon)
    return torch.sqrt(torch.sum((x_norm - y_norm + epsilon)**2)) ** 2 / (c*h*w)

l1_loss = nn.L1Loss()

In [6]:
unet_path = './ckpts/unet/'
resnet_path = './ckpts/resnet'
# Load UNet checkpoints
ckpts = os.listdir(unet_path)
if ckpts:
    ckpts = sorted(ckpts, key=lambda x: int(x.split('.')[0].split('_')[1]))
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(unet_path, latest))
    difrint.unet.load_state_dict(state_dict['model'])
    print(f'Loaded UNet {latest}')
# Load ResNet checkpoints
ckpts = os.listdir(resnet_path)
if ckpts:
    ckpts = sorted(ckpts, key=lambda x: int(x.split('.')[0].split('_')[1]))
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(resnet_path, latest))
    starting_epoch = state_dict['epoch'] + 1
    difrint.resnet.load_state_dict(state_dict['model'])
    #optimizer.load_state_dict(state_dict['optimizer'])
    print(f'Loaded ResNet {latest}')
    print(f'Starting from epoch {starting_epoch}')

Loaded UNet unet_20.pth
Loaded ResNet from the previous session
Starting from epoch 21


In [7]:
optimizer = torch.optim.Adam(difrint.parameters(), lr=5e-5,betas=(0.9, 0.99))

In [8]:
from datagen import DataLoader
data_gen = DataLoader('E:/Datasets/DAVIS/JPEGImages/480p/trainlist.txt', shape = (256,256,3))
from torch.utils import data
class IterDataset(data.IterableDataset):
    def __init__(self, data_generator):
        super(IterDataset, self).__init__()
        self.data_generator = data_generator

    def __iter__(self):
        return iter(self.data_generator())
data_gen = IterDataset(data_gen)
train_ds = data.DataLoader(data_gen, batch_size=1)

"from datagen import DataLoader\ndata_gen = DataLoader('E:/Datasets/DAVIS/JPEGImages/480p/trainlist.txt', shape = (256,256,3))\nfrom torch.utils import data\nclass IterDataset(data.IterableDataset):\n    def __init__(self, data_generator):\n        super(IterDataset, self).__init__()\n        self.data_generator = data_generator\n\n    def __iter__(self):\n        return iter(self.data_generator())\ndata_gen = IterDataset(data_gen)\ntrain_ds = data.DataLoader(data_gen, batch_size=1)"

In [10]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/DAVIS')

In [11]:
dataset_len = 6028
EPOCHS = 200
running_loss = 0.0 
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for epoch in range(starting_epoch,EPOCHS):
    for idx,data in enumerate(train_ds):
        if epoch > 100:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 1e-4 - 1e-6 * (epoch - 100)
        ft_minus, ft, fs, ft_plus = data
        ft_minus = ft_minus.cuda()
        ft = ft.cuda()
        fs = fs.cuda()
        ft_plus = ft_plus.cuda()
        fint, fout = difrint(ft_minus, ft, fs, ft_plus)
        optimizer.zero_grad()
        percept = (perceptual_loss(ft,fout)).item()
        loss1 = l1_loss(fs,fout) +  perceptual_loss(fs,fout) 
        loss2 = l1_loss(fs,fint) +  perceptual_loss(fs,fint)

        total_loss =  loss1 + 1 * loss2
        total_loss.backward()
        optimizer.step()

        img1 = fint[0,...].cpu().detach().permute(1,2,0).numpy()
        img1 = ((img1 + 1)/2 * 255).astype(np.uint8)
        img1 = cv2.cvtColor(img1,cv2.COLOR_RGB2BGR)
        img2 = fout[0,...].cpu().detach().permute(1,2,0).numpy()
        img2 = ((img2 + 1)/2 *255).astype(np.uint8)
        img2 = cv2.cvtColor(img2,cv2.COLOR_RGB2BGR)
        img3 = fs[0,...].cpu().permute(1,2,0).numpy()
        img3 = ((img3 + 1)/2 * 255).astype(np.uint8)
        img3 = cv2.cvtColor(img3,cv2.COLOR_RGB2BGR)
        concatenated_image = cv2.hconcat([img1, img2, img3])
        cv2.imshow('window',concatenated_image)
        if cv2.waitKey(1) & 0xFF == ord('9'):
            break
        running_loss += total_loss.item()
        if idx % 100 == 99:
            writer.add_scalar('training_loss',
                              running_loss / 100,
                              epoch * dataset_len + idx)
            running_loss = 0.0

            model_path = os.path.join(unet_path,f'unet_{epoch}.pth')
            torch.save({'model': difrint.unet.state_dict(),
                        'optimizer' : optimizer.state_dict(),
                        'epoch' : epoch}
                    ,model_path)
            
            model_path = os.path.join(resnet_path,f'resnet_{epoch}.pth')
            torch.save({'model': difrint.resnet.state_dict(),
                        'optimizer' : optimizer.state_dict(),
                        'epoch' : epoch}
                    ,model_path)
        print(f'\repoch: {epoch}, batch: {idx},running_loss: {running_loss / (idx % 100 + 1)}',end = '')

ValueError: not enough values to unpack (expected 4, got 3)