In [64]:
import argparse
import logging
import numpy as np
import os
from multiprocessing import Pool
from functools import partial
from scipy.misc import imresize
from scipy.ndimage import imread

# from inference import test_image
# from learning import train_image

In [65]:
from itertools import izip
import logging
import numpy as np
import networkx as nx
from numpy.random import rand, randint

from science_rcn.dilation.dilation import dilate_2d
# from preproc import Preproc

In [66]:
from collections import namedtuple
import logging
import numpy as np
import networkx as nx
from scipy.spatial import distance, cKDTree


In [67]:
import logging
import numpy as np
from scipy.ndimage import maximum_filter
from scipy.ndimage.filters import gaussian_filter
from scipy.signal import fftconvolve

==================    

learning    

==================    

In [68]:
ModelFactors = namedtuple('ModelFactors', 'frcs edge_factors graph')

def train_image(img, perturb_factor=2.):
    """Main function for training on one image.

    Parameters
    ----------
    img : 2D numpy.ndarray
        The training image.
    perturb_factor : float
        How much two points are allowed to vary on average given the distance
        between them. See Sec S2.3.2 for details.

    Returns
    -------
    frcs : numpy.ndarray of numpy.int
        Nx3 array of (feature idx, row, column), where each row represents a
        single pool center
    edge_factors : numpy.ndarray of numpy.int
        Nx3 array of (source pool index, target pool index, perturb_radius), where
        each row is a pairwise constraints on a pair of pool choices.
    graph : networkx.Graph
        An undirected graph whose edges describe the pairwise constraints between
        the pool centers.
        The tightness of the constraint is in the 'perturb_radius' edge attribute.
    """
    # Pre-processing layer (cf. Sec 4.2.1)
    preproc_layer = Preproc()
    bu_msg = preproc_layer.fwd_infer(img)
    # Sparsification (cf. Sec 5.1.1)
    frcs = sparsify(bu_msg)
    # Lateral learning (cf. 5.2)
    graph, edge_factors = learn_laterals(frcs, bu_msg, perturb_factor=perturb_factor)
    return ModelFactors(frcs, edge_factors, graph)


def sparsify(bu_msg, suppress_radius=3):
    """Make a sparse representation of the edges by greedily selecting features from the
    output of preprocessing layer and suppressing overlapping activations.

    Parameters
    ----------
    bu_msg : 3D numpy.ndarray of float
        The bottom-up messages from the preprocessing layer.
        Shape is (num_feats, rows, cols)
    suppress_radius : int
        How many pixels in each direction we assume this filter
        explains when included in the sparsification.

    Returns
    -------
    frcs : see train_image.
    """
    frcs = []
    img = bu_msg.max(0) > 0
    while True:
        r, c = np.unravel_index(img.argmax(), img.shape)
        if not img[r, c]:
            break
        frcs.append((bu_msg[:, r, c].argmax(), r, c))
        img[r - suppress_radius:r + suppress_radius + 1, 
            c - suppress_radius:c + suppress_radius + 1] = False
    return np.array(frcs)


def learn_laterals(frcs, bu_msg, perturb_factor, use_adjaceny_graph=False):
    """Given the sparse representation of each training example,
    learn perturbation laterals. See train_image for parameters and returns.
    """
    print("learn_laterals 1", use_adjaceny_graph)
    if use_adjaceny_graph:
        graph = make_adjacency_graph(frcs, bu_msg)
        graph = adjust_edge_perturb_radii(frcs, graph, perturb_factor=perturb_factor)
    else:
        graph = nx.Graph()
        graph.add_nodes_from(range(frcs.shape[0]))
    graph = add_underconstraint_edges(frcs, graph, perturb_factor=perturb_factor)
    graph = adjust_edge_perturb_radii(frcs, graph, perturb_factor=perturb_factor)
    edge_factors = np.array(
        [(edge_source, edge_target, edge_attrs['perturb_radius'])
         for edge_source, edge_target, edge_attrs in graph.edges_iter(data=True)])
    return graph, edge_factors


def make_adjacency_graph(frcs, bu_msg, max_dist=3):
    """Make a graph based on contour adjacency."""
    preproc_pos = np.transpose(np.nonzero(bu_msg > 0))[:, 1:]
    preproc_tree = cKDTree(preproc_pos)
    # Assign each preproc to the closest F1
    f1_bus_tree = cKDTree(frcs[:, 1:])
    _, preproc_to_f1 = f1_bus_tree.query(preproc_pos, k=1)
    # Add edges
    preproc_pairs = np.array(list(preproc_tree.query_pairs(r=max_dist, p=1)))
    f1_edges = np.array(list({(x, y) for x, y in preproc_to_f1[preproc_pairs] if x != y}))

    graph = nx.Graph()
    graph.add_nodes_from(range(frcs.shape[0]))
    graph.add_edges_from(f1_edges)
    return graph


def add_underconstraint_edges(frcs,
                              graph,
                              perturb_factor=2.,
                              max_cxn_length=100,
                              tolerance=4):
    """Examines all pairs of variables and greedily adds pairwise constraints
    until the pool flexibility matches the desired amount of flexibility specified by 
    perturb_factor and tolerance.

    Parameters
    ----------
    frcs : numpy.ndarray of numpy.int
        Nx3 array of (feature idx, row, column), where each row represents a 
        single pool center.
    perturb_factor : float
        How much two points are allowed to vary on average given the distance
        between them.
    max_cxn_length : int
        The maximum radius to consider adding laterals.
    tolerance : float
        How much relative error to tolerate in how much two points vary relative to each 
        other.

    Returns
    -------
    graph : see train_image.
    """
    graph = graph.copy()
    f1_bus_tree = cKDTree(frcs[:, 1:])

    close_pairs = np.array(list(f1_bus_tree.query_pairs(r=max_cxn_length)))
    dists = [distance.euclidean(frcs[x, 1:], frcs[y, 1:]) for x, y in close_pairs]

    for close_pairs_idx in np.argsort(dists):
        source, target = close_pairs[close_pairs_idx]
        dist = dists[close_pairs_idx]

        try:
            perturb_dist = nx.shortest_path_length(graph, source, target, 'perturb_radius')
        except nx.NetworkXNoPath:
            perturb_dist = np.inf

        target_perturb_dist = dist / float(perturb_factor)
        actual_perturb_dist = max(0, np.ceil(target_perturb_dist))
        if perturb_dist >= target_perturb_dist * tolerance:
            graph.add_edge(source,
                           target,
                           perturb_radius=int(actual_perturb_dist))
    return graph


def adjust_edge_perturb_radii(frcs,
                              graph,
                              perturb_factor=2):
    """Returns a new graph where the 'perturb_radius' has been adjusted to account for 
    rounding errors. See train_image for parameters and returns.
    """
    graph = graph.copy()
    
    total_rounding_error = 0
    for n1, n2 in nx.edge_dfs(graph):
        desired_radius = distance.euclidean(frcs[n1, 1:], frcs[n2, 1:]) / perturb_factor

        upper = int(np.ceil(desired_radius))
        lower = int(np.floor(desired_radius))
        round_up_error = total_rounding_error + upper - desired_radius
        round_down_error = total_rounding_error + lower - desired_radius
        if abs(round_up_error) < abs(round_down_error):
            graph.edge[n1][n2]['perturb_radius'] = upper
            total_rounding_error = round_up_error
        else:
            graph.edge[n1][n2]['perturb_radius'] = lower
            total_rounding_error = round_down_error
    return graph

==================    

inference    

==================    

In [69]:
class RCNInferenceError(Exception):
    """Raise for general errors in RCN inference."""
    pass


def test_image(img, model_factors,
               pool_shape=(25, 25), num_candidates=20, n_iters=300, damping=1.0):
    """
    Main function for testing on one image.

    Parameters
    ----------
    img : 2D numpy.ndarray
        The testing image.
    model_factors : ([numpy.ndarray], [numpy.ndarray], [networkx.Graph])
        ([frcs], [edge_factors], [graphs]), output of train_image in learning.py.
    pool_shape : (int, int)
        Vertical and horizontal pool shapes.
    num_candidates : int
        Number of top candidates for backward-pass inference.
    n_iters : int
        Maximum number of loopy BP iterations.
    damping : float
        Damping parameter for loopy BP.

    Returns
    -------
    winner_idx : int
        Training index of the winner feature.
    winner_score : float
        Score of the winning feature.
    """
    # Get bottom-up messages from the pre-processing layer
    preproc_layer = Preproc(cross_channel_pooling=True)
    bu_msg = preproc_layer.fwd_infer(img)

    # Forward pass inference
    fp_scores = np.zeros(len(model_factors[0]))
    for i, (frcs, _, graph) in enumerate(izip(*model_factors)):
        fp_scores[i] = forward_pass(frcs,
                                    bu_msg,
                                    graph,
                                    pool_shape)
    top_candidates = np.argsort(fp_scores)[-num_candidates:]

    # Backward pass inference
    winner_idx, winner_score = (-1, -np.inf)  # (training feature idx, score)
    for idx in top_candidates:
        frcs, edge_factors = model_factors[0][idx], model_factors[1][idx]
        rcn_inf = LoopyBPInference(bu_msg, frcs, edge_factors, pool_shape, preproc_layer,
                                   n_iters=n_iters, damping=damping)
        score = rcn_inf.bwd_pass()
        if score >= winner_score:
            winner_idx, winner_score = (idx, score)
    return winner_idx, winner_score


def forward_pass(frcs, bu_msg, graph, pool_shape):
    """
    Forward pass inference using a tree-approximation (cf. Sec S4.2).

    Parameters
    ----------
    frcs : numpy.ndarray of numpy.int
        Nx3 array of (feature idx, row, column), where each row represents a
        single pool center.
    bu_msg : 3D numpy.ndarray of float
        The bottom-up messages from the preprocessing layer.
        Shape is (num_feats, rows, cols)
    graph : networkx.Graph
        An undirected graph whose edges describe the pairwise constraints between
        the pool centers.
        The tightness of the constraint is in the 'perturb_radius' edge attribute.
    pool_shape : (int, int)
        Vertical and horizontal pool shapes.

    Returns
    -------
    fp_score : float
        Forward pass score.
    """
    height, width = bu_msg.shape[-2:]
    # Vertical and horizontal pool shapes
    vps, hps = pool_shape

    def _pool_slice(f, r, c):
        assert (r - vps // 2 >= 0 and r + vps - vps // 2 < height and
                c - hps // 2 >= 0 and c + hps - hps // 2 < width), \
            "Some pools are out of the image boundaries. "\
            "Consider increase image padding or reduce pool shapes."
        return np.s_[f,
                     r - vps // 2: r + vps - vps // 2,
                     c - hps // 2: c + hps - hps // 2]

    # Find a schedule to compute the max marginal for the most constrained tree
    tree_schedule = get_tree_schedule(frcs, graph)

    # If we're sending a message out from x to y, it means x has received all
    # incoming messages
    incoming_msgs = {}
    for source, target, perturb_radius in tree_schedule:
        msg_in = bu_msg[_pool_slice(*frcs[source])]
        if source in incoming_msgs:
            msg_in = msg_in + incoming_msgs[source]
            del incoming_msgs[source]
        msg_in = dilate_2d(msg_in, (2 * perturb_radius + 1, 2 * perturb_radius + 1))
        if target in incoming_msgs:
            incoming_msgs[target] += msg_in
        else:
            incoming_msgs[target] = msg_in
    fp_score = np.max(incoming_msgs[tree_schedule[-1, 1]] +
                      bu_msg[_pool_slice(*frcs[tree_schedule[-1, 1]])])
    return fp_score


def get_tree_schedule(frcs, graph):
    """
    Find the most constrained tree in the graph and returns which messages to compute
    it.  This is the minimum spanning tree of the perturb_radius edge attribute.

    See forward_pass for parameters.

    Returns
    -------
    tree_schedules : numpy.ndarray of numpy.int
        Describes how to compute the max marginal for the most constrained tree.
        Nx3 2D array of (source pool_idx, target pool_idx, perturb radius), where
        each row represents a single outgoing factor message computation.
    """
    min_tree = nx.minimum_spanning_tree(graph, 'perturb_radius')
    return np.array([(target, source, graph.edge[source][target]['perturb_radius'])
                     for source, target in nx.dfs_edges(min_tree)])[::-1]


class LoopyBPInference(object):
    """Max-product loopy belief propagation for a two-level RCN model (cf. Sec S4.4).

    Attributes
    ----------
    n_feats, n_rows, n_cols : int, int, int
        Number of features in preprocessing layer, image height, image width.
    n_pools : int
        Number of pools in the model.
    n_factors : int
        Number of edge factors in the model.
    vps, hps : int, int
        Horizontal and vertical pool shape.
    unary_messages : numpy.array
        Unary messages to each variable, obtained by cropping the receptive fields
        from bu_msg. Shape is (n_pools x vps x hps).
    lat_messages : numpy.array
        Lateral message matrix, shape is (n_pools, n_pools, vps, hps). Element
        (v1, v2, r, c) contains the message from v1 to v2, precisely the
        (unnormalized) log-message of pool v2 being in state r, c.
    """

    def __init__(self, bu_msg, frcs, edge_factors, pool_shape, preproc_layer,
                 n_iters=300, damping=1.0, tol=1e-5):
        """
        Parameters
        ----------
        bu_msg : numpy.array of float
            Bottom up messages from preprocessing layer, in the following format:
            (feature idx, row, col).
        frcs : np.ndarray of np.int
            Nx3 array of (feature idx, row, column), where each row represents a
            single pool center.
        edge_factors : numpy.ndarray of numpy.int
            Nx3 array of (source pool index, target pool index, perturb_radius), where
            each row is a pairwise constraints on a pair of pool choices.
        pool_shape : (int, int)
            Vertical and horizontal pool shapes.
        preproc_layer : Preproc
            Pre-processing layer. See preproc.py.
        n_iters : int
            Maximum number of loopy BP iterations.
        damping : float
            Damping parameter for loopy BP.
        tol : float
            Tolerance to determine loopy BP convergence.

        Raises
        ------
        RCNInferenceError
        """
        self.n_feats, self.n_rows, self.n_cols = bu_msg.shape
        self.n_pools, self.n_factors = frcs.shape[0], edge_factors.shape[0]
        self.vps, self.hps = pool_shape
        self.frcs = frcs
        self.bu_msg = bu_msg
        self.edge_factors = edge_factors
        self.preproc_layer = preproc_layer
        self.n_iters = n_iters
        self.damping = damping
        self.tol = tol

        # Check inputs
        if (np.array([0, self.vps // 2, self.hps // 2]) > frcs.min(0)).any():
            raise RCNInferenceError("Some frcs are too small for the provided pool shape")
        if (frcs.max(0) >= np.array([self.n_feats,
                                    self.n_rows - ((self.vps - 1) // 2),
                                    self.n_cols - ((self.hps - 1) // 2)])).any():
            raise RCNInferenceError("Some frcs are too big for the provided pool "
                                    "shape and/or `bu_msg`")
        if (edge_factors[:, :2].min(0) < np.array([0, 0])).any():
            raise RCNInferenceError("Some variable index in `edge_factors` is negative")
        if (edge_factors[:, :2].max(0) >= np.array([self.n_pools, self.n_pools])).any():
            raise RCNInferenceError("Some index in `edge_factors` exceeds the number of vars")
        if (edge_factors[:, 0] == edge_factors[:, 1]).any():
            raise RCNInferenceError("Some factor connects a variable to itself")
        if not issubclass(edge_factors.dtype.type, np.integer):
            raise RCNInferenceError("Factors should be an integer numpy array")

        # Initialize message
        self._reset_messages()
        self.unary_messages = np.zeros((self.n_pools, self.vps, self.hps))
        bu_msg_pert = self.bu_msg + 0.01 * (2 * rand(*bu_msg.shape) - 1)
        for i, (f, r, c) in enumerate(self.frcs):
            rstart = r - self.vps // 2
            cstart = c - self.hps // 2
            self.unary_messages[i] = bu_msg_pert[f,
                                                 rstart:rstart + self.vps,
                                                 cstart:cstart + self.hps]

    def _reset_messages(self):
        """Set all lateral messages to zero."""
        self.lat_messages = np.zeros((2, self.n_factors, self.vps, self.hps))

    @staticmethod
    def compute_1pl_message(in_mess, pert_radius):
        """Compute the outgoing message of a lateral factor given the
        perturbation radius and input message.

        Parameters
        ----------
        in_mess : numpy.array
            Input BP messages to the factor. Each message has shape vps x hps.
        pert_radius : int
            Perturbation radius corresponding to the factor.

        Returns
        -------
        out_mess : numpy.array
            Output BP message (at the opposite end of the factor from the input message).
            Shape is (vps, hps).
        """
        pert_diameter = 2 * pert_radius + 1
        out_mess = dilate_2d(in_mess, (pert_diameter, pert_diameter))
        return out_mess - out_mess.max()

    def new_messages(self):
        """Compute updated set of lateral messages (in both directions).

        Returns
        -------
        new_lat_messages : numpy.array
            Updated set of lateral messages. Shape is (2, n_factors, vps x hps).
        """
        # Compute beliefs
        beliefs = self.unary_messages.copy()
        for f, (var_i, var_j, pert_radius) in enumerate(self.edge_factors):
            beliefs[var_j] += self.lat_messages[0, f]
            beliefs[var_i] += self.lat_messages[1, f]

        # Compute outgoing messages
        new_lat_messages = np.zeros_like(self.lat_messages)
        for f, (var_i, var_j, pert_radius) in enumerate(self.edge_factors):
            new_lat_messages[0, f] = self.compute_1pl_message(
                beliefs[var_i] - self.lat_messages[1, f], pert_radius)
            new_lat_messages[1, f] = self.compute_1pl_message(
                beliefs[var_j] - self.lat_messages[0, f], pert_radius)
        return new_lat_messages

    def bwd_pass(self):
        """Perform max-product loopy BP inference and decode the max-marginals.

        Returns
        -------
        score : float
            The score of the backtraced solution, adjusted for filter overlapping.
        """
        self._reset_messages()
        # Loopy BP with parallel updates
        self.infer_pbp()
        # Decode the max-marginals
        assignments, backtrace_positions, score = self.decode()
        # Check constraints are satisfied
        if not self.laterals_are_satisfied(assignments):
            print("Lateral constraints not satisfied. Try increasing the "
                      "number of iterations.")
            score = -np.inf
        return score

    def infer_pbp(self):
        """Parallel loopy BP message passing, modifying state of `lat_messages`.
        See bwd_pass() for parameters.
        """
        for it in xrange(self.n_iters):
            new_lat_messages = self.new_messages()
            delta = new_lat_messages - self.lat_messages
            self.lat_messages += self.damping * delta
            if np.abs(delta).max() < self.tol:
                print("Parallel loopy BP converged in {} iterations".format(it))
                return
        print("Parallel loopy BP didn't converge in {} iterations".format(self.n_iters))

    def decode(self):
        """Find pool assignments by decoding the max-marginal messages.

        Returns
        -------
        assignments : 2D numpy.ndarray of int
            Each row is the row and column assignments for each pool.
        backtrace_positions : 3D numpy.ndarray of int
            Sparse top-down activations in the form of (f,r,c).
        score : float
            Sum of log-likelihoods collected by the decoded pool assignments.
        """
        # Compute beliefs
        beliefs = self.unary_messages.copy()
        for f, (var_i, var_j, pert_radius) in enumerate(self.edge_factors):
            beliefs[var_j] += self.lat_messages[0, f]
            beliefs[var_i] += self.lat_messages[1, f]

        assignments = np.zeros((self.n_pools, 2), dtype=np.int)
        backtrace = np.zeros((self.n_feats, self.n_rows, self.n_cols))
        for i, (f, r, c) in enumerate(self.frcs):
            r_max, c_max = np.where(beliefs[i] == beliefs[i].max())
            choice = randint(len(r_max))
            assignments[i] = np.array([r_max[choice], c_max[choice]])
            rstart = r - self.vps // 2
            cstart = c - self.hps // 2
            backtrace[f,
                      rstart + assignments[i, 0],
                      cstart + assignments[i, 1]] = 1
        backtrace_positions = np.transpose(np.nonzero(backtrace))
        score = recount(backtrace_positions, self.bu_msg, self.preproc_layer.pos_filters)
        return assignments, backtrace_positions, score

    def laterals_are_satisfied(self, assignments):
        """Check whether pool assignments satisfy all lateral constraints.

        Parameters
        ----------
        assignments : 2D numpy.ndarray of int
            Row and column assignments for each pool.

        Returns
        -------
        satisfied : bool
            Whether the pool assignments satisfy all lateral constraints.
        """
        satisfied = True
        for f, (var_i, var_j, pert_radius) in enumerate(self.edge_factors):
            rdist, cdist = np.abs(assignments[var_i] - assignments[var_j])
            if not (rdist <= pert_radius and cdist <= pert_radius):
                satisfied = False
                break
        return satisfied


def recount(backtrace_positions, bu_msg, filters):
    """
    Post-processing step to prevent overcounting of log-likelihoods (cf. Sec S8.2).

    Parameters
    ----------
    backtrace_positions : 3D numpy.ndarray of int
        Sparse top-down activations in the format of (f,r,c).
    bu_msg : 3D numpy.ndarray of int
        Bottom-up messages after the pre-processing layer.
    filters : [2D numpy.ndarray of float]
        Filter bank used in the pre-processing layer.

    Returns
    -------
    normalized_score : float
        Score normalized by taking filter overlaps into account.

    Raises
    ------
    RCNInferenceError
    """
    height, width = bu_msg.shape[-2:]
    f_h, f_w = filters[0].shape
    layers = np.zeros((len(backtrace_positions), height, width))
    fo_h, fo_w = f_h // 2, f_w // 2
    from_r, to_r = (np.maximum(0, backtrace_positions[:, 1] - fo_h),
                    np.minimum(height, backtrace_positions[:, 1] - fo_h + f_h))
    from_c, to_c = (np.maximum(0, backtrace_positions[:, 2] - fo_w),
                    np.minimum(width, backtrace_positions[:, 2] - fo_w + f_w))
    from_fr, to_fr = (np.maximum(0, fo_h - backtrace_positions[:, 1]),
                      np.minimum(f_h, height - backtrace_positions[:, 1] + fo_h))
    from_fc, to_fc = (np.maximum(0, fo_w - backtrace_positions[:, 2]),
                      np.minimum(f_w, width - backtrace_positions[:, 2] + fo_w))

    if not np.all(to_r - from_r == to_fr - from_fr):
        raise RCNInferenceError("Numbers of rows of filter and image patches "
                                "({}, {}) do not agree".format(
                                    to_r - from_r, to_fr - from_fr))
    if not np.all(to_c - from_c == to_fc - from_fc):
        raise RCNInferenceError("Numbers of columns of filter and image patches "
                                "({}, {}) do not agree".format(
                                    to_c - from_c, to_fc - from_fc))

    # Normalize activations by taking into account filter overlaps
    weight_sum = np.zeros((height, width))
    for i, (f, r, c) in enumerate(backtrace_positions):
        # Convolve sparse top-down activations with filters
        filt = filters[f][from_fr[i]:to_fr[i], from_fc[i]:to_fc[i]]

        weight_sum[from_r[i]:to_r[i], from_c[i]:to_c[i]] += filt
        layers[i, from_r[i]:to_r[i], from_c[i]:to_c[i]] = \
            filt**2 * bu_msg[f, r, c] / (1e-9 + filt.sum())
    normalized_score = (layers.sum(0) / (1e-9 + weight_sum)).sum()
    return normalized_score

==================    

preproc    

==================    

In [70]:
class Preproc(object):
    """
    A simplified preprocessing layer implementing Gabor filters and suppression.

    Parameters
    ----------
    num_orients : int
        Number of edge filter orientations (over 2pi).
    filter_scale : float
        A scale parameter for the filters.
    cross_channel_pooling : bool
        Whether to pool across neighboring orientation channels (cf. Sec S8.1.4).

    Attributes
    ----------
    filters : [numpy.ndarray]
        Kernels for oriented Gabor filters.
    pos_filters : [numpy.ndarray]
        Kernels for oriented Gabor filters with all-positive values.
    suppression_masks : numpy.ndarray
        Masks for oriented non-max suppression.
    """

    def __init__(self,
                 num_orients=16,
                 filter_scale=4.,
                 cross_channel_pooling=False):
        self.num_orients = num_orients
        self.filter_scale = filter_scale
        self.cross_channel_pooling = cross_channel_pooling
        self.suppression_masks = generate_suppression_masks(filter_scale=filter_scale, 
                                                            num_orients=num_orients)

    def fwd_infer(self, img, brightness_diff_threshold=40.):
        """Compute bottom-up (forward) inference.

        Parameters
        ----------
        img : numpy.ndarray
            The input image.
        brightness_diff_threshold : float
            Brightness difference threshold for oriented edges.

        Returns
        -------
        bu_msg : 3D numpy.ndarray of float
            The bottom-up messages from the preprocessing layer. 
            Shape is (num_feats, rows, cols)
        """
        filtered = np.zeros((len(self.filters),) + img.shape, dtype=np.float32)
        for i, kern in enumerate(self.filters):
            filtered[i] = fftconvolve(img, kern, mode='same')
        localized = local_nonmax_suppression(filtered, self.suppression_masks)
        # Threshold and binarize
        localized *= (filtered / brightness_diff_threshold).clip(0, 1)
        localized[localized < 1] = 0

        if self.cross_channel_pooling:
            pooled_channel_weights = [(0, 1), (-1, 1), (1, 1)]
            pooled_channels = [-np.ones_like(sf) for sf in localized]
            for i, pc in enumerate(pooled_channels):
                for channel_offset, factor in pooled_channel_weights:
                    ch = (i + channel_offset) % self.num_orients
                    pos_chan = localized[ch]
                    if factor != 1:
                        pos_chan[pos_chan > 0] *= factor
                    np.maximum(pc, pos_chan, pc)
            bu_msg = np.array(pooled_channels)
        else:
            bu_msg = localized
        # Setting background to -1
        bu_msg[bu_msg == 0] = -1.
        return bu_msg

    @property
    def filters(self):
        return get_gabor_filters(
            filter_scale=self.filter_scale, num_orients=self.num_orients, weights=False)

    @property
    def pos_filters(self):
        return get_gabor_filters(
            filter_scale=self.filter_scale, num_orients=self.num_orients, weights=True)


def get_gabor_filters(size=21, filter_scale=4., num_orients=16, weights=False):
    """Get Gabor filter bank. See Preproc for parameters and returns."""
    def _get_sparse_gaussian():
        """Sparse Gaussian."""
        size = 2 * np.ceil(np.sqrt(2.) * filter_scale) + 1
        alt = np.zeros((int(size), int(size)), np.float32)
        alt[int(size // 2), int(size // 2)] = 1
        gaussian = gaussian_filter(alt, filter_scale / np.sqrt(2.), mode='constant')
        gaussian[gaussian < 0.05 * gaussian.max()] = 0
        return gaussian

    gaussian = _get_sparse_gaussian()
    filts = []
    for angle in np.linspace(0., 2 * np.pi, num_orients, endpoint=False):
        acts = np.zeros((size, size), np.float32)
        x, y = np.cos(angle) * filter_scale, np.sin(angle) * filter_scale
        acts[int(size / 2 + y), int(size / 2 + x)] = 1.
        acts[int(size / 2 - y), int(size / 2 - x)] = -1.
        filt = fftconvolve(acts, gaussian, mode='same')
        filt /= np.abs(filt).sum()  # Normalize to ensure the maximum output is 1
        if weights:
            filt = np.abs(filt)
        filts.append(filt)
    return filts


def generate_suppression_masks(filter_scale=4., num_orients=16):
    """
    Generate the masks for oriented non-max suppression at the given filter_scale.
    See Preproc for parameters and returns.
    """
    size = 2 * int(np.ceil(filter_scale * np.sqrt(2))) + 1
    cx, cy = size // 2, size // 2
    filter_masks = np.zeros((num_orients, size, size), np.float32)
    # Compute for orientations [0, pi), then flip for [pi, 2*pi)
    for i, angle in enumerate(np.linspace(0., np.pi, num_orients // 2, endpoint=False)):
        x, y = np.cos(angle), np.sin(angle)
        for r in xrange(1, int(np.sqrt(2) * size / 2)):
            dx, dy = round(r * x), round(r * y)
            if abs(dx) > cx or abs(dy) > cy:
                continue
            filter_masks[i, int(cy + dy), int(cx + dx)] = 1
            filter_masks[i, int(cy - dy), int(cx - dx)] = 1
    filter_masks[num_orients // 2:] = filter_masks[:num_orients // 2]
    return filter_masks


def local_nonmax_suppression(filtered, suppression_masks, num_orients=16):
    """
    Apply oriented non-max suppression to the filters, so that only a single 
    orientated edge is active at a pixel. See Preproc for additional parameters.

    Parameters
    ----------
    filtered : numpy.ndarray
        Output of filtering the input image with the filter bank.
        Shape is (num feats, rows, columns).

    Returns
    -------
    localized : numpy.ndarray
        Result of oriented non-max suppression.
    """
    localized = np.zeros_like(filtered)
    cross_orient_max = filtered.max(0)
    filtered[filtered < 0] = 0
    for i, (layer, suppress_mask) in enumerate(zip(filtered, suppression_masks)):
        competitor_maxs = maximum_filter(layer, footprint=suppress_mask, mode='nearest')
        localized[i] = competitor_maxs <= layer
    localized[cross_orient_max > filtered] = 0
    return localized

==================    

run    

==================    

In [71]:
def run_experiment(data_dir='data/MNIST',
                   train_size=20,
                   test_size=20,
                   full_test_set=False,
                   pool_shape=(25, 25),
                   perturb_factor=2.,
                   parallel=True,
                   verbose=False,
                   seed=5):

    # Multiprocessing set up
    num_workers = None if parallel else 1
    pool = Pool(num_workers)

    train_data, test_data = get_mnist_data_iters(
        data_dir, train_size, test_size, full_test_set, seed=seed)

    train_partial = partial(train_image,
                            perturb_factor=perturb_factor)
    train_results = pool.map_async(train_partial, [d[0] for d in train_data]).get(9999999)
    all_model_factors = zip(*train_results)

    test_partial = partial(test_image, model_factors=all_model_factors,
                           pool_shape=pool_shape)
    test_results = pool.map_async(test_partial, [d[0] for d in test_data]).get(9999999)

    # Evaluate result
    correct = 0
    for test_idx, (winner_idx, _) in enumerate(test_results):
        correct += int(test_data[test_idx][1]) == winner_idx // (train_size // 10)
    print "Total test accuracy = {}".format(float(correct) / len(test_results))

    return all_model_factors, test_results


def get_mnist_data_iters(data_dir, train_size, test_size,
                         full_test_set=False, seed=5):
    """
    Load MNIST data.

    Assumed data directory structure:
        training/
            0/
            1/
            2/
            ...
        testing/
            0/
            ...

    Parameters
    ----------
    train_size, test_size : int
        MNIST dataset sizes are in increments of 10
    full_test_set : bool
        Test on the full MNIST 10k test set.
    seed : int
        Random seed used by numpy.random for sampling training set.

    Returns
    -------
    train_data, train_data : [(numpy.ndarray, str)]
        Each item reps a data sample (2-tuple of image and label)
        Images are numpy.uint8 type [0,255]
    """
    if not os.path.isdir(data_dir):
        raise IOError("Can't find your data dir '{}'".format(data_dir))

    def _load_data(image_dir, num_per_class, get_filenames=False):
        loaded_data = []
        for category in sorted(os.listdir(image_dir)):
            cat_path = os.path.join(image_dir, category)
            if not os.path.isdir(cat_path) or category.startswith('.'):
                continue
            if num_per_class is None:
                samples = sorted(os.listdir(cat_path))
            else:
                samples = np.random.choice(sorted(os.listdir(cat_path)), num_per_class)

            for fname in samples:
                filepath = os.path.join(cat_path, fname)
                # Resize and pad the images to (200, 200)
                image_arr = imresize(imread(filepath), (112, 112))
                img = np.pad(image_arr,
                             pad_width=tuple([(p, p) for p in (44, 44)]),
                             mode='constant', constant_values=0)
                loaded_data.append((img, category))
        return loaded_data

    np.random.seed(seed)
    train_set = _load_data(os.path.join(data_dir, 'training'),
                           num_per_class=train_size // 10)
    test_set = _load_data(os.path.join(data_dir, 'testing'),
                          num_per_class=None if full_test_set else test_size // 10)
    return train_set, test_set

In [74]:
#debug mode if this flag is set (default: False)
debug = False
#Number of training examples.
train_size = 10
#Number of testing examples.
test_size = 10
#Test on full MNIST test set.
full_test_set = False
#Pool shape.
pool_shape = 25
#Perturbation factor.
perturb_factor = 2
#Seed for numpy.random to sample training and testing dataset split.
seed = 5
#Parallelize over multi-CPUs if True.
parallel = True
#Verbosity level.
verbose = False

run_experiment(train_size=train_size,
               test_size=test_size,
               full_test_set=full_test_set,
               pool_shape=(pool_shape, pool_shape),
               perturb_factor=perturb_factor,
               seed=seed,
               verbose=verbose,
               parallel=parallel)

`imread` is deprecated in SciPy 1.0.0.
Use ``matplotlib.pyplot.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
('learn_laterals 1', False)
Parallel loopy BP converged in 38 iterations
Parallel loopy BP converged in 47 iterations
Parallel loopy BP converged in 31 iterations
Parallel loopy BP converged in 72 iterations
Parallel loopy BP converged in 55 iterations
Parallel loopy BP converged in 40 iterations
Parallel loopy BP converged in 40 iterations
Parallel loopy BP converged in 45 iterations
Parallel loopy BP converged in 54 iterations
Parallel loopy BP converged in 66 iterations
Parallel loopy BP converged in 50 iterations
Parallel loopy BP didn't converge in 300 iterations
Parallel loopy BP converged in 52 iterations
Parallel loopy BP converged in 34 iterations
Parallel loopy BP converged in 36 iterations
Parallel loopy BP didn't converge in 3

([(array([[ 11,  63, 110],
          [ 13,  63, 118],
          [ 11,  64, 106],
          [ 12,  64, 114],
          [ 13,  65, 122],
          [ 11,  66, 101],
          [ 13,  67, 126],
          [ 11,  68,  97],
          [ 13,  70, 130],
          [ 10,  71,  93],
          [ 14,  72, 134],
          [ 10,  74,  89],
          [ 15,  76, 136],
          [ 10,  78,  84],
          [  4,  78, 119],
          [  2,  79, 115],
          [  6,  79, 123],
          [  0,  80, 138],
          [ 10,  82,  80],
          [  2,  83, 111],
          [  7,  83, 125],
          [  0,  84, 138],
          [ 10,  86,  77],
          [  4,  87,  96],
          [  2,  87, 107],
          [  8,  87, 126],
          [  0,  88, 138],
          [  3,  89,  92],
          [  9,  90,  75],
          [  6,  91,  98],
          [  2,  91, 105],
          [  8,  91, 126],
          [  2,  92,  88],
          [  0,  92, 138],
          [  9,  94,  74],
          [  7,  95,  97],
          [  1,  95, 106],
 