diff --git a/README.rst b/README.rst
index ef05fe7..32c14dd 100644
--- a/README.rst
+++ b/README.rst
@@ -102,10 +102,10 @@ API Usage
filenames = [os.path.join(src_path, f) for f in os.listdir(src_path)
if f.lower().endswith(FILE_EXTS)]
+ cm = ColorMatcher()
for i, fname in enumerate(filenames):
img_src = load_img_file(fname)
- obj = ColorMatcher(src=img_src, ref=img_ref, method='mkl')
- img_res = obj.main()
+ img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl')
img_res = Normalizer(img_res).uint8_norm()
save_img_file(img_res, os.path.join(os.path.dirname(fname), str(i)+'.png'))
diff --git a/color_matcher/__init__.py b/color_matcher/__init__.py
index 5c01ec5..516b5c2 100644
--- a/color_matcher/__init__.py
+++ b/color_matcher/__init__.py
@@ -16,7 +16,7 @@
along with this program. If not, see .
"""
-__version__ = '0.4.1'
+__version__ = '0.5.0'
from .top_level import ColorMatcher
from .hist_matcher import HistogramMatcher
diff --git a/color_matcher/baseclass.py b/color_matcher/baseclass.py
index 0b86fdc..2273c22 100644
--- a/color_matcher/baseclass.py
+++ b/color_matcher/baseclass.py
@@ -17,6 +17,7 @@
"""
import numpy as np
+import warnings
class MatcherBaseclass(object):
@@ -25,6 +26,7 @@ def __init__(self, *args, **kwargs):
self._src = None
self._ref = None
+ self._funs = []
if len(args) == 2:
self._src = args[0]
@@ -34,8 +36,6 @@ def __init__(self, *args, **kwargs):
self._src = kwargs['src'] if 'src' in kwargs else self._src
self._ref = kwargs['ref'] if 'ref' in kwargs else self._ref
- self.validate_img_dims()
-
def validate_img_dims(self):
"""
This function validates the image dimensions. It throws an exception if the dimension are unequal to 2 or 3.
@@ -45,17 +45,61 @@ def validate_img_dims(self):
self._src = self._src[..., np.newaxis] if len(self._src.shape) == 2 else self._src
self._ref = self._ref[..., np.newaxis] if len(self._ref.shape) == 2 else self._ref
- if len(self._src.shape) != 3 or len(self._ref.shape) != 3:
- raise BaseException('Wrong image dimensions')
+ if len(self._src.shape) not in (2, 3) or len(self._ref.shape) not in (2, 3):
+ raise BaseException('Each image must have 2 or 3 dimensions')
return True
def validate_color_chs(self):
"""
- This function checks whether provided images consist of 3 color channels. An exception is thrown otherwise.
+ This function checks whether provided images consist of a valid number of color channels.
"""
- if self._src.shape[2] != 3 or self._ref.shape[2] != 3:
- raise BaseException('Each image must have 3 color channels')
+ if len(self._src.shape) == 3 or len(self._ref.shape) == 3:
+ if self._src.shape[2] > 4 or self._ref.shape[2] > 4:
+ raise BaseException('Each image cannot have more than 4 color channels')
+ elif self._src.shape[2] == 3 and self._ref.shape[2] == 4:
+ self._ref = self._ref[..., :3]
+ elif self._src.shape[2] == 4 and self._ref.shape[2] == 3:
+ self._src = self._src[..., :3]
+ elif self._src.shape[2] == 1 and self._ref.shape[2] == 3:
+ self._ref = self.rgb2gray(self._ref)
+ elif self._src.shape[2] == 3 and self._ref.shape[2] == 1:
+ self._src = self.rgb2gray(self._src)
+
+ if self._src.shape[2] == 1 and self._ref.shape[2] == 1:
+ # restrict monochromatic transfer to histogram matching
+ self._funs = [self.hist_match]
+ warnings.warn('Transfer restricted to histogram matching due to monochromatic input')
return True
+
+ @staticmethod
+ def rgb2gray(rgb: np.ndarray = None, standard: str = 'HDTV') -> np.ndarray:
+ """ Convert RGB color space to monochromatic color space
+
+ :param rgb: input array in red, green and blue (RGB) space
+ :type rgb: :class:`~numpy:numpy.ndarray`
+ :param standard: option that determines whether head- and footroom are excluded ('HDTV') or considered otherwise
+ :type standard: :class:`string`
+ :return: array in monochromatic space
+ :rtype: ~numpy:np.ndarray
+
+ """
+
+ # store shape
+ shape = rgb.shape
+
+ # reshape image to channel vectors
+ rgb = rgb.reshape(-1, 3).T
+
+ # choose standard
+ mat = np.array([0.2126, 0.7152, 0.0722]) if standard == 'HDTV' else np.array([0.299, 0.587, 0.114])
+
+ # convert to gray
+ arr = np.dot(mat, rgb)
+
+ # reshape to 2-D image
+ arr = arr.reshape(shape[:2] + (1,))
+
+ return arr
diff --git a/color_matcher/mvgd_matcher.py b/color_matcher/mvgd_matcher.py
index 444c321..0eb150b 100644
--- a/color_matcher/mvgd_matcher.py
+++ b/color_matcher/mvgd_matcher.py
@@ -36,15 +36,14 @@ def __init__(self, *args, **kwargs):
try:
self._fun_name = [kw for kw in list(self._fun_dict.keys()) if kwargs['method'].__contains__(kw)][0]
except (BaseException, IndexError):
- # use MKL as default
+ # default function
self._fun_name = 'mkl'
self._fun_call = self._fun_dict[self._fun_name] if self._fun_name in self._fun_dict else self.mkl_solver
# initialize variables
self.r, self.z, self.cov_r, self.cov_z, self.mu_r, self.mu_z, self.transfer_mat = [None]*7
- self._init_vars()
- def _init_vars(self):
+ def init_vars(self):
# reshape source and reference images
self.r, self.z = self._src.reshape([-1, self._src.shape[2]]).T, self._ref.reshape([-1, self._ref.shape[2]]).T
@@ -55,14 +54,17 @@ def _init_vars(self):
# compute color channel means
self.mu_r, self.mu_z = self.r.mean(axis=1)[..., np.newaxis], self.z.mean(axis=1)[..., np.newaxis]
- def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: FunctionType = None) -> np.ndarray:
+ # validate dimensionality
+ self.check_dims()
+
+ def multivar_transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: FunctionType = None) -> np.ndarray:
"""
Transfer function to map colors based on for Multi-Variate Gaussian Distributions (MVGDs).
:param src: Source image that requires transfer
:param ref: Palette image which serves as reference
- :param fun: optional argument to pass a transfer function to solve for covariance matrices
+ :param fun: Optional argument to pass a transfer function to solve for covariance matrices
:param res: Resulting image after the mapping
:type src: :class:`~numpy:numpy.ndarray`
@@ -82,7 +84,7 @@ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: Function
self.validate_color_chs()
# re-initialize variables to account for change in src and ref when passed to self.transfer()
- self._init_vars()
+ self.init_vars()
# set solver function for transfer matrix
self._fun_call = fun if fun is FunctionType else self._fun_call
@@ -108,14 +110,17 @@ def mkl_solver(self):
"""
+ # validate dimensionality
+ self.check_dims()
+
eig_val_r, eig_vec_r = np.linalg.eig(self.cov_r)
eig_val_r[eig_val_r < 0] = 0
val_r = np.diag(np.sqrt(eig_val_r[::-1]))
vec_r = np.array(eig_vec_r[:, ::-1])
inv_r = np.diag(1. / (np.diag(val_r + np.spacing(1))))
- mat_c = np.dot(val_r, np.dot(vec_r.T, np.dot(self.cov_z, np.dot(vec_r, val_r))))
- [eig_val_c, eig_vec_c] = np.linalg.eig(mat_c)
+ mat_c = val_r @ vec_r.T @ self.cov_z @ vec_r @ val_r
+ eig_val_c, eig_vec_c = np.linalg.eig(mat_c)
eig_val_c[eig_val_c < 0] = 0
val_c = np.diag(np.sqrt(eig_val_c))
@@ -133,6 +138,11 @@ def analytical_solver(self) -> np.ndarray:
"""
+ # validate dimensionality
+ self.check_dims()
+ if self.r.shape[-1] != self.z.shape[-1]:
+ raise Exception('Analytical MVGD solution requires spatial dimensions of both images to be equal')
+
cov_r_inv = np.linalg.pinv(self.cov_r)
cov_z_inv = np.linalg.pinv(self.cov_z)
@@ -164,3 +174,13 @@ def w2_dist(mu_a: np.ndarray, mu_b: np.ndarray, cov_a: np.ndarray, cov_b: np.nda
vars_dist = np.trace(cov_a+cov_b - 2*(np.dot(np.abs(cov_b)**.5, np.dot(np.abs(cov_a), np.abs(cov_b)**.5))**.5))
return float(mean_dist + vars_dist)
+
+ def check_dims(self):
+ """
+ Catch error for wrong color channel number (e.g., gray scale image)
+
+ :return: None
+ """
+
+ if np.ndim(self.cov_r) == 0 or np.ndim(self.cov_z) == 0:
+ raise Exception('Wrong color channel dimensionality for %s method' % self._fun_name)
diff --git a/color_matcher/top_level.py b/color_matcher/top_level.py
index 290d0e3..5447c77 100644
--- a/color_matcher/top_level.py
+++ b/color_matcher/top_level.py
@@ -34,36 +34,65 @@ def __init__(self, *args, **kwargs):
super(ColorMatcher, self).__init__(*args, **kwargs)
self._method = kwargs['method'] if 'method' in kwargs else 'default'
+ self._funs = []
- def main(self, method: str = None) -> np.ndarray:
+ def main(self) -> np.ndarray:
"""
- The main function is the high-level entry point performing the mapping. Valid methods are:
+ The main function is the high-level entry point performing the mapping based on instantiation arguments.
+ :return: Resulting image after color mapping
+ :rtype: np.ndarray
+ """
+
+ self.transfer()
+
+ return self._src
+
+ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, method: str = None) -> np.ndarray:
+ """
+
+ Transfer function to map colors based on provided transfer method.
+
+ :param src: Source image that requires transfer
+ :param ref: Palette image which serves as reference
:param method: ('default', 'hm', 'reinhard', 'mvgd', 'mkl', 'hm-mvgd-hm', 'hm-mkl-hm') determining color mapping
+
+ :type src: :class:`~numpy:numpy.ndarray`
+ :type ref: :class:`~numpy:numpy.ndarray`
:type method: :class:`str`
:return: Resulting image after color mapping
:rtype: np.ndarray
+
"""
+ # assign input arguments to variables (if provided)
self._method = self._method.lower() if method is None else method.lower()
+ self._src = src if src is not None else self._src
+ self._ref = ref if ref is not None else self._ref
# color transfer methods (to be iterated through)
if self._method == METHODS[0]:
- funs = [self.transfer]
+ self._funs = [self.multivar_transfer]
elif self._method == METHODS[1]:
- funs = [self.hist_match]
+ self._funs = [self.hist_match]
elif self._method == METHODS[2]:
- funs = [self.reinhard]
+ self._funs = [self.reinhard]
elif self._method in METHODS[3:5]:
- funs = [self.transfer]
+ self._funs = [self.multivar_transfer]
elif self._method in METHODS[5:]:
- funs = [self.hist_match, self.transfer, self.hist_match]
+ self._funs = [self.hist_match, self.multivar_transfer, self.hist_match]
else:
raise BaseException('Method type \'%s\' not recognized' % method)
+ # check if three color channels are provided
+ self.validate_img_dims()
+
+ # check provided color channels
+ self.validate_color_chs()
+
# proceed with the color match
- for fun in funs:
+ for fun in self._funs:
self._src = fun(self._src, self._ref)
return self._src
diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo
index e683da1..ec0f432 100644
--- a/docs/build/html/.buildinfo
+++ b/docs/build/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 36e65157a5b8d9f8afb88b0180e76b66
+config: d69c61c77e9336c13f9a1800c01ddd93
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js
index e9b7668..25eaf45 100644
--- a/docs/build/html/_static/documentation_options.js
+++ b/docs/build/html/_static/documentation_options.js
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
- VERSION: '0.4.1',
+ VERSION: '0.5.0',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
diff --git a/docs/build/html/apidoc.html b/docs/build/html/apidoc.html
index ed8bdde..4d09d63 100644
--- a/docs/build/html/apidoc.html
+++ b/docs/build/html/apidoc.html
@@ -5,7 +5,7 @@
- API documentation — color-matcher 0.4.1 documentation
+ API documentation — color-matcher 0.5.0 documentation
@@ -30,7 +30,7 @@ Navigation
previous |
- color-matcher 0.4.1 documentation »
+ color-matcher 0.5.0 documentation »
API documentation
@@ -63,11 +63,29 @@ Class hierarchy