Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding "Image Align with Rife" and "Wavelet Color Fix" Nodes #2714

Merged
merged 22 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
81b5eb8
adding "Image Align with Rife" and "Wavelet Color Fix" nodes
pifroggi Mar 28, 2024
960fe04
Create test
pifroggi Mar 28, 2024
8459946
adding additional files for the "Image Align with Rife" node and the …
pifroggi Mar 28, 2024
e91d49f
Delete backend/src/packages/chaiNNer_pytorch/pytorch/processing/rife/…
pifroggi Mar 28, 2024
5ac2d52
Update IFNet_HDv3_v4_14_align.py
pifroggi Mar 28, 2024
feca441
Update image_align_rife.py
pifroggi Mar 28, 2024
479f098
Update wavelet_color_fix.py
pifroggi Mar 28, 2024
6db17e4
Merge branch 'chaiNNer-org:main' into main
pifroggi Mar 30, 2024
2394ff9
Delete rife model
pifroggi Mar 30, 2024
086ae82
update image_align_rife.py to download rife model only when needed
pifroggi Mar 30, 2024
d613693
cosmetic fixes image_align_rife.py
pifroggi Mar 30, 2024
592a841
Merge branch 'chaiNNer-org:main' into main
pifroggi Apr 14, 2024
075de83
added minimums, removed comments, changed download, changed name, ruff
pifroggi Apr 14, 2024
c7ad0ae
added minimum for wavelet number
pifroggi Apr 14, 2024
819a982
removed commented out code
pifroggi Apr 14, 2024
81fe7aa
removed commented out code
pifroggi Apr 14, 2024
fa8b9b5
cosmetics
pifroggi Apr 14, 2024
0219bc6
Merge remote-tracking branch 'origin/main'
joeyballentine May 10, 2024
93e7ebc
Run ruff formatting
joeyballentine May 10, 2024
9d31aea
fixes, ignore a ton of stuff
joeyballentine May 10, 2024
61a0bf1
move some files and reorder nodes in list
joeyballentine May 10, 2024
982077b
fix pyright errors
joeyballentine May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
271 changes: 271 additions & 0 deletions backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# type: ignore
# Original Rife Frame Interpolation by hzwer
# https://github.com/megvii-research/ECCV2022-RIFE
# https://github.com/hzwer/Practical-RIFE

# Modifications to use Rife for Image Alignment by tepete/pifroggi ('Enhance Everything!' Discord Server)

# Additional helpful github issues
# https://github.com/megvii-research/ECCV2022-RIFE/issues/278
# https://github.com/megvii-research/ECCV2022-RIFE/issues/344

import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torchvision import transforms

from .warplayer import warp


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): # noqa: ANN001
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)


def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): # noqa: ANN001
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True),
)


class Head(nn.Module):
def __init__(self):
super().__init__()
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1)
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x, feat=False): # noqa: ANN001
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3


class ResConv(nn.Module):
def __init__(self, c, dilation=1): # noqa: ANN001
super().__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x): # noqa: ANN001
return self.relu(self.conv(x) * self.beta + x)


class IFBlock(nn.Module):
def __init__(self, in_planes, c=64): # noqa: ANN001
super().__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)
)

def forward(self, x, flow=None, scale=1): # noqa: ANN001
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
if flow is not None:
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(
tmp, scale_factor=scale, mode="bilinear", align_corners=False
)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
return flow, mask


class IFNet(nn.Module):
def __init__(self):
super().__init__()
self.block0 = IFBlock(7 + 16, c=192)
self.block1 = IFBlock(8 + 4 + 16, c=128)
self.block2 = IFBlock(8 + 4 + 16, c=96)
self.block3 = IFBlock(8 + 4 + 16, c=64)
self.encode = Head()

def align_images(
self,
img0, # noqa: ANN001
img1, # noqa: ANN001
timestep, # noqa: ANN001
scale_list, # noqa: ANN001
blur_strength, # noqa: ANN001
ensemble, # noqa: ANN001
device, # noqa: ANN001
):
# optional blur
if blur_strength is not None and blur_strength > 0:
blur = transforms.GaussianBlur(
kernel_size=(5, 5), sigma=(blur_strength, blur_strength)
)
img0_blurred = blur(img0)
img1_blurred = blur(img1)
else:
img0_blurred = img0
img1_blurred = img1

f0 = self.encode(img0_blurred[:, :3])
f1 = self.encode(img1_blurred[:, :3])
flow_list = []
mask_list = []
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat(
(img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1
),
None,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
f1,
f0,
1 - timestep,
),
1,
),
None,
scale=scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2], device)
wf1 = warp(f1, flow[:, 2:4], device)
fd, m0 = block[i](
torch.cat(
(
img0_blurred[:, :3],
img1_blurred[:, :3],
wf0,
wf1,
timestep,
mask,
),
1,
),
flow,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)

# apply warp to original image
aligned_img0 = warp(img0, flow_list[-1][:, :2], device)

# add clamp here instead of in warplayer script, as it changes the output there
aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0)
return aligned_img0, flow_list[-1]

def forward(
self,
x, # noqa: ANN001
timestep=1, # noqa: ANN001
training=False, # noqa: ANN001
fastmode=True, # noqa: ANN001
ensemble=True, # noqa: ANN001
num_iterations=1, # noqa: ANN001
multiplier=0.5, # noqa: ANN001
blur_strength=0, # noqa: ANN001
device="cuda", # noqa: ANN001
):
if not training:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]

scale_list = [multiplier * 8, multiplier * 4, multiplier * 2, multiplier]

if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) # type: ignore

for _iteration in range(num_iterations):
aligned_img0, flow = self.align_images(
img0, img1, timestep, scale_list, blur_strength, ensemble, device
)
img0 = aligned_img0 # use the aligned image as img0 for the next iteration

return aligned_img0, flow
38 changes: 38 additions & 0 deletions backend/src/nodes/impl/pytorch/rife/warplayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# type: ignore
import torch

backwarp_tenGrid = {} # noqa: N816


def warp(tenInput, tenFlow, device): # noqa: ANN001, N803
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = ( # noqa: N806
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
.view(1, 1, 1, tenFlow.shape[3])
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
)
tenVertical = ( # noqa: N806
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
.view(1, 1, tenFlow.shape[2], 1)
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
)
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)

tenFlow = torch.cat( # noqa: N806
[
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
],
1,
)

g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
tenOutput = torch.nn.functional.grid_sample(
input=tenInput,
grid=g,
mode="bicubic",
padding_mode="border",
align_corners=True,
)
return tenOutput
5 changes: 5 additions & 0 deletions backend/src/packages/chaiNNer_pytorch/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
restoration_group = pytorch_category.add_node_group("Restoration")
batch_processing_group = pytorch_category.add_node_group("Batch Processing")
utility_group = pytorch_category.add_node_group("Utility")

processing_group.order = [
"chainner:pytorch:upscale_image",
"chainner:pytorch:inpaint",
]