Skip to content

Commit

Permalink
fix batch face_enhancer
Browse files Browse the repository at this point in the history
  • Loading branch information
kex0 committed Jun 26, 2023
1 parent f54c54b commit 1ae7fd8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions roop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def start() -> None:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
frame_processor.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
# frame_processor.post_process()
frame_processor.post_process()
release_resources()
if is_image(roop.globals.target_path):
update_status('Processing to image succeed!')
Expand All @@ -239,7 +239,7 @@ def start() -> None:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
# frame_processor.post_process()
frame_processor.post_process()
release_resources()
# handles fps
if roop.globals.keep_fps:
Expand Down
9 changes: 6 additions & 3 deletions roop/processors/frame/face_enhancer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List, Callable
import os
import cv2
import threading
import gfpgan
Expand Down Expand Up @@ -62,8 +63,9 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
return temp_frame


def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
def process_frames(source_path: str, temp_directory_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
for temp_frame_path in temp_frame_paths:
temp_frame_path = os.path.join(temp_directory_path, os.path.basename(temp_frame_path))
temp_frame = cv2.imread(temp_frame_path)
result = process_frame(None, temp_frame)
cv2.imwrite(temp_frame_path, result)
Expand All @@ -72,10 +74,11 @@ def process_frames(source_path: str, temp_frame_paths: List[str], update: Callab


def process_image(source_path: str, target_path: str, output_path: str) -> None:
target_frame = cv2.imread(target_path)
target_frame = cv2.imread(output_path)
result = process_frame(None, target_frame)
cv2.imwrite(output_path, result)


def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
roop.processors.frame.core.process_video(None, temp_frame_paths, process_frames)
temp_directory_path = os.path.join(os.path.dirname(temp_frame_paths[0]), os.path.splitext(os.path.basename(roop.globals.source_path))[0])
roop.processors.frame.core.process_video(None, temp_directory_path, temp_frame_paths, process_frames)

0 comments on commit 1ae7fd8

Please sign in to comment.