In [1]:
%load_ext autoreload
%autoreload 2

In [38]:
import numpy as np
import chainer
import cv2
import abc
import sys

from chainer import functions as F

from ipywidgets import interactive
from tabulate import tabulate

from matplotlib import pyplot as plt
from dataclasses import dataclass
from typing import Tuple

from fve_layer.backends.chainer import links

print(f"Chainer version: {chainer.__version__}")

chainer.config.train = False

Chainer version: 7.7.0


# Arguments

In [53]:
@dataclass
class Args:
    n_components:       int     = 1
    n_dims:             int     = 2
        
    init_mu:            float   = 0
    rnd_mu:             bool    = False
    init_sig:           float   = 1
    eps:                float   = 1e-2
        
    x_range:            Tuple[float, float] = (-1000, 1000)
    y_range:            Tuple[float, float] = (-1000, 1000)
    n_samples:          int     = 99



# Data and Layer initializers

In [55]:
def new_layer(args: Args, mu=None, sig=None, N=None):
    
    init_mu = args.init_mu if mu is None else mu
    n_comp = args.n_components if N is None else N
    
    # the random mu initialization (uniform in [-init_mu, init_mu]) is realized by the layer
    if not args.rnd_mu:
        init_mu = np.full((args.n_dims, n_comp), init_mu, dtype=np.float32)
        

    layer = links.FVELayer_noEM(
        in_size=args.n_dims,
        n_components=n_comp,

        init_mu=init_mu,
        init_sig=args.init_sig if sig is None else sig,

        eps=args.eps
    )
    layer.cleargrads()

    return layer


def new_data(args: Args):
    
    xs,ys = [np.linspace(*r, num=args.n_samples) for r in [args.x_range, args.y_range]]
    grid = np.stack(np.meshgrid(ys, xs), axis=2).astype(np.float32)
    return grid


# Plotting helper

In [134]:
def _norm(arr, axis=(0,1)):
    _min = arr.min(axis=axis)
    arr[:] -= _min
    _max = arr.max(axis=axis, keepdims=True)
    _max_mask = _max != 0
    if _max_mask.any():
        arr_mask = np.broadcast_to(_max_mask, arr.shape)
        arr[arr_mask] /= np.broadcast_to(_max, arr.shape)[arr_mask]
    
    return _min, np.squeeze(_max)

def _denorm(arr, _min, _max):
    _max_mask = _max != 0
    if _max_mask.any():
        arr = arr.copy()
        arr[_max_mask] *= _max[_max_mask] 

    return arr + _min

def _imshow(args, ax, im, *, title=None, YUV2RGB=False, display_values=True, ndecimals=4):
    assert im.shape[-1] >= 2
    im = getattr(im, "array", im)
    
    _im = np.zeros(im.shape[:-1] + (3,), dtype=im.dtype)
    _im[..., 1:] = im[..., :2].copy()
    
    _min, _max = _norm(_im)
    
    if YUV2RGB:
        _im = cv2.cvtColor(_im, cv2.COLOR_YUV2RGB)

    _norm(_im)
    _im[..., 0] = 0

    ax.imshow(_im.round(ndecimals), extent=args.x_range + args.y_range, origin="lower")
    
    if title is not None:
        ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    
    
    if display_values:
        N = 5
        h, w, _ = _im.shape
        xs = np.linspace(*args.x_range, N)
        ys = np.linspace(*args.y_range, N)
        _js = np.linspace(0, w-1, N).astype(int)
        _is = np.linspace(0, h-1, N).astype(int)

        fmt = "\n".join(["{x} | {y}", "R={0: .2f}", "G={1: .2f}", "B={2: .2f}"])
        for y, i in zip(ys, _is):
            for x, j in zip(xs, _js):
                
                text = fmt.format(*_denorm(_im[i, j], _min, _max), x=x, y=y)
                if i == _is.min():
                    verticalalignment = "bottom"
                elif i == _is.max():
                    verticalalignment = "top"
                else:
                    verticalalignment = "center"
                    
                if j == _js.min():
                    horizontalalignment = "left"
                elif j == _js.max():
                    horizontalalignment = "right"
                else:
                    horizontalalignment = "center"
                    
                ax.text(x,y, s=text, 
                        bbox=dict(facecolor="white", alpha=0.5),
                        horizontalalignment=horizontalalignment,
                        verticalalignment=verticalalignment,
                       )
                ax.scatter(x, y, marker="x", color="white")
                


# Printing helper


In [135]:
def _print(mu_sig_w, data_name="array", headers=["\u03BC", "\u03C3", "w"]):
    
    rows = []
    _get = lambda param: getattr(param, data_name, param)
    mu, sig, w = map(_get, mu_sig_w)
    for i, params in enumerate(zip(mu.T, sig.T, w), 1):
        rows.append([f"Comp #{i}"] + [p.round(6) for p in params])
    print(tabulate(rows, headers=headers, tablefmt="fancy_grid"))

def print_arrays(layer):
    _print([layer.mu, layer.sig, layer.w], "array")
    
def print_grad(layer, enc_part):
    _print([layer.mu, layer._sig, layer._w], "grad", 
           headers=[f"\u2202 F{enc_part} / \u2202 {_}" for _ in ["\u03BC", "_\u03C3", "_w"]])
    

# Base Plotter class

In [136]:

class Plotter(abc.ABC):
    __name__ = "Plotter"
    
    def __init__(self, max_comps=10, **kwargs):
        self.args = Args(**kwargs)
        print(self.args)
        self.X = None
        self.max_comps = max_comps
        
    def __call__(self=None, mu=1, sig=1, N=1):
        layer = new_layer(self.args, mu, sig, N)
        if self.X is None:
            self.X = new_data(self.args)
        
        print_arrays(layer)
        self.plot(layer, self.X)
        
        #plt.tight_layout()
        plt.show()
        plt.close()
    
    def encode(self, layer, X):
        h, w, c = X.shape
        X = chainer.Variable(X)
        
        enc = layer(X.reshape(h*w, 1, c))
        enc = enc.reshape(h*w, 2, layer.n_components, layer.in_size)
        
        if layer.n_components > self.max_comps:
            print("More than {0} components! Plotting only the first {0}!".format(self.max_comps), 
                  file=sys.stderr)
            enc = enc[:, :, :self.max_comps]
        
        mu_enc, sig_enc = enc[:, 0], enc[:, 1]
        
        return X, mu_enc, sig_enc
        
    
    def interact(self, height=None,
                 mu=(-10, 10, 0.1), 
                 sig=(0.1, 100, 0.1),
                 N=(1, 32, 1)):
        _plot = interactive(self, mu=mu, sig=sig, N=N)
        
        # set fixed size if flickering is annoying
        if height is not None:
            output = _plot.children[-1]
            output.layout.height = f"{height}px"
        
        return _plot
        
    def plot(self, layer, X):
        fig, axs = plt.subplots(1,2, figsize=(16,8))
        fig.suptitle("Data and the GMM")
        
        layer.plot(ax=axs[1])
        axs[1].set_xlim(*self.args.x_range)
        axs[1].set_ylim(*self.args.y_range)
        
        _imshow(self.args, axs[0], X, title="Data")
        
        self.plot_core(layer, X)
        
    @abc.abstractmethod
    def plot_core(self, layer, X):
        pass

In [137]:
class EncodingPlotter(Plotter):
    
    def plot_core(self, layer, X):
        n0, n1, c = X.shape
        X, mu_enc, sig_enc = self.encode(layer, X)
            
        n_comp = mu_enc.shape[1]
        
        
        for i in np.arange(n_comp):
            
            fig, axs = plt.subplots(1, 2, figsize=(16,8))
            fig.suptitle(f"Component #{i}")
            for j, (_x, title) in enumerate(zip([mu_enc, sig_enc], ["$\mu$-Part", "$\sigma^2$-Part"])):
                _imshow(self.args, axs[j], _x[:, i].reshape(n0, n1, 2), title=title)


plotter = EncodingPlotter(
    x_range=(-10, 10),
    y_range=(-10, 10),
    
    eps=0,
    
    n_samples=23,
    
    rnd_mu=True,
)

plotter.interact()

Args(n_components=1, n_dims=2, init_mu=0, rnd_mu=True, init_sig=1, eps=0, x_range=(-10, 10), y_range=(-10, 10), n_samples=23)


interactive(children=(FloatSlider(value=1.0, description='mu', max=10.0, min=-10.0), FloatSlider(value=1.0, de…

In [125]:
class GradPlotter(Plotter):
    
    def plot_core(self, layer, X):
       
        n0,n1,c = X.shape
        n0, n1, c = X.shape
        X, mu_enc, sig_enc = self.encode(layer, X)
        
        
        _title = f"All components"
        print(f"=== {_title} ===")
        title_fmt = "$\dfrac{{\partial}}{{\partial x}}F_{{{param}}}(x)$" 
        
        enc_names = ["\mu", "\sigma^2", "\Theta"]
        encs = [mu_enc, sig_enc, F.concat([mu_enc, sig_enc], axis=1)]
        
        fig, axs = self.plot_grad(layer, X, encs, enc_names, 
                                  title_fmt=title_fmt)
        fig.suptitle(_title)
        
        
        for i in np.arange(mu_enc.shape[1]):
            _title = f"Component #{i}"
            print(f"=== {_title} ===")
            
            title_fmt = "$\dfrac{{\partial}}{{\partial x}}F_{{{param}_{{" + str(i) + "}}}}(x)$" 

            enc_names = ["\mu", "\sigma^2", "\Theta"]
            encs = [mu_enc[:, i], sig_enc[:, i]]
            encs.append(F.concat(encs, axis=-1))
            
            fig, axs = self.plot_grad(layer, X, encs, enc_names, 
                                      title_fmt=title_fmt)

            fig.suptitle(_title)
        
    def plot_grad(self, layer, X, encs, enc_names, *, title_fmt):
        
        fig, axs = plt.subplots(1, len(encs), figsize=(8*len(encs), 8))
        
        for ax, _enc, enc_name in zip(axs, encs, enc_names):
            X.grad = None
            layer.cleargrads()
            F.sum(_enc).backward()
            title = title_fmt.format(param=enc_name)
            if "mu" in enc_name:
                enc_part = "\u03BC"
                
            elif "sig" in enc_name:
                enc_part = "\u03C3"
                
            else:
                enc_part = "\u03B8"
                
            print_grad(layer, 
                enc_part="\u03BC" if "mu" in enc_name else "\u03C3" if "sig" in enc_name else "\u03B8" )
            _imshow(self.args, ax, X.grad, title=title)    
            
        return fig, axs
        

In [138]:

plotter = GradPlotter(    
    x_range=(-10,10),
    y_range=(-10,10),
    n_components=3,
    
    init_mu=5,
    eps=0,
    
    n_samples=99,
    rnd_mu=True,
)

plotter.interact()

Args(n_components=3, n_dims=2, init_mu=5, rnd_mu=True, init_sig=1, eps=0, x_range=(-10, 10), y_range=(-10, 10), n_samples=99)


interactive(children=(FloatSlider(value=1.0, description='mu', max=10.0, min=-10.0), FloatSlider(value=1.0, de…

In [11]:
layer = new_layer(plotter.args)

layer._sig, layer.sig

(variable _sig([[0., 0.],
                [0., 0.]]),
 variable([[1., 1.],
           [1., 1.]]))

In [None]:
foo = chainer.Variable(np.random.randn(2,3).astype(np.float32), )