Skip to content

Commit

Permalink
Alignments Tool updates
Browse files Browse the repository at this point in the history
  - Copy info back to alignments file from faces
  • Loading branch information
torzdf committed Sep 26, 2022
1 parent 5805d76 commit c79175c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 53 deletions.
2 changes: 1 addition & 1 deletion lib/align/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def video_meta_data(self) -> Dict[str, Optional[Union[List[int], List[float]]]]:
pts_time: List[float] = []
keyframes: List[int] = []
for idx, key in enumerate(sorted(self.data)):
if not self.data[key]["video_meta"]:
if not self.data[key].get("video_meta", {}):
return retval
meta = self.data[key]["video_meta"]
pts_time.append(cast(float, meta["pts_time"]))
Expand Down
23 changes: 16 additions & 7 deletions tools/alignments/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm import tqdm

from .media import Faces, Frames
from .jobs_faces import FaceToFile

if sys.version_info < (3, 8):
from typing_extensions import Literal
Expand All @@ -21,7 +22,7 @@

if TYPE_CHECKING:
from argparse import Namespace
from lib.align.alignments import PNGHeaderSourceDict
from lib.align.alignments import PNGHeaderDict
from .media import AlignmentData

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,7 +81,7 @@ def _get_source_dir(self, arguments: "Namespace") -> str:
logger.debug("type: '%s', source_dir: '%s'", self._type, source_dir)
return source_dir

def _get_items(self) -> Union[List[Dict[str, str]], List[Dict[str, "PNGHeaderSourceDict"]]]:
def _get_items(self) -> Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]]:
""" Set the correct items to process
Returns
Expand All @@ -93,14 +94,21 @@ def _get_items(self) -> Union[List[Dict[str, str]], List[Dict[str, "PNGHeaderSou
assert self._type is not None
items: Union[Frames, Faces] = globals()[self._type.title()](self._source_dir)
self._is_video = items.is_video
return cast(Union[List[Dict[str, str]], List[Dict[str, "PNGHeaderSourceDict"]]],
return cast(Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]],
items.file_list_sorted)

def process(self) -> None:
""" Process the frames check against the alignments file """
assert self._type is not None
logger.info("[CHECK %s]", self._type.upper())
items_output = self._compile_output()

if self._type == "faces":
filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._items)
check_update = FaceToFile(self._alignments, [val[1] for val in filelist])
if check_update():
self._alignments.save()

self._output_results(items_output)

def _validate(self) -> None:
Expand Down Expand Up @@ -185,12 +193,13 @@ def _get_multi_faces_faces(self) -> Generator[Tuple[str, int], None, None]:
The frame name and the face id of any frames which have multiple faces
"""
self.output_message = "Multiple faces in frame"
for item in tqdm(cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._items),
for item in tqdm(cast(List[Tuple[str, "PNGHeaderDict"]], self._items),
desc=self.output_message,
leave=False):
if not self._alignments.frame_has_multiple_faces(item["source_filename"]):
src = item[1]["source"]
if not self._alignments.frame_has_multiple_faces(src["source_filename"]):
continue
retval = (item[0], item[1]["face_index"])
retval = (item[0], src["face_index"])
logger.trace("Returning: '%s'", retval) # type:ignore
yield retval

Expand Down Expand Up @@ -222,7 +231,7 @@ def _get_missing_frames(self) -> Generator[str, None, None]:
The frame name of any frames in alignments with no matching file
"""
self.output_message = "Missing frames that are in alignments file"
frames = set(item["frame_fullname"] for item in self._items)
frames = set(item["frame_fullname"] for item in cast(List[Dict[str, str]], self._items))
for frame in tqdm(self._alignments.data.keys(), desc=self.output_message, leave=False):
if frame not in frames:
logger.debug("Returning: '%s'", frame)
Expand Down
111 changes: 100 additions & 11 deletions tools/alignments/jobs_faces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
""" Tools for manipulating the alignments using extracted Faces as a source """
import os
import logging
import os
import sys
from argparse import Namespace
from operator import itemgetter
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING

import numpy as np
from tqdm import tqdm
Expand All @@ -15,10 +16,15 @@

from .media import Faces

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal

if TYPE_CHECKING:
from .media import AlignmentData
from lib.align.alignments import (AlignmentDict, AlignmentFileDict,
PNGHeaderDict, PNGHeaderSourceDict)
PNGHeaderDict, PNGHeaderAlignmentsDict)

logger = logging.getLogger(__name__)

Expand All @@ -37,7 +43,7 @@ def __init__(self, alignments: None, arguments: Namespace) -> None:
logger.debug("Initializing %s: (alignments: %s, arguments: %s)",
self.__class__.__name__, alignments, arguments)
self._faces_dir = arguments.faces_dir
self._faces = Faces(arguments.faces_dir, with_alignments=True)
self._faces = Faces(arguments.faces_dir)
logger.debug("Initialized %s", self.__class__.__name__)

def process(self) -> None:
Expand Down Expand Up @@ -240,7 +246,7 @@ def __init__(self,
self.__class__.__name__, arguments, faces)
self._alignments = alignments

kwargs: Dict[str, Union[bool, "AlignmentData"]] = dict(with_alignments=False)
kwargs = {}
if alignments.version < 2.1:
# Update headers of faces generated with hash based alignments
kwargs["alignments"] = alignments
Expand All @@ -254,14 +260,19 @@ def __init__(self,
def process(self) -> None:
""" Process the face renaming """
logger.info("[RENAME FACES]") # Tidy up cli output
filelist = cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._faces.file_list_sorted)
rename_mappings = sorted([(face[0], face[1]["original_filename"])
filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted)
rename_mappings = sorted([(face[0], face[1]["source"]["original_filename"])
for face in filelist
if face[0] != face[1]["original_filename"]],
if face[0] != face[1]["source"]["original_filename"]],
key=lambda x: x[1])
rename_count = self._rename_faces(rename_mappings)
logger.info("%s faces renamed", rename_count)

filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted)
copyback = FaceToFile(self._alignments, [val[1] for val in filelist])
if copyback():
self._alignments.save()

def _rename_faces(self, filename_mappings: List[Tuple[str, str]]) -> int:
""" Rename faces back to their original name as exists in the alignments file.
Expand Down Expand Up @@ -325,7 +336,7 @@ def __init__(self, alignments: "AlignmentData", arguments: Namespace) -> None:
logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments)
self._alignments = alignments

kwargs: Dict[str, Union[bool, "AlignmentData"]] = dict(with_alignments=False)
kwargs = {}
if alignments.version < 2.1:
# Update headers of faces generated with hash based alignments
kwargs["alignments"] = alignments
Expand Down Expand Up @@ -367,10 +378,11 @@ def _update_png_headers(self) -> None:
to like this and has a tendency to throw permission errors, so this remains single threaded
for now.
"""
filelist = cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._items.file_list_sorted)
items = cast(Dict[str, List[int]], self._items.items)
srcs = [(x[0], x[1]["source"])
for x in cast(List[Tuple[str, "PNGHeaderDict"]], self._items.file_list_sorted)]
to_update = [ # Items whose face index has changed
x for x in filelist
x for x in srcs
if x[1]["face_index"] != items[x[1]["source_filename"]].index(x[1]["face_index"])]

for item in tqdm(to_update, desc="Updating PNG Headers", leave=False):
Expand Down Expand Up @@ -400,3 +412,80 @@ def _update_png_headers(self) -> None:
update_existing_metadata(fullpath, meta)

logger.info("%s Extracted face(s) had their header information updated", len(to_update))


class FaceToFile(): # pylint:disable=too-few-public-methods
""" Updates any optional/missing keys in the alignments file with any data that has been
populated in a PNGHeader. Includes masks and identity fields.
Parameters
---------
alignments: :class:`tools.alignments.media.AlignmentsData`
The loaded alignments containing faces to be removed
face_data: list
List of :class:`PNGHeaderDict` objects
"""
def __init__(self, alignments: "AlignmentData", face_data: List["PNGHeaderDict"]) -> None:
logger.debug("Initializing %s: alignments: %s, face_data: %s",
self.__class__.__name__, alignments, len(face_data))
self._alignments = alignments
self._face_alignments = face_data
self._updatable_keys: List[Literal["identity", "mask"]] = ["identity", "mask"]
self._counts: Dict[str, int] = {}
logger.debug("Initialized %s", self.__class__.__name__)

def _check_and_update(self,
alignment: "PNGHeaderAlignmentsDict",
face: "AlignmentFileDict") -> None:
""" Check whether the key requires updating and update it.
alignment: dict
The alignment dictionary from the PNG Header
face: dict
The alignment dictionary for the face from the alignments file
"""
for key in self._updatable_keys:
if key == "mask":
exist_masks = face["mask"]
for mask_name, mask_data in alignment["mask"].items():
if mask_name in exist_masks:
continue
exist_masks[mask_name] = mask_data
count_key = f"mask_{mask_name}"
self._counts[count_key] = self._counts.get(count_key, 0) + 1
continue

if not face.get(key, {}) and alignment.get(key):
face[key] = alignment[key]
self._counts[key] = self._counts.get(key, 0) + 1

def __call__(self) -> bool:
""" Parse through the face data updating any entries in the alignments file.
Returns
-------
bool
``True`` if any alignment information was updated otherwise ``False``
"""
for meta in tqdm(self._face_alignments,
desc="Updating Alignments File from PNG Header",
leave=False):
src = meta["source"]
alignment = meta["alignments"]
if not any(alignment.get(key, {}) for key in self._updatable_keys):
continue

faces = self._alignments.get_faces_in_frame(src["source_filename"])
if len(faces) < src["face_index"] + 1: # list index out of range
logger.debug("Skipped face '%s'. Index does not exist in alignments file",
src["original_filename"])
continue

face = faces[src["face_index"]]
self._check_and_update(alignment, face)

retval = False
if self._counts:
retval = True
logger.info("Updated alignments file from PNG Data: %s", self._counts)
return retval
5 changes: 3 additions & 2 deletions tools/alignments/jobs_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,10 @@ def _get_count(self) -> Optional[int]:
meta = self._alignments.video_meta_data
has_meta = all(val is not None for val in meta.values())
if has_meta:
retval = None
retval: Optional[int] = len(cast(Dict[str, Union[List[int], List[float]]],
meta["pts_time"]))
else:
retval = len(cast(Dict[str, Union[List[int], List[float]]], meta["pts_time"]))
retval = None
logger.debug("Frame count from alignments file: (has_meta: %s, %s", has_meta, retval)
return retval

Expand Down
43 changes: 11 additions & 32 deletions tools/alignments/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

if TYPE_CHECKING:
import numpy as np
from lib.align.alignments import AlignmentFileDict, PNGHeaderDict, PNGHeaderSourceDict
from lib.align.alignments import AlignmentFileDict, PNGHeaderDict

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -151,14 +151,12 @@ def valid_extension(filename) -> bool:
return retval

def sorted_items(self) -> Union[List[Dict[str, str]],
List[Tuple[str, "PNGHeaderSourceDict"]],
List[Tuple[str, "PNGHeaderDict"]]]:
""" Override for specific folder processing """
raise NotImplementedError()

def process_folder(self) -> Union[Generator[Dict[str, str], None, None],
Generator[Tuple[str, "PNGHeaderDict"], None, None],
Generator[Tuple[str, "PNGHeaderSourceDict"], None, None]]:
Generator[Tuple[str, "PNGHeaderDict"], None, None]]:
""" Override for specific folder processing """
raise NotImplementedError()

Expand Down Expand Up @@ -265,21 +263,12 @@ class Faces(MediaLoader):
The alignments object that contains the faces. Used to update legacy hash based faces
for <v2.1 alignments to png header based version. Pass in ``None`` to not update legacy
faces (raises error instead). Default: ``None``
with_alignments: bool, optional
By default, only the source information stored in the PNG header will be returned in
:attr:`file_list_sorted`. Set to ``True`` to include alignment information as well.
Default:``False``
"""
def __init__(self,
folder: str,
alignments: Optional[Alignments] = None,
with_alignments: bool = False) -> None:
def __init__(self, folder: str, alignments: Optional[Alignments] = None) -> None:
self._alignments = alignments
self._with_alignments = with_alignments
super().__init__(folder)

def process_folder(self) -> Union[Generator[Tuple[str, "PNGHeaderDict"], None, None],
Generator[Tuple[str, "PNGHeaderSourceDict"], None, None]]:
def process_folder(self) -> Generator[Tuple[str, "PNGHeaderDict"], None, None]:
""" Iterate through the faces folder pulling out various information for each face.
Yields
Expand Down Expand Up @@ -321,13 +310,11 @@ def process_folder(self) -> Union[Generator[Tuple[str, "PNGHeaderDict"], None, N
f"Some of the faces being passed in from '{self.folder}' could not be "
f"matched to the alignments file '{self._alignments.file}'\nPlease double "
"check your sources and try again.")
sub_dict = data if self._with_alignments else data["source"]
sub_dict = data
else:
sub_dict = (metadata["itxt"] if self._with_alignments
else metadata["itxt"]["source"])
sub_dict = cast("PNGHeaderDict", metadata["itxt"])

retval: Union[Tuple[str, "PNGHeaderDict"], Tuple[str, "PNGHeaderSourceDict"]]
retval = (os.path.basename(fullpath), sub_dict) # type:ignore
retval = (os.path.basename(fullpath), sub_dict)
yield retval

def load_items(self) -> Dict[str, List[int]]:
Expand All @@ -339,29 +326,21 @@ def load_items(self) -> Dict[str, List[int]]:
The source filename as key with list of face indices for the frame as value
"""
faces: Dict[str, List[int]] = {}
for face in cast(Union[List[Tuple[str, "PNGHeaderDict"]],
List[Tuple[str, "PNGHeaderSourceDict"]]],
self.file_list_sorted):
src: "PNGHeaderSourceDict" = cast(
"PNGHeaderDict",
face[1])["source"] if self._with_alignments else cast("PNGHeaderSourceDict",
face[1])
for face in cast(List[Tuple[str, "PNGHeaderDict"]], self.file_list_sorted):
src = face[1]["source"]
faces.setdefault(src["source_filename"], []).append(src["face_index"])
logger.trace(faces) # type: ignore
return faces

def sorted_items(self) -> Union[List[Tuple[str, "PNGHeaderDict"]],
List[Tuple[str, "PNGHeaderSourceDict"]]]:
def sorted_items(self) -> List[Tuple[str, "PNGHeaderDict"]]:
""" Return the items sorted by the saved file name.
Returns
--------
list
List of `dict` objects for each face found, sorted by the face's current filename
"""
items = cast(Union[List[Tuple[str, "PNGHeaderDict"]],
List[Tuple[str, "PNGHeaderSourceDict"]]],
sorted(self.process_folder(), key=itemgetter(0)))
items = sorted(self.process_folder(), key=itemgetter(0))
logger.trace(items) # type: ignore
return items

Expand Down

0 comments on commit c79175c

Please sign in to comment.