Skip to content

Commit

Permalink
Convert - Patch writer (#1352)
Browse files Browse the repository at this point in the history
* Add cli-args and config opts for patch writer
* Add supporting code to convert funcs
* writers
  - move frame count functions to _base
  - Add kwargs for pre_encode
* scripts/lib convert - add code for patch writer
* lib.image.encode_image - Add cv2.imencode args
* patch_defaults: Add face index location option
* Add patch writer with PNG support
* Send correct matrix to patch plugin
* Add canvas origin option
* Add Tiff format to Patch Writer
* Docs and locales
* Add ROI to output
* convert: choose warp border by face count
  • Loading branch information
torzdf committed Sep 28, 2023
1 parent a660eda commit e80ae04
Show file tree
Hide file tree
Showing 19 changed files with 1,060 additions and 393 deletions.
8 changes: 8 additions & 0 deletions docs/full/plugins/convert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ writer.opencv module
:undoc-members:
:show-inheritance:

writer.patch module
--------------------

.. automodule:: plugins.convert.writer.patch
:members:
:undoc-members:
:show-inheritance:

writer.pillow module
--------------------

Expand Down
4 changes: 4 additions & 0 deletions lib/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,10 @@ def get_optional_arguments() -> list[dict[str, T.Any]]:
"\nL|gif: [animated image] Create an animated gif."
"\nL|opencv: [images] The fastest image writer, but less options and formats "
"than other plugins."
"\nL|patch: [images] Outputs the raw swapped face patch, along with the "
"transformation matrix required to re-insert the face back into the original "
"frame. Use this option if you wish to post-process and composite the final "
"face within external tools."
"\nL|pillow: [images] Slower than opencv, but has more options and supports "
"more formats.")))
argument_list.append(dict(
Expand Down
158 changes: 114 additions & 44 deletions lib/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self,
coverage_ratio: float,
centering: CenteringType,
draw_transparent: bool,
pre_encode: Callable[[np.ndarray], list[bytes]] | None,
pre_encode: Callable | None,
arguments: Namespace,
configfile: str | None = None) -> None:
logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, "
Expand All @@ -94,8 +94,9 @@ def __init__(self,
self._configfile = configfile

self._scale = arguments.output_scale / 100
self._face_scale = 1.0 - self._args.face_scale / 100.
self._face_scale = 1.0 - arguments.face_scale / 100.
self._adjustments = Adjustments()
self._full_frame_output: bool = arguments.writer != "patch"

self._load_plugins()
logger.debug("Initialized %s", self.__class__.__name__)
Expand Down Expand Up @@ -183,7 +184,7 @@ def process(self, in_queue: EventQueue, out_queue: EventQueue):
"""
logger.debug("Starting convert process. (in_queue: %s, out_queue: %s)",
in_queue, out_queue)
log_once = False
logged = False
while True:
inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get()
if inbound == "EOF":
Expand All @@ -196,7 +197,8 @@ def process(self, in_queue: EventQueue, out_queue: EventQueue):

items = inbound if isinstance(inbound, list) else [inbound]
for item in items:
logger.trace("Patch queue got: '%s'", item.inbound.filename) # type: ignore
logger.trace("Patch queue got: '%s'", # type: ignore[attr-defined]
item.inbound.filename)
try:
image = self._patch_image(item)
except Exception as err: # pylint: disable=broad-except
Expand All @@ -205,16 +207,42 @@ def process(self, in_queue: EventQueue, out_queue: EventQueue):
item.inbound.filename, str(err))
image = item.inbound.image

loglevel = logger.trace if log_once else logger.warning # type: ignore
loglevel("Convert error traceback:", exc_info=True)
log_once = True
lvl = logger.trace if logged else logger.warning # type: ignore[attr-defined]
lvl("Convert error traceback:", exc_info=True)
logged = True
# UNCOMMENT THIS CODE BLOCK TO PRINT TRACEBACK ERRORS
# import sys; import traceback
# exc_info = sys.exc_info(); traceback.print_exception(*exc_info)
logger.trace("Out queue put: %s", item.inbound.filename) # type: ignore
logger.trace("Out queue put: %s", # type: ignore[attr-defined]
item.inbound.filename)
out_queue.put((item.inbound.filename, image))
logger.debug("Completed convert process")

def _get_warp_matrix(self, matrix: np.ndarray, size: int) -> np.ndarray:
""" Obtain the final scaled warp transformation matrix based on face scaling from the
original transformation matrix
Parameters
----------
matrix: :class:`numpy.ndarray`
The transformation for patching the swapped face back onto the output frame
size: int
The size of the face patch, in pixels
Returns
-------
:class:`numpy.ndarray`
The final transformation matrix with any scaling applied
"""
if self._face_scale == 1.0:
mat = matrix
else:
mat = matrix * self._face_scale
patch_center = (size / 2, size / 2)
mat[..., 2] += (1 - self._face_scale) * np.array(patch_center)

return mat

def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]:
""" Patch a swapped face onto a frame.
Expand All @@ -233,22 +261,67 @@ def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]:
function (if it has one)
"""
logger.trace("Patching image: '%s'", predicted.inbound.filename) # type: ignore
logger.trace("Patching image: '%s'", # type: ignore[attr-defined]
predicted.inbound.filename)
frame_size = (predicted.inbound.image.shape[1], predicted.inbound.image.shape[0])
new_image, background = self._get_new_image(predicted, frame_size)
patched_face = self._post_warp_adjustments(background, new_image)
patched_face = self._scale_image(patched_face)
patched_face *= 255.0
patched_face = np.rint(patched_face,
out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe')

if self._full_frame_output:
patched_face = self._post_warp_adjustments(background, new_image)
patched_face = self._scale_image(patched_face)
patched_face *= 255.0
patched_face = np.rint(patched_face,
out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe')
else:
patched_face = new_image

if self._writer_pre_encode is None:
retval: np.ndarray | list[bytes] = patched_face
else:
retval = self._writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore
kwargs: dict[str, T.Any] = {}
if self.cli_arguments.writer == "patch":
kwargs["canvas_size"] = (background.shape[1], background.shape[0])
kwargs["matrices"] = np.array([self._get_warp_matrix(face.adjusted_matrix,
patched_face.shape[1])
for face in predicted.reference_faces],
dtype="float32")
retval = self._writer_pre_encode(patched_face, **kwargs)
logger.trace("Patched image: '%s'", # type: ignore[attr-defined]
predicted.inbound.filename)
return retval

def _warp_to_frame(self,
reference: AlignedFace,
face: np.ndarray,
frame: np.ndarray,
multiple_faces: bool) -> None:
""" Perform affine transformation to place a face patch onto the given frame.
Affine is done in place on the `frame` array, so this function does not return a value
Parameters
----------
reference: :class:`lib.align.AlignedFace`
The object holding the original aligned face
face: :class:`numpy.ndarray`
The swapped face patch
frame: :class:`numpy.ndarray`
The frame to affine the face onto
multiple_faces: bool
Controls the border mode to use. Uses BORDER_CONSTANT if there is only 1 face in
the image, otherwise uses the inferior BORDER_TRANSPARENT
"""
# Warp face with the mask
mat = self._get_warp_matrix(reference.adjusted_matrix, face.shape[0])
border = cv2.BORDER_TRANSPARENT if multiple_faces else cv2.BORDER_CONSTANT
cv2.warpAffine(face,
mat,
(frame.shape[1], frame.shape[0]),
frame,
flags=cv2.WARP_INVERSE_MAP | reference.interpolators[1],
borderMode=border)

def _get_new_image(self,
predicted: ConvertItem,
frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -271,41 +344,38 @@ def _get_new_image(self,
background: :class: `numpy.ndarray`
The original frame
"""
logger.trace("Getting: (filename: '%s', faces: %s)", # type: ignore
logger.trace("Getting: (filename: '%s', faces: %s)", # type: ignore[attr-defined]
predicted.inbound.filename, len(predicted.swapped_faces))

placeholder = np.zeros((frame_size[1], frame_size[0], 4), dtype="float32")
background = predicted.inbound.image / np.array(255.0, dtype="float32")
placeholder[:, :, :3] = background
if self._full_frame_output:
background = predicted.inbound.image / np.array(255.0, dtype="float32")
placeholder[:, :, :3] = background
else:
faces = [] # Collect the faces into final array
background = placeholder # Used for obtaining original frame dimensions

for new_face, detected_face, reference_face in zip(predicted.swapped_faces,
predicted.inbound.detected_faces,
predicted.reference_faces):
predicted_mask = new_face[:, :, -1] if new_face.shape[2] == 4 else None
new_face = new_face[:, :, :3]
interpolator = reference_face.interpolators[1]

new_face = self._pre_warp_adjustments(new_face,
detected_face,
reference_face,
predicted_mask)

# Warp face with the mask
if self._face_scale == 1.0:
mat = reference_face.adjusted_matrix
if self._full_frame_output:
self._warp_to_frame(reference_face,
new_face, placeholder,
len(predicted.swapped_faces) > 1)
else:
mat = reference_face.adjusted_matrix * self._face_scale
patch_center = (new_face.shape[1] / 2, new_face.shape[0] / 2)
mat[..., 2] += (1 - self._face_scale) * np.array(patch_center)

cv2.warpAffine(new_face,
mat,
frame_size,
placeholder,
flags=cv2.WARP_INVERSE_MAP | interpolator,
borderMode=cv2.BORDER_TRANSPARENT)

logger.trace("Got filename: '%s'. (placeholders: %s)", # type: ignore
faces.append(new_face)

if not self._full_frame_output:
placeholder = np.array(faces, dtype="float32")

logger.trace("Got filename: '%s'. (placeholders: %s)", # type: ignore[attr-defined]
predicted.inbound.filename, placeholder.shape)

return placeholder, background
Expand Down Expand Up @@ -339,7 +409,7 @@ def _pre_warp_adjustments(self,
The face output from the Faceswap Model with any requested pre-warp adjustments
performed.
"""
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore[attr-defined]
new_face.shape, predicted_mask.shape if predicted_mask is not None else None)
old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0
new_face, raw_mask = self._get_image_mask(new_face,
Expand All @@ -350,7 +420,7 @@ def _pre_warp_adjustments(self,
new_face = self._adjustments.color.run(old_face, new_face, raw_mask)
if self._adjustments.seamless is not None:
new_face = self._adjustments.seamless.run(old_face, new_face, raw_mask)
logger.trace("returning: new_face shape %s", new_face.shape) # type: ignore
logger.trace("returning: new_face shape %s", new_face.shape) # type: ignore[attr-defined]
return new_face

def _get_image_mask(self,
Expand Down Expand Up @@ -381,7 +451,7 @@ def _get_image_mask(self,
:class:`numpy.ndarray`
The raw mask with no erosion or blurring applied
"""
logger.trace("Getting mask. Image shape: %s", new_face.shape) # type: ignore
logger.trace("Getting mask. Image shape: %s", new_face.shape) # type: ignore[attr-defined]
if self._args.mask_type not in ("none", "predicted"):
mask_centering = detected_face.mask[self._args.mask_type].stored_centering
else:
Expand All @@ -392,9 +462,9 @@ def _get_image_mask(self,
reference_face.pose.offset[self._centering],
self._centering,
predicted_mask=predicted_mask)
logger.trace("Adding mask to alpha channel") # type: ignore
logger.trace("Adding mask to alpha channel") # type: ignore[attr-defined]
new_face = np.concatenate((new_face, mask), -1)
logger.trace("Got mask. Image shape: %s", new_face.shape) # type: ignore
logger.trace("Got mask. Image shape: %s", new_face.shape) # type: ignore[attr-defined]
return new_face, raw_mask

def _post_warp_adjustments(self, background: np.ndarray, new_image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -447,11 +517,11 @@ def _scale_image(self, frame: np.ndarray) -> np.ndarray:
"""
if self._scale == 1:
return frame
logger.trace("source frame: %s", frame.shape) # type: ignore
logger.trace("source frame: %s", frame.shape) # type: ignore[attr-defined]
interp = cv2.INTER_CUBIC if self._scale > 1 else cv2.INTER_AREA
dims = (round((frame.shape[1] / 2 * self._scale) * 2),
round((frame.shape[0] / 2 * self._scale) * 2))
frame = cv2.resize(frame, dims, interpolation=interp)
logger.trace("resized frame: %s", frame.shape) # type: ignore
logger.trace("resized frame: %s", frame.shape) # type: ignore[attr-defined]
np.clip(frame, 0.0, 1.0, out=frame)
return frame

0 comments on commit e80ae04

Please sign in to comment.