Skip to content

Commit

Permalink
add rife-video-frame-interpolation and model (#685)
Browse files Browse the repository at this point in the history
* add rife-video-frame-interpolation pipeline and model

* add doc, change test level

---------

Co-authored-by: miracle.zjf <miracle.zjf@alibaba-inc.com>
Co-authored-by: wenmeng zhou <wenmeng.zwm@alibaba-inc.com>
  • Loading branch information
3 people committed Jan 2, 2024
1 parent 39562dc commit 0b1a974
Show file tree
Hide file tree
Showing 10 changed files with 550 additions and 1 deletion.
2 changes: 2 additions & 0 deletions modelscope/metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class Models(object):
human_image_generation = 'human-image-generation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
rife = 'rife'
anydoor = 'anydoor'

# nlp models
Expand Down Expand Up @@ -456,6 +457,7 @@ class Pipelines(object):
human3d_animation = 'human3d-animation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
rife_video_frame_interpolation = 'rife-video-frame-interpolation'
anydoor = 'anydoor'
image_to_3d = 'image-to-3d'

Expand Down
3 changes: 2 additions & 1 deletion modelscope/models/cv/video_frame_interpolation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

if TYPE_CHECKING:
from .VFINet_arch import VFINet
from .rife import RIFEModel

else:
_import_structure = {'VFINet_arch': ['VFINet']}
_import_structure = {'VFINet_arch': ['VFINet'], 'rife': ['RIFEModel']}

import sys

Expand Down
119 changes: 119 additions & 0 deletions modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# The implementation here is modified based on ECCV2022-RIFE,
# originally MIT License, Copyright (c) Megvii Inc.,
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE

import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)

def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes)
)

class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock0 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.convblock1 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.convblock2 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.convblock3 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(c, c//2, 4, 2, 1),
nn.PReLU(c//2),
nn.ConvTranspose2d(c//2, 4, 4, 2, 1),
)
self.conv2 = nn.Sequential(
nn.ConvTranspose2d(c, c//2, 4, 2, 1),
nn.PReLU(c//2),
nn.ConvTranspose2d(c//2, 1, 4, 2, 1),
)

def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
feat = self.conv0(torch.cat((x, flow), 1))
feat = self.convblock0(feat) + feat
feat = self.convblock1(feat) + feat
feat = self.convblock2(feat) + feat
feat = self.convblock3(feat) + feat
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
return flow, mask

class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(self, x, scale_list=[4, 2, 1], training=False):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = (x[:, :4]).detach() * 0
mask = (x[:, :1]).detach() * 0
loss_cons = 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
# merged[i] = torch.clamp(merged[i] + res, 0, 1)
return flow_list, mask_list[2], merged
104 changes: 104 additions & 0 deletions modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# The implementation here is modified based on ECCV2022-RIFE,
# originally MIT License, Copyright (c) Megvii Inc.,
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE

import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from .warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from .IFNet_HDv3 import *
import torch.nn.functional as F
from .loss import *

from modelscope.metainfo import Models
from modelscope.models.base import Tensor
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@MODELS.register_module(Tasks.video_frame_interpolation, module_name=Models.rife)
class RIFEModel(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.flownet = IFNet()
self.flownet.to(self.device)
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL()
self.load_model(model_dir, -1)
self.eval()

def train(self):
self.flownet.train()

def eval(self):
self.flownet.eval()

def load_model(self, path, rank=0):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
else:
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')))

def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))

def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [4/scale, 2/scale, 1/scale]
_, _, merged = self.flownet(imgs, scale_list)
return merged[2].detach()

def forward(self, inputs):
img0 = inputs['img0']
img1 = inputs['img1']
scale = inputs['scale']
return {'output': self.inference(img0, img1, scale)}

def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
scale = [4, 2, 1]
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
loss_l1 = (merged[2] - gt).abs().mean()
loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
# loss_vgg = self.vgg(merged[2], gt)
if training:
self.optimG.zero_grad()
loss_G = loss_cons + loss_smooth * 0.1
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
return merged[2], {
'mask': mask,
'flow': flow[2][:, :2],
'loss_l1': loss_l1,
'loss_cons': loss_cons,
'loss_smooth': loss_smooth,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# The implementation here is modified based on ECCV2022-RIFE,
# originally MIT License, Copyright (c) Megvii Inc.,
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE

from .RIFE_HDv3 import RIFEModel
132 changes: 132 additions & 0 deletions modelscope/models/cv/video_frame_interpolation/rife/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# The implementation here is modified based on ECCV2022-RIFE,
# originally MIT License, Copyright (c) Megvii Inc.,
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class EPE(nn.Module):
def __init__(self):
super(EPE, self).__init__()

def forward(self, flow, gt, loss_mask):
loss_map = (flow - gt.detach()) ** 2
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
return (loss_map * loss_mask)


class Ternary(nn.Module):
def __init__(self):
super(Ternary, self).__init__()
patch_size = 7
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape(
(patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)

def transform(self, img):
patches = F.conv2d(img, self.w, padding=3, bias=None)
transf = patches - img
transf_norm = transf / torch.sqrt(0.81 + transf**2)
return transf_norm

def rgb2gray(self, rgb):
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray

def hamming(self, t1, t2):
dist = (t1 - t2) ** 2
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
return dist_norm

def valid_mask(self, t, padding):
n, _, h, w = t.size()
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
mask = F.pad(inner, [padding] * 4)
return mask

def forward(self, img0, img1):
img0 = self.transform(self.rgb2gray(img0))
img1 = self.transform(self.rgb2gray(img1))
return self.hamming(img0, img1) * self.valid_mask(img0, 1)


class SOBEL(nn.Module):
def __init__(self):
super(SOBEL, self).__init__()
self.kernelX = torch.tensor([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]).float()
self.kernelY = self.kernelX.clone().T
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)

def forward(self, pred, gt):
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
img_stack = torch.cat(
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]

L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
loss = (L1X+L1Y)
return loss

class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False

class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0):
super(VGGPerceptualLoss, self).__init__()
blocks = []
pretrained = True
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
for param in self.parameters():
param.requires_grad = False

def forward(self, X, Y, indices=None):
X = self.normalize(X)
Y = self.normalize(Y)
indices = [2, 7, 12, 21, 30]
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
k = 0
loss = 0
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
Y = self.vgg_pretrained_features[i](Y)
if (i+1) in indices:
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
k += 1
return loss

if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
ternary_loss = Ternary()
print(ternary_loss(img0, img1).shape)
Loading

0 comments on commit 0b1a974

Please sign in to comment.