Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding "Image Align with Rife" and "Wavelet Color Fix" Nodes #2714

Merged
merged 22 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
81b5eb8
adding "Image Align with Rife" and "Wavelet Color Fix" nodes
pifroggi Mar 28, 2024
960fe04
Create test
pifroggi Mar 28, 2024
8459946
adding additional files for the "Image Align with Rife" node and the …
pifroggi Mar 28, 2024
e91d49f
Delete backend/src/packages/chaiNNer_pytorch/pytorch/processing/rife/…
pifroggi Mar 28, 2024
5ac2d52
Update IFNet_HDv3_v4_14_align.py
pifroggi Mar 28, 2024
feca441
Update image_align_rife.py
pifroggi Mar 28, 2024
479f098
Update wavelet_color_fix.py
pifroggi Mar 28, 2024
6db17e4
Merge branch 'chaiNNer-org:main' into main
pifroggi Mar 30, 2024
2394ff9
Delete rife model
pifroggi Mar 30, 2024
086ae82
update image_align_rife.py to download rife model only when needed
pifroggi Mar 30, 2024
d613693
cosmetic fixes image_align_rife.py
pifroggi Mar 30, 2024
592a841
Merge branch 'chaiNNer-org:main' into main
pifroggi Apr 14, 2024
075de83
added minimums, removed comments, changed download, changed name, ruff
pifroggi Apr 14, 2024
c7ad0ae
added minimum for wavelet number
pifroggi Apr 14, 2024
819a982
removed commented out code
pifroggi Apr 14, 2024
81fe7aa
removed commented out code
pifroggi Apr 14, 2024
fa8b9b5
cosmetics
pifroggi Apr 14, 2024
0219bc6
Merge remote-tracking branch 'origin/main'
joeyballentine May 10, 2024
93e7ebc
Run ruff formatting
joeyballentine May 10, 2024
9d31aea
fixes, ignore a ton of stuff
joeyballentine May 10, 2024
61a0bf1
move some files and reorder nodes in list
joeyballentine May 10, 2024
982077b
fix pyright errors
joeyballentine May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@

# Original Rife Frame Interpolation by hzwer
# https://github.com/megvii-research/ECCV2022-RIFE
# https://github.com/hzwer/Practical-RIFE

# Modifications to use Rife for Image Alignment by tepete/pifroggi ('Enhance Everything!' Discord Server)

# Additional helpful github issues
# https://github.com/megvii-research/ECCV2022-RIFE/issues/278
# https://github.com/megvii-research/ECCV2022-RIFE/issues/344

import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.append(project_root)

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import requests
import zipfile
from pathlib import Path
from api import NodeContext
from nodes.properties.inputs import ImageInput, EnumInput, NumberInput, BoolInput, SliderInput
from nodes.properties.outputs import ImageOutput
from ...settings import PyTorchSettings, get_settings
from .. import processing_group
from nodes.impl.pytorch.utils import np2tensor, tensor2np
from nodes.impl.resize import resize, ResizeFilter
from nodes.utils.utils import get_h_w_c
from packages.chaiNNer_pytorch.pytorch.processing.rife.IFNet_HDv3_v4_14_align import IFNet
from enum import Enum

class PrecisionMode(Enum):
FIFTY_PERCENT = 2000
ONE_HUNDRED_PERCENT = 1000
TWO_HUNDRED_PERCENT = 500
FOUR_HUNDRED_PERCENT = 250
EIGHT_HUNDRED_PERCENT = 125

def calculate_padding(height, width, precision_mode):
if precision_mode == PrecisionMode.EIGHT_HUNDRED_PERCENT:
pad_value = 4
elif precision_mode == PrecisionMode.FOUR_HUNDRED_PERCENT:
pad_value = 8
elif precision_mode == PrecisionMode.TWO_HUNDRED_PERCENT:
pad_value = 16
elif precision_mode == PrecisionMode.ONE_HUNDRED_PERCENT:
pad_value = 32
else:
pad_value = 64

pad_height = (pad_value - height % pad_value) % pad_value
pad_width = (pad_value - width % pad_value) % pad_value
return pad_height, pad_width

def download_model(download_url, model_path, model_file, zip_inner_path):
model_dir = Path(model_path)
model_dir.mkdir(parents=True, exist_ok=True)
zip_path = model_dir / "model.zip"
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved

if not (model_dir / model_file).exists():
try:
response = requests.get(download_url)
response.raise_for_status()
with open(zip_path, 'wb') as f:
f.write(response.content)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
specific_file_path = zip_inner_path + '/' + model_file
zip_ref.extract(specific_file_path, model_dir)
extracted_file_path = model_dir / specific_file_path
final_path = model_dir / model_file
if extracted_file_path != final_path:
extracted_file_path.rename(final_path)

#cleanup
train_log_dir = model_dir / zip_inner_path
train_log_dir.rmdir()

zip_path.unlink()
except requests.RequestException as e:
print(f"Failed to download the model. Error: {e}")

def align_images(context, target_img: np.ndarray, source_img: np.ndarray, precision_mode, model_path='python/models/rife_v4.14', model_file='flownet.pkl', multiplier=1, alignment_passes=1, blur_strength=0, ensemble=True) -> np.ndarray:
download_url = "https://drive.usercontent.google.com/download?id=1BjuEY7CHZv1wzmwXSQP9ZTj0mLWu_4xy&export=download&authuser=0"
zip_inner_path = 'train_log'
download_model(download_url, model_path, model_file, zip_inner_path)

source_h, source_w, _ = get_h_w_c(source_img)
target_h, target_w, _ = get_h_w_c(target_img)

#resize, then shift reference left because rife shifts slightly to the right)
target_img_resized = resize(target_img, (source_w, source_h), filter=ResizeFilter.LANCZOS)
target_img_resized = np.roll(target_img_resized, -1, axis=1)
target_img_resized[:, -1] = target_img_resized[:, -2]

#resize, then shift reference left because rife shifts slightly to the right)
#different approach to do subpixel shift
#scale_x = source_w / target_w
#scale_y = source_h / target_h
#shift_x = -1.0
#shift_y = 0
#transformation_matrix = np.float32([[scale_x, 0, shift_x], [0, scale_y, shift_y]])
#target_img_resized = cv2.warpAffine(target_img, transformation_matrix, (source_w, source_h), flags=cv2.INTER_LANCZOS4, borderMode=cv2.BORDER_REPLICATE)

#padding because rife can only work with multiples of 32 (changes with precision mode)
pad_h, pad_w = calculate_padding(source_h, source_w, precision_mode)
top_pad = pad_h // 2
bottom_pad = pad_h - top_pad
left_pad = pad_w // 2
right_pad = pad_w - left_pad
target_img_padded = np.pad(target_img_resized, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), mode='edge')
source_img_padded = np.pad(source_img, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), mode='edge')
#target_img_padded = target_img
#source_img_padded = source_img_resized

exec_options = get_settings(context)
device = exec_options.device

#load model
model_full_path = os.path.join(model_path, model_file)
model = IFNet().to(device)#.half()
state_dict = torch.load(model_full_path, map_location=device)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()

#convert to tensors
target_tensor_padded = np2tensor(target_img_padded, change_range=True).to(device)#.half()
source_tensor_padded = np2tensor(source_img_padded, change_range=True).to(device)#.half()

#concatenate images
img_pair = torch.cat((target_tensor_padded, source_tensor_padded), dim=1)

with torch.no_grad():
aligned_img, _ = model(img_pair, multiplier=multiplier, num_iterations=alignment_passes, blur_strength=blur_strength, ensemble=ensemble, device=device)

#convert back to numpy and crop
result_img = tensor2np(aligned_img.squeeze(0).cpu(), change_range=False, imtype=np.float32)
result_img = result_img[top_pad:top_pad+source_h, left_pad:left_pad+source_w]

return result_img





@processing_group.register(
schema_id="chainner:pytorch:image_align_rife",
name="Image Align with Rife",
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
description="Aligns an Image with a Reference Image using Rife. Images should have vague alignment before using this Node. Output Image will have the same dimensions as Reference Image. Resize Reference Image to get desired output scale.",
icon="BsRulers",
inputs=[
ImageInput(label="Image", channels=3),
ImageInput(label="Reference Image", channels=3),
EnumInput(
PrecisionMode,
label="Precision",
default=PrecisionMode.ONE_HUNDRED_PERCENT,
option_labels={
PrecisionMode.FIFTY_PERCENT: "50%",
PrecisionMode.ONE_HUNDRED_PERCENT: "100%",
PrecisionMode.TWO_HUNDRED_PERCENT: "200%",
PrecisionMode.FOUR_HUNDRED_PERCENT: "400%",
PrecisionMode.EIGHT_HUNDRED_PERCENT: "800% (VRAM!)",
},
)
.with_docs(
"If the Alignment is very close, try a **high** value.",
"If the Alignment is **not** very close, try a **low** value.",
"Higher values will internally align at higher resolutions to increase precision, which will in turn increase processing time and VRAM usage. Lower values are less precise, but can align over larger distances.",
hint=True,
),
NumberInput(
"Alignment Passes",
controls_step=1,
maximum=1000,
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
default=1,
unit="#",
)
.with_docs(
"Runs the alignment multiple times.",
"With more than around 4 passes, artifacts can appear. Try to keep it low.",
hint=True,
),
NumberInput(
"Blur Strength",
minimum=0,
maximum=100,
default=0,
precision=1,
controls_step=1,
unit="⌀"
)
.with_docs(
"Blur is only used internally and will not be visible on the Output Image. It will reduce accuracy, try to keep it **low**. The **best** alignment will be at **Blur 0**.",
"Blur can help to ignore strong degredations (like compression or noise). If the lines on the Output Image get thinner or thicker, try to increase the blur a little as well.",
hint=True,
),
],
outputs=[
ImageOutput().with_never_reason("Returns the aligned image.")
],
node_context=True,
)
def image_aligner_node(context, target_img: np.ndarray, source_img: np.ndarray, precision: PrecisionMode, alignment_passes: int, blur_strength: float) -> np.ndarray:
multiplier = precision.value / 1000
return align_images(context, target_img, source_img, precision, multiplier=multiplier, alignment_passes=alignment_passes, blur_strength=blur_strength, ensemble=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@

# Original Rife Frame Interpolation by hzwer
# https://github.com/megvii-research/ECCV2022-RIFE
# https://github.com/hzwer/Practical-RIFE

# Modifications to use Rife for Image Alignment by tepete/pifroggi ('Enhance Everything!' Discord Server)

# Additional helpful github issues
# https://github.com/megvii-research/ECCV2022-RIFE/issues/278
# https://github.com/megvii-research/ECCV2022-RIFE/issues/344

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from .warplayer import warp
# from train_log.refine import *
#import logging

# Setup logging
#logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s %(levelname)s:%(message)s')


#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.LeakyReLU(0.2, True)
)

def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True)
)

class Head(nn.Module):
def __init__(self):
super(Head, self).__init__()
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1)
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3

class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super(ResConv, self).__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\
)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)

class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
nn.PixelShuffle(2)
)

def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
return flow, mask










class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+16, c=192)
self.block1 = IFBlock(8+4+16, c=128)
self.block2 = IFBlock(8+4+16, c=96)
self.block3 = IFBlock(8+4+16, c=64)
self.encode = Head()

def align_images(self, img0, img1, timestep, scale_list, blur_strength, ensemble, device):
#optional blur
if blur_strength is not None and blur_strength > 0:
blur = transforms.GaussianBlur(kernel_size=(5, 5), sigma=(blur_strength, blur_strength))
img0_blurred = blur(img0)
img1_blurred = blur(img1)
else:
img0_blurred = img0
img1_blurred = img1

f0 = self.encode(img0_blurred[:, :3])
f1 = self.encode(img1_blurred[:, :3])
flow_list = []
mask_list = []
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](torch.cat((img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i])
if ensemble:
f_, m_ = block[i](torch.cat((img1_blurred[:, :3], img0_blurred[:, :3], f1, f0, 1-timestep), 1), None, scale=scale_list[i])
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2], device)
wf1 = warp(f1, flow[:, 2:4], device)
fd, m0 = block[i](torch.cat((img0_blurred[:, :3], img1_blurred[:, :3], wf0, wf1, timestep, mask), 1), flow, scale=scale_list[i])
if ensemble:
f_, m_ = block[i](torch.cat((img1_blurred[:, :3], img0_blurred[:, :3], wf1, wf0, 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)

#apply warp to original image
aligned_img0 = warp(img0, flow_list[-1][:, :2], device)
aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0)
return aligned_img0, flow_list[-1]

def forward(self, x, timestep=1, training=False, fastmode=True, ensemble=True, num_iterations=1, multiplier=0.5, blur_strength=0, device='cuda'):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]

scale_list = [multiplier * 8, multiplier * 4, multiplier * 2, multiplier]
#logging.debug(f"Generated scale list: {scale_list}")

if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])

for iteration in range(num_iterations):
aligned_img0, flow = self.align_images(img0, img1, timestep, scale_list, blur_strength, ensemble, device)
img0 = aligned_img0 # Use the aligned image as img0 for the next iteration

return aligned_img0, flow
Loading
Loading