Skip to content

Commit

Permalink
Merge branch 'staging'
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Dec 18, 2022
2 parents 8c63b7a + 48c886b commit 6913e42
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 72 deletions.
24 changes: 14 additions & 10 deletions lib/gpu_stats/amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def active_devices(self) -> List[int]:
@property
def _plaid_ids(self) -> List[str]:
""" list: The device identification for each GPU device that PlaidML has discovered. """
return [device.id.decode("utf-8") for device in self._all_devices]
return [device.id.decode("utf-8", errors="replace") for device in self._all_devices]

@property
def _experimental_indices(self) -> List[int]:
Expand Down Expand Up @@ -186,7 +186,9 @@ def _get_supported_devices(self) -> List[plaidml._DeviceConfig]:

supported = [d for d in devices
if d.details
and json.loads(d.details.decode("utf-8")).get("type", "cpu").lower() == "gpu"]
and json.loads(
d.details.decode("utf-8",
errors="replace")).get("type", "cpu").lower() == "gpu"]

self._log("debug", f"Obtained supported devices: {supported}")
return supported
Expand All @@ -206,7 +208,9 @@ def _get_all_devices(self) -> List[plaidml._DeviceConfig]:

experi = [d for d in devices
if d.details
and json.loads(d.details.decode("utf-8")).get("type", "cpu").lower() == "gpu"]
and json.loads(
d.details.decode("utf-8",
errors="replace")).get("type", "cpu").lower() == "gpu"]

self._log("debug", f"Obtained experimental Devices: {experi}")

Expand Down Expand Up @@ -240,7 +244,7 @@ def _get_fallback_devices(self) -> List[plaidml._DeviceConfig]:
raise RuntimeError("No valid devices could be found for plaidML.")

self._log("warning", f"PlaidML could not find a GPU. Falling back to: "
f"{[d.id.decode('utf-8') for d in devices]}")
f"{[d.id.decode('utf-8', errors='replace') for d in devices]}")
return devices

def _get_device_details(self) -> List[dict]:
Expand All @@ -254,10 +258,10 @@ def _get_device_details(self) -> List[dict]:
details = []
for dev in self._all_devices:
if dev.details:
details.append(json.loads(dev.details.decode("utf-8")))
details.append(json.loads(dev.details.decode("utf-8", errors="replace")))
else:
details.append(dict(vendor=dev.id.decode("utf-8"),
name=dev.description.decode("utf-8"),
details.append(dict(vendor=dev.id.decode("utf-8", errors="replace"),
name=dev.description.decode("utf-8", errors="replace"),
globalMemSize=4 * 1024 * 1024 * 1024)) # 4GB dummy ram
self._log("debug", f"Obtained Device details: {details}")
return details
Expand All @@ -284,11 +288,11 @@ def _select_largest_gpu(self) -> None:
self._log("error", "Please run `plaidml-setup` to set up your GPU.")
sys.exit(1)

max_vram = max([self._all_vram[idx] for idx in indices])
max_vram = max(self._all_vram[idx] for idx in indices)
self._log("debug", f"Max VRAM: {max_vram}")

gpu_idx = min([idx for idx, vram in enumerate(self._all_vram)
if vram == max_vram and idx in indices])
gpu_idx = min(idx for idx, vram in enumerate(self._all_vram)
if vram == max_vram and idx in indices)
self._log("debug", f"GPU IDX: {gpu_idx}")

selected_gpu = self._plaid_ids[gpu_idx]
Expand Down
14 changes: 9 additions & 5 deletions lib/gui/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ def _get_branches():
retcode = cmd.poll()
if retcode != 0:
logger.debug("Unable to list git branches. return code: %s, message: %s",
retcode, stdout.decode().strip().replace("\n", " - "))
retcode,
stdout.decode(locale.getpreferredencoding(),
errors="replace").strip().replace("\n", " - "))
return None
return stdout.decode(locale.getpreferredencoding())
return stdout.decode(locale.getpreferredencoding(), errors="replace")

@staticmethod
def _filter_branches(stdout):
Expand Down Expand Up @@ -321,7 +323,9 @@ def _switch_branch(branch):
retcode = cmd.poll()
if retcode != 0:
logger.error("Unable to switch branch. return code: %s, message: %s",
retcode, stdout.decode().strip().replace("\n", " - "))
retcode,
stdout.decode(locale.getdefaultlocale(),
errors="replace").strip().replace("\n", " - "))
return
logger.info("Succesfully switched to '%s'. You may want to check for updates to make sure "
"that you have the latest code.", branch)
Expand Down Expand Up @@ -402,7 +406,7 @@ def check_for_updates(encoding, check=False):
msg = ("Git is not installed or you are not running a cloned repo. "
"Unable to check for updates")
else:
chk = stdout.decode(encoding).splitlines()
chk = stdout.decode(encoding, errors="replace").splitlines()
for line in chk:
if line.lower().startswith("your branch is ahead"):
msg = "Your branch is ahead of the remote repo. Not updating"
Expand Down Expand Up @@ -434,7 +438,7 @@ def do_update(encoding):
bufsize=1,
cwd=_WORKING_DIR) as cmd:
while True:
output = cmd.stdout.readline().decode(encoding)
output = cmd.stdout.readline().decode(encoding, errors="replace")
if output == "" and cmd.poll() is not None:
break
if output:
Expand Down
9 changes: 5 additions & 4 deletions lib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,11 @@ def read_image_meta(filename):
elif field == b"iTXt":
keyword, value = infile.read(length).split(b"\0", 1)
if keyword == b"faceswap":
retval["itxt"] = literal_eval(value[4:].decode("utf-8"))
retval["itxt"] = literal_eval(value[4:].decode("utf-8", errors="replace"))
break
else:
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1", "ignore"))
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1",
errors="ignore"))
length = 0 # Reset marker for next chunk
infile.seek(length + 4, 1)
logger.trace("filename: %s, metadata: %s", filename, retval)
Expand Down Expand Up @@ -645,9 +646,9 @@ def png_read_meta(png):
pointer += 8
keyword, value = png[pointer:pointer + length].split(b"\0", 1)
if keyword == b"faceswap":
retval = literal_eval(value[4:].decode("utf-8"))
retval = literal_eval(value[4:].decode("utf-8", errors="ignore"))
break
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1", "ignore"))
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1", errors="ignore"))
pointer += length + 4
return retval

Expand Down
4 changes: 2 additions & 2 deletions lib/keypress.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def getch(self):
if (self.is_gui or not sys.stdout.isatty()) and os.name != "nt":
return None
if os.name == "nt":
return msvcrt.getch().decode("utf-8")
return msvcrt.getch().decode("utf-8", errors="replace")
return sys.stdin.read(1)

def getarrow(self):
Expand All @@ -83,7 +83,7 @@ def getarrow(self):
char = sys.stdin.read(3)[2]
vals = [65, 67, 66, 68]

return vals.index(ord(char.decode("utf-8")))
return vals.index(ord(char.decode("utf-8", errors="replace")))

def kbhit(self):
""" Returns True if keyboard character was hit, False otherwise. """
Expand Down
28 changes: 10 additions & 18 deletions lib/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ def unmarshal(self, serialized_data):
logger.debug("returned data type: %s", type(retval))
return retval

@classmethod
def _marshal(cls, data):
def _marshal(self, data):
""" Override for serializer specific marshalling """
raise NotImplementedError()

@classmethod
def _unmarshal(cls, data):
def _unmarshal(self, data):
""" Override for serializer specific unmarshalling """
raise NotImplementedError()

Expand All @@ -188,13 +186,11 @@ def __init__(self):
super().__init__()
self._file_extension = "yml"

@classmethod
def _marshal(cls, data):
def _marshal(self, data):
return yaml.dump(data, default_flow_style=False).encode("utf-8")

@classmethod
def _unmarshal(cls, data):
return yaml.load(data.decode("utf-8"), Loader=yaml.FullLoader)
def _unmarshal(self, data):
return yaml.load(data.decode("utf-8", errors="replace"), Loader=yaml.FullLoader)


class _JSONSerializer(Serializer):
Expand All @@ -203,13 +199,11 @@ def __init__(self):
super().__init__()
self._file_extension = "json"

@classmethod
def _marshal(cls, data):
def _marshal(self, data):
return json.dumps(data, indent=2).encode("utf-8")

@classmethod
def _unmarshal(cls, data):
return json.loads(data.decode("utf-8"))
def _unmarshal(self, data):
return json.loads(data.decode("utf-8", errors="replace"))


class _PickleSerializer(Serializer):
Expand All @@ -218,12 +212,10 @@ def __init__(self):
super().__init__()
self._file_extension = "pickle"

@classmethod
def _marshal(cls, data):
def _marshal(self, data):
return pickle.dumps(data)

@classmethod
def _unmarshal(cls, data):
def _unmarshal(self, data):
return pickle.loads(data)


Expand Down
63 changes: 32 additions & 31 deletions lib/sysinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,54 +95,54 @@ def _fs_command(self):
@property
def _installed_pip(self):
""" str: The list of installed pip packages within Faceswap's scope. """
pip = Popen("{} -m pip freeze".format(sys.executable),
shell=True, stdout=PIPE)
installed = pip.communicate()[0].decode().splitlines()
with Popen(f"{sys.executable} -m pip freeze", shell=True, stdout=PIPE) as pip:
installed = pip.communicate()[0].decode(self._encoding, errors="replace").splitlines()
return "\n".join(installed)

@property
def _installed_conda(self):
""" str: The list of installed Conda packages within Faceswap's scope. """
if not self._is_conda:
return None
conda = Popen("conda list", shell=True, stdout=PIPE, stderr=PIPE)
stdout, stderr = conda.communicate()
with Popen("conda list", shell=True, stdout=PIPE, stderr=PIPE) as conda:
stdout, stderr = conda.communicate()
if stderr:
return "Could not get package list"
installed = stdout.decode().splitlines()
installed = stdout.decode(self._encoding, errors="replace").splitlines()
return "\n".join(installed)

@property
def _conda_version(self):
""" str: The installed version of Conda, or `N/A` if Conda is not installed. """
if not self._is_conda:
return "N/A"
conda = Popen("conda --version", shell=True, stdout=PIPE, stderr=PIPE)
stdout, stderr = conda.communicate()
with Popen("conda --version", shell=True, stdout=PIPE, stderr=PIPE) as conda:
stdout, stderr = conda.communicate()
if stderr:
return "Conda is used, but version not found"
version = stdout.decode().splitlines()
version = stdout.decode(self._encoding, errors="replace").splitlines()
return "\n".join(version)

@property
def _git_branch(self):
""" str: The git branch that is currently being used to execute Faceswap. """
git = Popen("git status", shell=True, stdout=PIPE, stderr=PIPE)
stdout, stderr = git.communicate()
with Popen("git status", shell=True, stdout=PIPE, stderr=PIPE) as git:
stdout, stderr = git.communicate()
if stderr:
return "Not Found"
branch = stdout.decode().splitlines()[0].replace("On branch ", "")
branch = stdout.decode(self._encoding,
errors="replace").splitlines()[0].replace("On branch ", "")
return branch

@property
def _git_commits(self):
""" str: The last 5 git commits for the currently running Faceswap. """
git = Popen("git log --pretty=oneline --abbrev-commit -n 5",
shell=True, stdout=PIPE, stderr=PIPE)
stdout, stderr = git.communicate()
with Popen("git log --pretty=oneline --abbrev-commit -n 5",
shell=True, stdout=PIPE, stderr=PIPE) as git:
stdout, stderr = git.communicate()
if stderr:
return "Not Found"
commits = stdout.decode().splitlines()
commits = stdout.decode(self._encoding, errors="replace").splitlines()
return ". ".join(commits)

@property
Expand Down Expand Up @@ -193,14 +193,14 @@ def full_info(self):
"gpu_cuda": self._cuda_version,
"gpu_cudnn": self._cudnn_version,
"gpu_driver": self._gpu["driver"],
"gpu_devices": ", ".join(["GPU_{}: {}".format(idx, device)
"gpu_devices": ", ".join([f"GPU_{idx}: {device}"
for idx, device in enumerate(self._gpu["devices"])]),
"gpu_vram": ", ".join(["GPU_{}: {}MB".format(idx, int(vram))
"gpu_vram": ", ".join([f"GPU_{idx}: {int(vram)}MB"
for idx, vram in enumerate(self._gpu["vram"])]),
"gpu_devices_active": ", ".join(["GPU_{}".format(idx)
"gpu_devices_active": ", ".join([f"GPU_{idx}"
for idx in self._gpu["devices_active"]])}
for key in sorted(sys_info.keys()):
retval += ("{0: <20} {1}\n".format(key + ":", sys_info[key]))
retval += (f"{key + ':':<20} {sys_info[key]}\n")
retval += "\n=============== Pip Packages ===============\n"
retval += self._installed_pip
if self._is_conda:
Expand All @@ -219,11 +219,11 @@ def _format_ram(self):
str
The total, available, used and free RAM displayed in Megabytes
"""
retval = list()
retval = []
for name in ("total", "available", "used", "free"):
value = getattr(self, "_ram_{}".format(name))
value = getattr(self, f"_ram_{name}")
value = int(value / (1024 * 1024))
retval.append("{}: {}MB".format(name.capitalize(), value))
retval.append(f"{name.capitalize()}: {value}MB")
return ", ".join(retval)


Expand All @@ -241,7 +241,8 @@ def get_sysinfo():
try:
retval = _SysInfo().full_info()
except Exception as err: # pylint: disable=broad-except
retval = "Exception occured trying to retrieve sysinfo: {}".format(err)
retval = f"Exception occured trying to retrieve sysinfo: {str(err)}"
raise
return retval


Expand Down Expand Up @@ -284,7 +285,7 @@ def _parse_configs(self, config_files):
for cfile in config_files:
fname = os.path.basename(cfile)
ext = os.path.splitext(cfile)[1]
formatted += "\n--------- {} ---------\n".format(fname)
formatted += f"\n--------- {fname} ---------\n"
if ext == ".ini":
formatted += self._parse_ini(cfile)
elif fname == ".faceswap":
Expand All @@ -305,14 +306,14 @@ def _parse_ini(self, config_file):
The current configuration in the config file formatted in a human readable format
"""
formatted = ""
with open(config_file, "r") as cfile:
with open(config_file, "r", encoding="utf-8", errors="replace") as cfile:
for line in cfile.readlines():
line = line.strip()
if line.startswith("#") or not line:
continue
item = line.split("=")
if len(item) == 1:
formatted += "\n{}\n".format(item[0].strip())
formatted += f"\n{item[0].strip()}\n"
else:
formatted += self._format_text(item[0], item[1])
return formatted
Expand All @@ -331,7 +332,7 @@ def _parse_json(self, config_file):
The current configuration in the config file formatted as a python dictionary
"""
formatted = ""
with open(config_file, "r") as cfile:
with open(config_file, "r", encoding="utf-8", errors="replace") as cfile:
conf_dict = json.load(cfile)
for key in sorted(conf_dict.keys()):
formatted += self._format_text(key, conf_dict[key])
Expand All @@ -353,7 +354,7 @@ def _format_text(key, value):
str
The formatted key value pair for display
"""
return "{0: <25} {1}\n".format(key.strip() + ":", value.strip())
return f"{key.strip() + ':':<25} {value.strip()}\n"


class _State(): # pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -395,12 +396,12 @@ def _get_state_file(self):
"""
if not self._is_training or self._model_dir is None or self._trainer is None:
return ""
fname = os.path.join(self._model_dir, "{}_state.json".format(self._trainer))
fname = os.path.join(self._model_dir, f"{self._trainer}_state.json")
if not os.path.isfile(fname):
return ""

retval = "\n\n=============== State File =================\n"
with open(fname, "r") as sfile:
with open(fname, "r", encoding="utf-8", errors="replace") as sfile:
retval += sfile.read()
return retval

Expand Down

0 comments on commit 6913e42

Please sign in to comment.