Skip to content

Commit

Permalink
refactor(export): guard clauses, pathlib
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Jul 15, 2023
1 parent 316ded8 commit e59fc51
Showing 1 changed file with 102 additions and 108 deletions.
210 changes: 102 additions & 108 deletions stimupy/utils/export.py
Expand Up @@ -2,6 +2,7 @@
import json
import pickle
from hashlib import md5
from pathlib import Path

import numpy as np
from PIL import Image
Expand All @@ -21,126 +22,125 @@


def array_to_checksum(arr):
"""Hash (md5) values, and save only the hex
"""Hash (md5) array values, and return hex-checksum
Parameters
----------
arr : np.ndarray
Array to be hashed.
arr : numpy.ndarray
array to be hashed.
Returns
----------
hex
str
hex-string representation of hash (MD5) of given array
"""
return md5(np.ascontiguousarray(arr.round(8))).hexdigest()


def array_to_image(arr, filename, norm=True):
def array_to_image(arr, filename, format=None, norm=True):
"""Save a 2D numpy array as a grayscale image file.
Parameters
----------
arr : np.ndarray
Array to be saved. Values will be cropped to [0,255].
filename : str
full path to the file to be created.
arr : numpy.ndarray
array to be exported. Values will be cropped to [0,255].
filename : Path or str
(full) path to the file to be created.
norm : bool
if True (default), multiply array by 255
multiply array by 255, by default True
"""
if filename[-4:] != ".png" and filename[-4:] != ".jpg":
filename += ".png"
try:
arr_to_write = np.ascontiguousarray(arr)
except TypeError as e:
raise e from ValueError("arr should be a numpy.ndarray(-like)")

if isinstance(arr, (np.ndarray, list)):
arr = np.array(arr)
if norm:
arr_to_write = arr_to_write * 255

if norm:
arr = arr * 255

if Image:
imsize = arr.shape
im = Image.new("L", (imsize[1], imsize[0]))
im.putdata(arr.flatten())
im.save(filename)
else:
raise ValueError("arr should be a np.ndarray")
im = Image.fromarray(arr_to_write.astype("uint8"), mode="L")
im.save(filename)


def array_to_npy(arr, filename):
"""Save a numpy array to npy-file.
Parameters
----------
arr : np.ndarray
Array to be saved.
filename : str
full path to the file to be creaated.
arr : numpy.ndarray
array to be exported.
filename : Path or str
(full) path to the file to be created.
"""
if isinstance(arr, (np.ndarray, list)):
np.save(filename, arr)
else:
raise ValueError("arr should be a np.ndarray")
try:
arr_to_write = np.ascontiguousarray(arr)
except TypeError as e:
raise e from ValueError("arr should be a numpy.ndarray(-like)")

filepath = Path(filename).resolve().with_suffix(".npy")

np.save(filepath, arr_to_write)


def array_to_mat(arr, filename):
"""Save a numpy array to a mat-file.
Parameters
----------
arr : np.ndarray
Array to be saved.
filename : str
full path to the file to be creaated.
arr : numpy.ndarray
array to be exported.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-4:] != ".mat":
filename += ".mat"
try:
arr_to_write = np.ascontiguousarray(arr)
except TypeError as e:
raise e from ValueError("arr should be a numpy.ndarray(-like)")

filepath = Path(filename).resolve().with_suffix(".mat")

if isinstance(arr, (np.ndarray, list)):
savemat(filename, {"arr": arr})
else:
raise ValueError("arr should be a np.ndarray")
savemat(filepath, {"arr": arr_to_write})


def array_to_pickle(arr, filename):
"""Save a numpy array to a pickle-file.
Parameters
----------
arr : np.ndarray
Array to be saved.
filename : str
full path to the file to be creaated.
arr : numpy.ndarray
array to be exported.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-7:] != ".pickle":
filename += ".pickle"
try:
arr_to_write = np.ascontiguousarray(arr)
except TypeError as e:
raise e from ValueError("arr should be a numpy.ndarray(-like)")

if isinstance(arr, (np.ndarray, list)):
with open(filename, "wb") as handle:
pickle.dump({"arr": arr}, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
raise ValueError("arr should be a np.ndarray")
filepath = Path(filename).resolve().with_suffix(".pickle")

with filepath.open("wb") as file:
pickle.dump({"arr": arr_to_write}, file, protocol=pickle.HIGHEST_PROTOCOL)


def array_to_json(arr, filename):
"""Save a numpy array to a (pretty) JSON.
Parameters
----------
arr : np.ndarray
Array to be saved.
filename : str
full path to the file to be creaated.
arr : numpy.ndarray
array to be exported.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-5:] != ".json":
filename += ".json"
try:
arr_to_write = np.ascontiguousarray(arr)
except TypeError as e:
raise e from ValueError("arr should be a numpy.ndarray(-like)")

filepath = Path(filename).resolve().with_suffix(".json")

if isinstance(arr, np.ndarray):
with open(filename, "w", encoding="utf-8") as f:
json.dump(arr.tolist(), f, ensure_ascii=False, indent=4)
elif isinstance(arr, list):
with open(filename, "w", encoding="utf-8") as f:
json.dump(arr, f, ensure_ascii=False, indent=4)
else:
raise ValueError("arr should be a np.ndarray")
with filepath.open("w", encoding="utf-8") as file:
json.dump(arr_to_write.tolist(), file, ensure_ascii=False, indent=4)


def arrays_to_checksum(stim, keys=["img", "mask"]):
Expand All @@ -149,7 +149,7 @@ def arrays_to_checksum(stim, keys=["img", "mask"]):
Parameters
----------
stim : dict
stimulus dictionary containing keys
stimulus dictionary to export.
keys : str of list of str
keys of dict for which the hashing should be performed
Expand All @@ -175,27 +175,24 @@ def to_json(stim, filename):
Parameters
----------
stim : dict
stimulus dictionary containing keys
filename : str
full path to the file to be creaated.
stimulus dictionary to export.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-5:] != ".json":
filename += ".json"
if not isinstance(stim, dict):
raise ValueError("stim should be a dict")

# stimulus-dict(s) as (pretty) JSON
if isinstance(stim, dict):
stim2 = copy.deepcopy(stim)
filepath = Path(filename).resolve().with_suffix(".json")

for key in stim2.keys():
# np.ndarrays are not serializable; change to list
if isinstance(stim2[key], np.ndarray):
stim2[key] = stim2[key].tolist()
stim2 = copy.deepcopy(stim)
for key in stim2.keys():
# np.ndarrays are not serializable; change to list
if isinstance(stim2[key], np.ndarray):
stim2[key] = stim2[key].tolist()

with open(filename, "w", encoding="utf-8") as f:
json.dump(stim2, f, ensure_ascii=False, indent=4)
else:
raise ValueError("stim should be a dict")
with filepath.open("w", encoding="utf-8") as file:
json.dump(stim2, file, ensure_ascii=False, indent=4)


def to_mat(stim, filename):
Expand All @@ -204,43 +201,40 @@ def to_mat(stim, filename):
Parameters
----------
stim : dict
stimulus dictionary containing keys
filename : str
full path to the file to be creaated.
stimulus dictionary to export.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-4:] != ".mat":
filename += ".mat"

if isinstance(stim, dict):
savemat(filename, stim)
else:
if not isinstance(stim, dict):
raise ValueError("stim should be a dict")

filepath = Path(filename).resolve().with_suffix(".mat")

savemat(filepath, stim)


def to_pickle(stim, filename):
"""Save stimulus-dict(s) as pickle-file
Parameters
----------
stim : dict
stimulus dictionary containing keys
filename : str
full path to the file to be creaated.
stimulus dictionary to export.
filename : Path or str
(full) path to the file to be created.
"""
if filename[-7:] != ".pickle":
filename += ".pickle"
if not isinstance(stim, dict):
raise ValueError("stim should be a dict")

if isinstance(stim, dict):
stim2 = copy.deepcopy(stim)
filepath = Path(filename).resolve().with_suffix(".pickle")

for key in stim2.keys():
# certain classes can cause problems for pickles; change to list
if key in ["visual_size", "ppd", "shape"]:
stim2[key] = list(stim2[key])
stim2 = copy.deepcopy(stim)
for key in stim2.keys():
# certain classes can cause problems for pickles; change to list
if key in ["visual_size", "ppd", "shape"]:
stim2[key] = list(stim2[key])

with open(filename, "wb") as handle:
pickle.dump(stim2, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
raise ValueError("stim should be a dict")
with filepath.open("wb") as file:
pickle.dump(stim2, file, protocol=pickle.HIGHEST_PROTOCOL)

0 comments on commit e59fc51

Please sign in to comment.