Skip to content

Commit

Permalink
feat(cm): major update for gray-scale and alpha-channel image handling
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Aug 25, 2021
1 parent b387c5f commit f4d4339
Show file tree
Hide file tree
Showing 17 changed files with 335 additions and 89 deletions.
4 changes: 2 additions & 2 deletions README.rst
Expand Up @@ -102,10 +102,10 @@ API Usage
filenames = [os.path.join(src_path, f) for f in os.listdir(src_path)
if f.lower().endswith(FILE_EXTS)]
cm = ColorMatcher()
for i, fname in enumerate(filenames):
img_src = load_img_file(fname)
obj = ColorMatcher(src=img_src, ref=img_ref, method='mkl')
img_res = obj.main()
img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl')
img_res = Normalizer(img_res).uint8_norm()
save_img_file(img_res, os.path.join(os.path.dirname(fname), str(i)+'.png'))
Expand Down
2 changes: 1 addition & 1 deletion color_matcher/__init__.py
Expand Up @@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

__version__ = '0.4.1'
__version__ = '0.5.0'

from .top_level import ColorMatcher
from .hist_matcher import HistogramMatcher
Expand Down
58 changes: 51 additions & 7 deletions color_matcher/baseclass.py
Expand Up @@ -17,6 +17,7 @@
"""

import numpy as np
import warnings


class MatcherBaseclass(object):
Expand All @@ -25,6 +26,7 @@ def __init__(self, *args, **kwargs):

self._src = None
self._ref = None
self._funs = []

if len(args) == 2:
self._src = args[0]
Expand All @@ -34,8 +36,6 @@ def __init__(self, *args, **kwargs):
self._src = kwargs['src'] if 'src' in kwargs else self._src
self._ref = kwargs['ref'] if 'ref' in kwargs else self._ref

self.validate_img_dims()

def validate_img_dims(self):
"""
This function validates the image dimensions. It throws an exception if the dimension are unequal to 2 or 3.
Expand All @@ -45,17 +45,61 @@ def validate_img_dims(self):
self._src = self._src[..., np.newaxis] if len(self._src.shape) == 2 else self._src
self._ref = self._ref[..., np.newaxis] if len(self._ref.shape) == 2 else self._ref

if len(self._src.shape) != 3 or len(self._ref.shape) != 3:
raise BaseException('Wrong image dimensions')
if len(self._src.shape) not in (2, 3) or len(self._ref.shape) not in (2, 3):
raise BaseException('Each image must have 2 or 3 dimensions')

return True

def validate_color_chs(self):
"""
This function checks whether provided images consist of 3 color channels. An exception is thrown otherwise.
This function checks whether provided images consist of a valid number of color channels.
"""

if self._src.shape[2] != 3 or self._ref.shape[2] != 3:
raise BaseException('Each image must have 3 color channels')
if len(self._src.shape) == 3 or len(self._ref.shape) == 3:
if self._src.shape[2] > 4 or self._ref.shape[2] > 4:
raise BaseException('Each image cannot have more than 4 color channels')
elif self._src.shape[2] == 3 and self._ref.shape[2] == 4:
self._ref = self._ref[..., :3]
elif self._src.shape[2] == 4 and self._ref.shape[2] == 3:
self._src = self._src[..., :3]
elif self._src.shape[2] == 1 and self._ref.shape[2] == 3:
self._ref = self.rgb2gray(self._ref)
elif self._src.shape[2] == 3 and self._ref.shape[2] == 1:
self._src = self.rgb2gray(self._src)

if self._src.shape[2] == 1 and self._ref.shape[2] == 1:
# restrict monochromatic transfer to histogram matching
self._funs = [self.hist_match]
warnings.warn('Transfer restricted to histogram matching due to monochromatic input')

return True

@staticmethod
def rgb2gray(rgb: np.ndarray = None, standard: str = 'HDTV') -> np.ndarray:
""" Convert RGB color space to monochromatic color space
:param rgb: input array in red, green and blue (RGB) space
:type rgb: :class:`~numpy:numpy.ndarray`
:param standard: option that determines whether head- and footroom are excluded ('HDTV') or considered otherwise
:type standard: :class:`string`
:return: array in monochromatic space
:rtype: ~numpy:np.ndarray
"""

# store shape
shape = rgb.shape

# reshape image to channel vectors
rgb = rgb.reshape(-1, 3).T

# choose standard
mat = np.array([0.2126, 0.7152, 0.0722]) if standard == 'HDTV' else np.array([0.299, 0.587, 0.114])

# convert to gray
arr = np.dot(mat, rgb)

# reshape to 2-D image
arr = arr.reshape(shape[:2] + (1,))

return arr
36 changes: 28 additions & 8 deletions color_matcher/mvgd_matcher.py
Expand Up @@ -36,15 +36,14 @@ def __init__(self, *args, **kwargs):
try:
self._fun_name = [kw for kw in list(self._fun_dict.keys()) if kwargs['method'].__contains__(kw)][0]
except (BaseException, IndexError):
# use MKL as default
# default function
self._fun_name = 'mkl'
self._fun_call = self._fun_dict[self._fun_name] if self._fun_name in self._fun_dict else self.mkl_solver

# initialize variables
self.r, self.z, self.cov_r, self.cov_z, self.mu_r, self.mu_z, self.transfer_mat = [None]*7
self._init_vars()

def _init_vars(self):
def init_vars(self):

# reshape source and reference images
self.r, self.z = self._src.reshape([-1, self._src.shape[2]]).T, self._ref.reshape([-1, self._ref.shape[2]]).T
Expand All @@ -55,14 +54,17 @@ def _init_vars(self):
# compute color channel means
self.mu_r, self.mu_z = self.r.mean(axis=1)[..., np.newaxis], self.z.mean(axis=1)[..., np.newaxis]

def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: FunctionType = None) -> np.ndarray:
# validate dimensionality
self.check_dims()

def multivar_transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: FunctionType = None) -> np.ndarray:
"""
Transfer function to map colors based on for Multi-Variate Gaussian Distributions (MVGDs).
:param src: Source image that requires transfer
:param ref: Palette image which serves as reference
:param fun: optional argument to pass a transfer function to solve for covariance matrices
:param fun: Optional argument to pass a transfer function to solve for covariance matrices
:param res: Resulting image after the mapping
:type src: :class:`~numpy:numpy.ndarray`
Expand All @@ -82,7 +84,7 @@ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: Function
self.validate_color_chs()

# re-initialize variables to account for change in src and ref when passed to self.transfer()
self._init_vars()
self.init_vars()

# set solver function for transfer matrix
self._fun_call = fun if fun is FunctionType else self._fun_call
Expand All @@ -108,14 +110,17 @@ def mkl_solver(self):
"""

# validate dimensionality
self.check_dims()

eig_val_r, eig_vec_r = np.linalg.eig(self.cov_r)
eig_val_r[eig_val_r < 0] = 0
val_r = np.diag(np.sqrt(eig_val_r[::-1]))
vec_r = np.array(eig_vec_r[:, ::-1])
inv_r = np.diag(1. / (np.diag(val_r + np.spacing(1))))

mat_c = np.dot(val_r, np.dot(vec_r.T, np.dot(self.cov_z, np.dot(vec_r, val_r))))
[eig_val_c, eig_vec_c] = np.linalg.eig(mat_c)
mat_c = val_r @ vec_r.T @ self.cov_z @ vec_r @ val_r
eig_val_c, eig_vec_c = np.linalg.eig(mat_c)
eig_val_c[eig_val_c < 0] = 0
val_c = np.diag(np.sqrt(eig_val_c))

Expand All @@ -133,6 +138,11 @@ def analytical_solver(self) -> np.ndarray:
"""

# validate dimensionality
self.check_dims()
if self.r.shape[-1] != self.z.shape[-1]:
raise Exception('Analytical MVGD solution requires spatial dimensions of both images to be equal')

cov_r_inv = np.linalg.pinv(self.cov_r)
cov_z_inv = np.linalg.pinv(self.cov_z)

Expand Down Expand Up @@ -164,3 +174,13 @@ def w2_dist(mu_a: np.ndarray, mu_b: np.ndarray, cov_a: np.ndarray, cov_b: np.nda
vars_dist = np.trace(cov_a+cov_b - 2*(np.dot(np.abs(cov_b)**.5, np.dot(np.abs(cov_a), np.abs(cov_b)**.5))**.5))

return float(mean_dist + vars_dist)

def check_dims(self):
"""
Catch error for wrong color channel number (e.g., gray scale image)
:return: None
"""

if np.ndim(self.cov_r) == 0 or np.ndim(self.cov_z) == 0:
raise Exception('Wrong color channel dimensionality for %s method' % self._fun_name)
45 changes: 37 additions & 8 deletions color_matcher/top_level.py
Expand Up @@ -34,36 +34,65 @@ def __init__(self, *args, **kwargs):
super(ColorMatcher, self).__init__(*args, **kwargs)

self._method = kwargs['method'] if 'method' in kwargs else 'default'
self._funs = []

def main(self, method: str = None) -> np.ndarray:
def main(self) -> np.ndarray:
"""
The main function is the high-level entry point performing the mapping. Valid methods are:
The main function is the high-level entry point performing the mapping based on instantiation arguments.
:return: Resulting image after color mapping
:rtype: np.ndarray
"""

self.transfer()

return self._src

def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, method: str = None) -> np.ndarray:
"""
Transfer function to map colors based on provided transfer method.
:param src: Source image that requires transfer
:param ref: Palette image which serves as reference
:param method: ('default', 'hm', 'reinhard', 'mvgd', 'mkl', 'hm-mvgd-hm', 'hm-mkl-hm') determining color mapping
:type src: :class:`~numpy:numpy.ndarray`
:type ref: :class:`~numpy:numpy.ndarray`
:type method: :class:`str`
:return: Resulting image after color mapping
:rtype: np.ndarray
"""

# assign input arguments to variables (if provided)
self._method = self._method.lower() if method is None else method.lower()
self._src = src if src is not None else self._src
self._ref = ref if ref is not None else self._ref

# color transfer methods (to be iterated through)
if self._method == METHODS[0]:
funs = [self.transfer]
self._funs = [self.multivar_transfer]
elif self._method == METHODS[1]:
funs = [self.hist_match]
self._funs = [self.hist_match]
elif self._method == METHODS[2]:
funs = [self.reinhard]
self._funs = [self.reinhard]
elif self._method in METHODS[3:5]:
funs = [self.transfer]
self._funs = [self.multivar_transfer]
elif self._method in METHODS[5:]:
funs = [self.hist_match, self.transfer, self.hist_match]
self._funs = [self.hist_match, self.multivar_transfer, self.hist_match]
else:
raise BaseException('Method type \'%s\' not recognized' % method)

# check if three color channels are provided
self.validate_img_dims()

# check provided color channels
self.validate_color_chs()

# proceed with the color match
for fun in funs:
for fun in self._funs:
self._src = fun(self._src, self._ref)

return self._src
2 changes: 1 addition & 1 deletion docs/build/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 36e65157a5b8d9f8afb88b0180e76b66
config: d69c61c77e9336c13f9a1800c01ddd93
tags: 645f666f9bcd5a90fca523b33c5a78b7
2 changes: 1 addition & 1 deletion docs/build/html/_static/documentation_options.js
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
VERSION: '0.4.1',
VERSION: '0.5.0',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
Expand Down

0 comments on commit f4d4339

Please sign in to comment.