Skip to content

Commit

Permalink
refactor(dist): add metric plot as svg
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Jun 27, 2021
1 parent 6a08209 commit 1e4a834
Show file tree
Hide file tree
Showing 3 changed files with 6,552 additions and 7 deletions.
10 changes: 10 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ Results

|
|vspace|

|metric|

|vspace|

Installation
------------

Expand Down Expand Up @@ -114,6 +120,10 @@ API Usage
.. _source: https://github.com/hahnec/color-matcher/archive/master.zip

.. |metric| raw:: html

<img src="https://raw.githubusercontent.com/hahnec/color-matcher/develop/hist+wasser_dist.svg" max-width:"100%">

.. |src_photo| raw:: html

<img src="https://raw.githubusercontent.com/hahnec/color-matcher/master/tests/data/scotland_house.png" width="200px" max-width:"100%">
Expand Down
14 changes: 7 additions & 7 deletions color_matcher/mvgd_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs):
self._fun_call = self._fun_dict[self._fun_name] if self._fun_name in self._fun_dict else self.mkl_solver

# initialize variables
self.r, self.z, self.cov_r, self.cov_z, self.mu_r, self.mu_z = [None]*6
self.r, self.z, self.cov_r, self.cov_z, self.mu_r, self.mu_z, self.transfer_mat = [None]*7
self._init_vars()

def _init_vars(self):
Expand Down Expand Up @@ -88,10 +88,10 @@ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: Function
self._fun_call = fun if fun is FunctionType else self._fun_call

# compute transfer matrix
transfer_mat = self._fun_call()
self.transfer_mat = self._fun_call()

# transfer the intensity distributions
res = np.dot(transfer_mat, self.r - self.mu_r) + self.mu_z
res = np.dot(self.transfer_mat, self.r - self.mu_r) + self.mu_z

# reshape pixel array
res = res.T.reshape(self._src.shape)
Expand Down Expand Up @@ -130,13 +130,13 @@ def analytical_solver(self) -> np.ndarray:
"""

cov_r_inv = np.linalg.inv(self.cov_r)
cov_z_inv = np.linalg.inv(self.cov_z)
cov_r_inv = np.linalg.pinv(self.cov_r)
cov_z_inv = np.linalg.pinv(self.cov_z)

# compute transfer matrix using analytical method
transfer_mat = np.linalg.pinv((self.z-self.mu_z).T @ cov_z_inv) @ (self.r-self.mu_r).T @ cov_r_inv
self.transfer_mat = np.linalg.pinv((self.z-self.mu_z).T @ cov_z_inv) @ (self.r-self.mu_r).T @ cov_r_inv

return transfer_mat
return self.transfer_mat

@staticmethod
def w2_dist(mu_a: np.ndarray, mu_b: np.ndarray, cov_a: np.ndarray, cov_b: np.ndarray) -> float:
Expand Down
Loading

0 comments on commit 1e4a834

Please sign in to comment.