In [None]:
import numpy as np
import tensorflow as tf

from matplotlib import pyplot as plt
%matplotlib inline

from scipy.linalg import dft
import matplotlib.cm as cm

TODO

plot functions to setup parameters
 - how many m's for different filter masks
 - ...


refactor the code: 


rename variables that they have the same name in different contexts
 That the names are unique and semantically meaningful

- e.g. filter_size == k_size
-      n_channels -> num_output_channels

- Use "constants" for magic numbers

- more comments

- the term "weight" is used twice:
 - interpolation weight: for sampling from polar to pixel filter, i.e. the $g_i({\bf r}_j)$
 - ring weight: weighing of the different rings for a m which a summed up to construct a filter
 - Use the names `interpolation_weight` and `ring_weight`
 
 
 - Todo : Remove tensorflow - here we just need numpy!

Open questions:
 - are the heuristics of hyperparameters, e.g. num_rings, radii meaningful

These heuristics should be in separate functions on one place to play with them

- Visualization to get a "feeling" for these parameters

In [None]:
session = tf.InteractiveSession()

In [None]:
#from models.rot_mnist_downloader import *
#from models.utilities import *

In [None]:
# number of samples of a ring
# returns for moderate filtersizes (<32) always 101!!!
def n_samples(filter_size):
    return int(np.maximum(np.ceil(np.pi * filter_size), 101))  ############## <--- One source of instability

In [None]:
# helper function to generate points on a grid 
# used for the pixel coordinates for the filter masks
def L2_grid(shape):
    foveal_center = np.asarray([filter_size, filter_size])/2.
    # Get neighbourhoods
    lin = np.arange(shape) + 0.5
    J, I = np.meshgrid(lin, lin)
    I = I - foveal_center[1]
    J = J - foveal_center[0]
    return np.vstack((np.reshape(I, -1), np.reshape(J, -1)))

In [None]:
def get_num_rings(filter_size):
    return int(np.maximum(filter_size/2, 2))


In [None]:
# Computes the weights $g_i(r_j)$ see Figure 6 of the paper
def get_interpolation_weights(filter_size, m, n_rings=None, plot=False):
    """Resample the patches on rings using Gaussian interpolation"""
    
    if n_rings is None:
        n_rings = get_num_rings(filter_size)
        
    radii = np.linspace(m!=0, n_rings-0.5, n_rings) #<-------look into m and n-rings-0.5
    
    # We define pixel centers to be at positions 0.5
    
    # The angles to sample
    # np.maximum(np.ceil(np.pi*filter_size),101) 
    N = n_samples(filter_size)
    lin = (2*np.pi*np.arange(N))/N
    # lin: equally spaced angles 
    
    # Sample equi-angularly along each ring
    ring_locations = np.vstack([-np.sin(lin), np.cos(lin)])
    
    # Create interpolation coefficient coordinates
    coords = L2_grid(filter_size)
    
    # Sample positions wrt patch center IJ-coords
    radii = radii[:,np.newaxis,np.newaxis,np.newaxis]
    ring_locations = ring_locations[np.newaxis,:,:,np.newaxis]
    
    # euclidian distance
    diff = radii*ring_locations - coords[np.newaxis,:,np.newaxis,:]
    dist2 = np.sum(diff**2, axis=1)
    
    # Convert distances to weightings
    bandwidth = 0.5 # 
    # Lowpass filter - Gaussian blur? antialiasing
    # see page 12 "Fourier Analysis in Polar and Spherical Coordinates"
    # for the factor 1/2 !
    
    weights = np.exp(-0.5*dist2/(bandwidth**2))
    # Normalize
    
    if plot: # just for illustration
        print ("number of rings", n_rings)
        print ("radii", radii.flatten())
        print("number of samples per ring:", N)
        print("dist2.shape", dist2.shape)
        plt.figure(figsize=(5,5))
        plt.plot(coords[0], coords[1], "r*")
        for r in radii.flatten():
            plt.plot(r*ring_locations[:,0].flatten(), r*ring_locations[:,1].flatten(), "b.")
            
            
    return weights/np.sum(weights, axis=2, keepdims=True)


In [None]:
# Compute for each sample point the weights of each ring point
# weighting by Gaussian distance
filter_size = ksize = 4
m = 1
# if m=0: there is a ring in the center 
interpolation_weights = get_interpolation_weights(filter_size=filter_size, m=m, plot=True)
# ouput_shape: nb_rings, nb_points_per_ring, nb_sample_points
# nb_sample_points == ksize**2
nb_rings, nb_points_per_ring, nb_sample_points = interpolation_weights.shape

In [None]:
# ouput_shape: nb_rings, nb_points_per_ring, nb_sample_points
# interpolation_weights

In [None]:
nb_samples = 101
DFT = dft(nb_samples)

In [None]:
# The columns of the DFT are the basis functions to compute the
# for the DFT matrix see e.g. https://en.wikipedia.org/wiki/DFT_matrix

# the spherical basis functions are the same as the DFT basis functions 
plt.figure(figsize=(7,7))
m=1
plt.plot(range(len(np.real(DFT[m,:]))), np.real(DFT[m,:]), "r-")
plt.plot(range(len(np.imag(DFT[m,:]))), np.imag(DFT[m,:]), "g-")

# what's about the negative m ??
# the magic: if we use negative indexing in numpy 
# e.g. m=-1 we get 100 with is the conjugate complex of m=1  
# e.g.
omega = np.exp(2*np.pi*1j/nb_samples)
np.testing.assert_almost_equal(omega**1, np.conj(omega**(nb_samples-1)))

In [None]:
# this is misleading and distracting.
# but maybe helps in further improving the model???
# 
# figure 6 of the paper
# LPF = W(x_i) = \sum_j g_i(r_j)W(r_j)
# the sum is done by np.dot(DFT[m], interpolation_weights)
# that corresponds to a fourier transformation
# of the interpolation weights

# the interpolation weights are fourier transformed in the 
# frequency space. the binning number corresponds to the rotation order

nb_samples = 101
LPF = np.dot(DFT, interpolation_weights)
print(LPF.shape, interpolation_weights.shape)
ring_nr=1
pixel_index=10
plt.figure(figsize=(7,7))
plt.plot(range(len(interpolation_weights[ring_nr,:])), np.real(interpolation_weights[ring_nr,:,pixel_index]), "b-")
plt.plot(range(len(np.real(LPF[:,ring_nr,pixel_index]))), np.real(LPF[:,ring_nr,pixel_index]), "r.")
plt.plot(range(len(np.imag(LPF[:,ring_nr,pixel_index]))), np.imag(LPF[:,ring_nr,pixel_index]), "g.")
plt.xlabel("m")

In [None]:
# TODO: visualize the m components as vectors 
plt.figure(figsize=(10,10))
max_m = 3
real = np.real(LPF[:max_m,ring_nr,pixel_index])
imag = np.imag(LPF[:max_m,ring_nr,pixel_index])
t = np.linspace(0,2*np.pi,101)
plt.xlim(-10,10)
plt.ylim(-10,10)
plt.plot(0.1*np.cos(t),0.1* np.sin(t))
plt.plot(real, imag,'ro')

In [None]:
# Filter corresponding to a ring 
# the filters in the cnn are weighted sums of such filters
m=1
ring_nr=0
k = int(np.sqrt(len(LPF[m, ring_nr,:])))
f = LPF[m, ring_nr,:].reshape((k,k))
fr = np.real(f)
fi = np.imag(f)
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.set_title("real")
ax1.imshow(fr, interpolation='nearest', cmap=cm.Greys_r)
ax2 = fig.add_subplot(122)
ax2.set_title("imag")
ax2.imshow(fi, interpolation='nearest', cmap=cm.Greys_r)
plt.show()

In [None]:
# the filters corresponding to the different rings are weighted and summed
# these weights (of the sum) are learnt (this is the learnt radial part R(r))
# 

m=1
k = int(np.sqrt(len(LPF[m, ring_nr,:])))
fr=None
for lpr in LPF[m, :,:]:
    f = lpr.reshape((k,k))
    if fr is None:
        fr = np.real(f)
        fi = np.imag(f)
    else:
        fr += np.real(f)
        fi += np.imag(f)
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.set_title("real")
ax1.imshow(fr, interpolation='nearest', cmap=cm.Greys_r)
ax2 = fig.add_subplot(122)
ax2.set_title("imag")
ax2.imshow(fi, interpolation='nearest', cmap=cm.Greys_r)
plt.show()

In [None]:
def get_filters(ring_weight_dict, filter_size, P=None, n_rings=None, plot=False):
    """Perform single-frequency DFT on each ring of a polar-resampled patch
    
    R: weight dict
    filter_size: ksize
    P: hase dict
    nb_rings
    
    """
    k = filter_size
    filters = {}
    num_of_samples = n_samples(k) # on a ring
    
    # (num_samples, num_samples)
    DFT_Matrix = dft(num_of_samples)
    
    # R is the weight dict
    # r are the weights of the weight dict (initialized by He-Initalization?)
    # r-shape = [n_rings, input_channels, n_channels]
    for m, r in ring_weight_dict.items():
        rsh = r.get_shape().as_list()
        # Get the basis matrices
        interpolation_weights = get_interpolation_weights(filter_size=k, m=m, n_rings=n_rings)
        # ouput_shape: nb_rings, nb_points_per_ring, nb_sample_points
            
        # Get the DFT-Matrix 
        DFT_m = DFT_Matrix[m,:] 
        # DFT matrix of the angle phi 
        # here only the m-th frequency is used 
        
        # (nb_points_per_ring,) * (nb_rings, nb_points_per_ring, nb_sample_points)
        LPF = np.dot(DFT_m, interpolation_weights).T
        # LPF.shape = (nb_sample_points, nb_rings)
        
        # the weights are transformed in the angular frequency space 
        cosine = np.real(LPF).astype(np.float32)
        sine = np.imag(LPF).astype(np.float32)
        
        # Reshape for multiplication with radial profile
        cosine = tf.constant(cosine)
        sine = tf.constant(sine)
        
        num_rings=rsh[0]
        num_input_channels = rsh[1]
        num_output_channels = rsh[2]
        
        # the shaping is quite confusing: TODO refactor
        # the shapes are combined and then separated without any need?
        
        # Project taps on to rotational basis
        r = tf.reshape(r, tf.stack([num_rings, num_input_channels*num_output_channels]))
        # r [rings, input_channels*output_channels]
        
        # matrix multiplication is just for adding the different rings!
        ucos = tf.reshape(tf.matmul(cosine, r), tf.stack([k, k, num_input_channels, num_output_channels]))
        usin = tf.reshape(tf.matmul(sine, r), tf.stack([k, k, num_input_channels, num_output_channels]))
        # (nb_sample_points, nb_rings) * (n_rings, input_channels, n_channels]
        # + reshape =>
        # [k, k, input_channels, n_channels]
                                      
        # learnt relative angular orientations "beta" of the filters     
        if P is not None:
            # Rotate basis matrices
            ucos_ = tf.cos(P[m])*ucos + tf.sin(P[m])*usin
            usin = -tf.sin(P[m])*ucos + tf.cos(P[m])*usin
            ucos = ucos_
        filters[m] = (ucos, usin)
        
        if plot:
            print("m:", m, DFT_m.shape)
            print("weights.shape", weights.shape)
            print ("LPF.shape", LPF.shape)
    return filters

In [None]:
# TODO: Think about the problem and find a better initialization!!
#
# Is "He" really appropriate?
# "Xavier" resp. "He" assumes that the channels/pixels are randomly sampled independently
# and from the number of inputs and outputs (for the forward and backward pass) the
# values are initialized to get a unit preactivation in the range (-1,1).
# experimentally: compute an appropriate statistic
#
# How relevant is this if we use Batch-Normalization?
#
# Here from the ring-weights the pixels are computated!
# Hyperparameter: std_mult=0.4

def get_weights(filter_shape, W_init=None, std_mult=0.4, name='W'):
    """Initialize weights variable with He method

    filter_shape: list of filter dimensions (num_rings, num_input_channels, num_output_channels)
    W_init: numpy initial values (default None)
    std_mult: multiplier for weight standard deviation (default 0.4)
    name: (default W)
    device: (default /cpu:0)
    """
    if W_init == None:
        stddev = std_mult*np.sqrt(2.0 / np.prod(filter_shape[:3]))
        W_init = tf.random_normal_initializer(stddev=stddev)
    return tf.get_variable(name, dtype=tf.float32, shape=filter_shape,
            initializer=W_init)

# for each in ring in every input-output pair a weight is randomly 
# initialized. this corresponds to learned the radial part of the basis functions
# these weights are learnt with the training data

In [None]:
def get_weights_dict(shape, max_order, std_mult=0.4, n_rings=None, name='W'):
    """Return a dict of weights.

    shape: list of filter shape [h,w,i,o] --- note we use h=w
    max_order: returns weights for m=0,1,...,max_order, or if max_order is a
    tuple, then it returns orders in the range.
    std_mult: He init scaled by std_mult (default 0.4)
    name: (default 'W')
    dev: (default /cpu:0)
    """
    if isinstance(max_order, int):
        orders = range(-max_order, max_order+1)
    else:
        diff = max_order[1]-max_order[0]
        orders = range(-diff, diff+1)
    weights_dict = {}
    filter_size = shape[0]
    for i in orders:
        if n_rings is None:
            n_rings = get_num_rings(filter_size) 
        sh = [n_rings,] + shape[2:]
        nm = name + '_' + str(i)
        weights_dict[i] = get_weights(sh, std_mult=std_mult, name=nm)
    return weights_dict

In [None]:
def get_phase_dict(n_in, n_out, max_order, name='b'):
    """Return a dict of phase offsets
    n_in: number of input channes
    n_out: number of output channels
    max_order: maximum m
    name: base name of filter
    """
    if isinstance(max_order, int):
        orders = range(-max_order, max_order+1)
    else:
        diff = max_order[1]-max_order[0]
        orders = range(-diff, diff+1)
    phase_dict = {}
    for i in orders:
        init = np.random.rand(1,1,n_in,n_out) * 2. *np.pi
        init = np.float32(init)
        phase = tf.get_variable(name+'_'+str(i), dtype=tf.float32,
                                shape=[1,1,n_in,n_out],
            initializer=tf.constant_initializer(init))
        phase_dict[i] = phase
    return phase_dict

# the phase differences between the different filters 
# these are learned, too.

In [None]:
ksize = filter_size = 4 # kernel size
output_channels = 3 # output_channels
max_order = 2
stddev = 0.4

n_rings = get_num_rings(filter_size)

input_channels = 1 
shape = [ksize, ksize, input_channels, output_channels]
name="_test"

# shape of the Qs: nb_rings, nb_input_channels, nb_output_channels
ring_weight_dict = get_weights_dict(shape, max_order, std_mult=stddev, n_rings=n_rings, name='W'+name)
# dict of tf.variables: {0: <tf.Variable 'W_test_0:0' shape=(4, 1, 3) dtype=float32_ref>, 
# 1: <tf.Variable 'W_test_1:0' shape=(4, 1, 3) dtype=float32_ref>, 
# 2: <tf.Variable 'W_test_2:0' shape=(4, 1, 3) dtype=float32_ref>, 
# -1: <tf.Variable 'W_test_-1:0' shape=(4, 1, 3) dtype=float32_ref>, 
# ...... -2: <tf.Variable 'W_test_-2:0' shape=(4, 1, 3) dtype=float32_ref>}


In [None]:
print(ring_weight_dict[-1].shape)

In [None]:
tf.global_variables_initializer().run()

In [None]:
# todo:
# manual manipulation of the "ring_weight_dict"
# to see in the plot's the influences of the "learned ring weighting"

In [None]:
filters_ = get_filters(ring_weight_dict, filter_size=filter_size, P=None, n_rings=n_rings)
filters_real = dict() 
filters_imag = dict() 

for m, f in filters_.items():
    f_real_, f_imag_ = filters_[m]
    f_real = f_real_.eval()
    f_imag = f_imag_.eval()
    filters_real[m] = f_real
    filters_imag[m] = f_imag
    

In [None]:
# todo plot all filters by a loop over 

in_channel = 0
out_channel = 1
m = 2
f_real = filters_real[m]
f_imag = filters_imag[m]
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.set_title("real")
ax1.imshow(f_real[:,:,in_channel, out_channel], interpolation='nearest', cmap=cm.Greys_r)
ax2 = fig.add_subplot(122)
ax2.set_title("imag")
ax2.imshow(f_imag[:,:,in_channel, out_channel], interpolation='nearest', cmap=cm.Greys_r)
plt.show()

In [None]:
# TODO: Plot Phase influences (filter) 