-
Notifications
You must be signed in to change notification settings - Fork 707
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add rife-video-frame-interpolation and model (#685)
* add rife-video-frame-interpolation pipeline and model * add doc, change test level --------- Co-authored-by: miracle.zjf <miracle.zjf@alibaba-inc.com> Co-authored-by: wenmeng zhou <wenmeng.zwm@alibaba-inc.com>
- Loading branch information
1 parent
39562dc
commit 0b1a974
Showing
10 changed files
with
550 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
119 changes: 119 additions & 0 deletions
119
modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# The implementation here is modified based on ECCV2022-RIFE, | ||
# originally MIT License, Copyright (c) Megvii Inc., | ||
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from .warplayer import warp | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): | ||
return nn.Sequential( | ||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | ||
padding=padding, dilation=dilation, bias=True), | ||
nn.PReLU(out_planes) | ||
) | ||
|
||
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): | ||
return nn.Sequential( | ||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | ||
padding=padding, dilation=dilation, bias=False), | ||
nn.BatchNorm2d(out_planes), | ||
nn.PReLU(out_planes) | ||
) | ||
|
||
class IFBlock(nn.Module): | ||
def __init__(self, in_planes, c=64): | ||
super(IFBlock, self).__init__() | ||
self.conv0 = nn.Sequential( | ||
conv(in_planes, c//2, 3, 2, 1), | ||
conv(c//2, c, 3, 2, 1), | ||
) | ||
self.convblock0 = nn.Sequential( | ||
conv(c, c), | ||
conv(c, c) | ||
) | ||
self.convblock1 = nn.Sequential( | ||
conv(c, c), | ||
conv(c, c) | ||
) | ||
self.convblock2 = nn.Sequential( | ||
conv(c, c), | ||
conv(c, c) | ||
) | ||
self.convblock3 = nn.Sequential( | ||
conv(c, c), | ||
conv(c, c) | ||
) | ||
self.conv1 = nn.Sequential( | ||
nn.ConvTranspose2d(c, c//2, 4, 2, 1), | ||
nn.PReLU(c//2), | ||
nn.ConvTranspose2d(c//2, 4, 4, 2, 1), | ||
) | ||
self.conv2 = nn.Sequential( | ||
nn.ConvTranspose2d(c, c//2, 4, 2, 1), | ||
nn.PReLU(c//2), | ||
nn.ConvTranspose2d(c//2, 1, 4, 2, 1), | ||
) | ||
|
||
def forward(self, x, flow, scale=1): | ||
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) | ||
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale | ||
feat = self.conv0(torch.cat((x, flow), 1)) | ||
feat = self.convblock0(feat) + feat | ||
feat = self.convblock1(feat) + feat | ||
feat = self.convblock2(feat) + feat | ||
feat = self.convblock3(feat) + feat | ||
flow = self.conv1(feat) | ||
mask = self.conv2(feat) | ||
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale | ||
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) | ||
return flow, mask | ||
|
||
class IFNet(nn.Module): | ||
def __init__(self): | ||
super(IFNet, self).__init__() | ||
self.block0 = IFBlock(7+4, c=90) | ||
self.block1 = IFBlock(7+4, c=90) | ||
self.block2 = IFBlock(7+4, c=90) | ||
self.block_tea = IFBlock(10+4, c=90) | ||
# self.contextnet = Contextnet() | ||
# self.unet = Unet() | ||
|
||
def forward(self, x, scale_list=[4, 2, 1], training=False): | ||
if training == False: | ||
channel = x.shape[1] // 2 | ||
img0 = x[:, :channel] | ||
img1 = x[:, channel:] | ||
flow_list = [] | ||
merged = [] | ||
mask_list = [] | ||
warped_img0 = img0 | ||
warped_img1 = img1 | ||
flow = (x[:, :4]).detach() * 0 | ||
mask = (x[:, :1]).detach() * 0 | ||
loss_cons = 0 | ||
block = [self.block0, self.block1, self.block2] | ||
for i in range(3): | ||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) | ||
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) | ||
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 | ||
mask = mask + (m0 + (-m1)) / 2 | ||
mask_list.append(mask) | ||
flow_list.append(flow) | ||
warped_img0 = warp(img0, flow[:, :2]) | ||
warped_img1 = warp(img1, flow[:, 2:4]) | ||
merged.append((warped_img0, warped_img1)) | ||
''' | ||
c0 = self.contextnet(img0, flow[:, :2]) | ||
c1 = self.contextnet(img1, flow[:, 2:4]) | ||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) | ||
res = tmp[:, 1:4] * 2 - 1 | ||
''' | ||
for i in range(3): | ||
mask_list[i] = torch.sigmoid(mask_list[i]) | ||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) | ||
# merged[i] = torch.clamp(merged[i] + res, 0, 1) | ||
return flow_list, mask_list[2], merged |
104 changes: 104 additions & 0 deletions
104
modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# The implementation here is modified based on ECCV2022-RIFE, | ||
# originally MIT License, Copyright (c) Megvii Inc., | ||
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE | ||
|
||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from torch.optim import AdamW | ||
import torch.optim as optim | ||
import itertools | ||
from .warplayer import warp | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from .IFNet_HDv3 import * | ||
import torch.nn.functional as F | ||
from .loss import * | ||
|
||
from modelscope.metainfo import Models | ||
from modelscope.models.base import Tensor | ||
from modelscope.models.base.base_torch_model import TorchModel | ||
from modelscope.models.builder import MODELS | ||
from modelscope.utils.config import Config | ||
from modelscope.utils.constant import ModelFile, Tasks | ||
from modelscope.utils.logger import get_logger | ||
|
||
@MODELS.register_module(Tasks.video_frame_interpolation, module_name=Models.rife) | ||
class RIFEModel(TorchModel): | ||
def __init__(self, model_dir, *args, **kwargs): | ||
super().__init__(model_dir, *args, **kwargs) | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
self.flownet = IFNet() | ||
self.flownet.to(self.device) | ||
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) | ||
self.epe = EPE() | ||
# self.vgg = VGGPerceptualLoss().to(device) | ||
self.sobel = SOBEL() | ||
self.load_model(model_dir, -1) | ||
self.eval() | ||
|
||
def train(self): | ||
self.flownet.train() | ||
|
||
def eval(self): | ||
self.flownet.eval() | ||
|
||
def load_model(self, path, rank=0): | ||
def convert(param): | ||
if rank == -1: | ||
return { | ||
k.replace("module.", ""): v | ||
for k, v in param.items() | ||
if "module." in k | ||
} | ||
else: | ||
return param | ||
if rank <= 0: | ||
if torch.cuda.is_available(): | ||
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) | ||
else: | ||
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu'))) | ||
|
||
def save_model(self, path, rank=0): | ||
if rank == 0: | ||
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) | ||
|
||
def inference(self, img0, img1, scale=1.0): | ||
imgs = torch.cat((img0, img1), 1) | ||
scale_list = [4/scale, 2/scale, 1/scale] | ||
_, _, merged = self.flownet(imgs, scale_list) | ||
return merged[2].detach() | ||
|
||
def forward(self, inputs): | ||
img0 = inputs['img0'] | ||
img1 = inputs['img1'] | ||
scale = inputs['scale'] | ||
return {'output': self.inference(img0, img1, scale)} | ||
|
||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): | ||
for param_group in self.optimG.param_groups: | ||
param_group['lr'] = learning_rate | ||
img0 = imgs[:, :3] | ||
img1 = imgs[:, 3:] | ||
if training: | ||
self.train() | ||
else: | ||
self.eval() | ||
scale = [4, 2, 1] | ||
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) | ||
loss_l1 = (merged[2] - gt).abs().mean() | ||
loss_smooth = self.sobel(flow[2], flow[2]*0).mean() | ||
# loss_vgg = self.vgg(merged[2], gt) | ||
if training: | ||
self.optimG.zero_grad() | ||
loss_G = loss_cons + loss_smooth * 0.1 | ||
loss_G.backward() | ||
self.optimG.step() | ||
else: | ||
flow_teacher = flow[2] | ||
return merged[2], { | ||
'mask': mask, | ||
'flow': flow[2][:, :2], | ||
'loss_l1': loss_l1, | ||
'loss_cons': loss_cons, | ||
'loss_smooth': loss_smooth, | ||
} |
5 changes: 5 additions & 0 deletions
5
modelscope/models/cv/video_frame_interpolation/rife/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# The implementation here is modified based on ECCV2022-RIFE, | ||
# originally MIT License, Copyright (c) Megvii Inc., | ||
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE | ||
|
||
from .RIFE_HDv3 import RIFEModel |
132 changes: 132 additions & 0 deletions
132
modelscope/models/cv/video_frame_interpolation/rife/loss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# The implementation here is modified based on ECCV2022-RIFE, | ||
# originally MIT License, Copyright (c) Megvii Inc., | ||
# and publicly available at https://github.com/megvii-research/ECCV2022-RIFE | ||
|
||
import torch | ||
import numpy as np | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision.models as models | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class EPE(nn.Module): | ||
def __init__(self): | ||
super(EPE, self).__init__() | ||
|
||
def forward(self, flow, gt, loss_mask): | ||
loss_map = (flow - gt.detach()) ** 2 | ||
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 | ||
return (loss_map * loss_mask) | ||
|
||
|
||
class Ternary(nn.Module): | ||
def __init__(self): | ||
super(Ternary, self).__init__() | ||
patch_size = 7 | ||
out_channels = patch_size * patch_size | ||
self.w = np.eye(out_channels).reshape( | ||
(patch_size, patch_size, 1, out_channels)) | ||
self.w = np.transpose(self.w, (3, 2, 0, 1)) | ||
self.w = torch.tensor(self.w).float().to(device) | ||
|
||
def transform(self, img): | ||
patches = F.conv2d(img, self.w, padding=3, bias=None) | ||
transf = patches - img | ||
transf_norm = transf / torch.sqrt(0.81 + transf**2) | ||
return transf_norm | ||
|
||
def rgb2gray(self, rgb): | ||
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] | ||
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b | ||
return gray | ||
|
||
def hamming(self, t1, t2): | ||
dist = (t1 - t2) ** 2 | ||
dist_norm = torch.mean(dist / (0.1 + dist), 1, True) | ||
return dist_norm | ||
|
||
def valid_mask(self, t, padding): | ||
n, _, h, w = t.size() | ||
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) | ||
mask = F.pad(inner, [padding] * 4) | ||
return mask | ||
|
||
def forward(self, img0, img1): | ||
img0 = self.transform(self.rgb2gray(img0)) | ||
img1 = self.transform(self.rgb2gray(img1)) | ||
return self.hamming(img0, img1) * self.valid_mask(img0, 1) | ||
|
||
|
||
class SOBEL(nn.Module): | ||
def __init__(self): | ||
super(SOBEL, self).__init__() | ||
self.kernelX = torch.tensor([ | ||
[1, 0, -1], | ||
[2, 0, -2], | ||
[1, 0, -1], | ||
]).float() | ||
self.kernelY = self.kernelX.clone().T | ||
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) | ||
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) | ||
|
||
def forward(self, pred, gt): | ||
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] | ||
img_stack = torch.cat( | ||
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) | ||
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) | ||
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) | ||
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] | ||
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] | ||
|
||
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) | ||
loss = (L1X+L1Y) | ||
return loss | ||
|
||
class MeanShift(nn.Conv2d): | ||
def __init__(self, data_mean, data_std, data_range=1, norm=True): | ||
c = len(data_mean) | ||
super(MeanShift, self).__init__(c, c, kernel_size=1) | ||
std = torch.Tensor(data_std) | ||
self.weight.data = torch.eye(c).view(c, c, 1, 1) | ||
if norm: | ||
self.weight.data.div_(std.view(c, 1, 1, 1)) | ||
self.bias.data = -1 * data_range * torch.Tensor(data_mean) | ||
self.bias.data.div_(std) | ||
else: | ||
self.weight.data.mul_(std.view(c, 1, 1, 1)) | ||
self.bias.data = data_range * torch.Tensor(data_mean) | ||
self.requires_grad = False | ||
|
||
class VGGPerceptualLoss(torch.nn.Module): | ||
def __init__(self, rank=0): | ||
super(VGGPerceptualLoss, self).__init__() | ||
blocks = [] | ||
pretrained = True | ||
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features | ||
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() | ||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, X, Y, indices=None): | ||
X = self.normalize(X) | ||
Y = self.normalize(Y) | ||
indices = [2, 7, 12, 21, 30] | ||
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] | ||
k = 0 | ||
loss = 0 | ||
for i in range(indices[-1]): | ||
X = self.vgg_pretrained_features[i](X) | ||
Y = self.vgg_pretrained_features[i](Y) | ||
if (i+1) in indices: | ||
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 | ||
k += 1 | ||
return loss | ||
|
||
if __name__ == '__main__': | ||
img0 = torch.zeros(3, 3, 256, 256).float().to(device) | ||
img1 = torch.tensor(np.random.normal( | ||
0, 1, (3, 3, 256, 256))).float().to(device) | ||
ternary_loss = Ternary() | ||
print(ternary_loss(img0, img1).shape) |
Oops, something went wrong.