# Switch and Stack Development

## Switch Development

Let's say the switch is made up of individual ```filters``` that compute the residual and apply the weighting.

Do we apply a non-linearity to the weighted output?

In [None]:
def norm(residual):
    """Compute a norm of a residual array.

    Arg:
        residual - numpy array.
    """
    absolute = np.absolute(residual)
    scaled_total = absolute.sum() / absolute.shape[0]
    return scaled_total


In [None]:
class SwitchFilter:
    """Component of a switch that corresponds to a stage."""
    
    def __init__(self, input_len, norm_buffer_len=30):
        """Initialise.
        
        Args:
            input_len - integer setting the input / prediction array length.
            norm_buffer_len - length of time to buffer norm computations.  
        """
        # Set a variable to store the residual weight
        self.weight = 1
        # Set a variable to store the current residual
        self.residual = np.zeros(shape=(input_len, 1))
        self.norm_buffer = np.zeros(shape=(input_len, norm_buffer_len))
        
    def get_residual(self, input_data, pred_input):
        """Compare input data and predicted input data.
        
        Does this need to be a method? Yes to store the current residual.
        
        Args:
            input_data and pred_input are 1D numpy arrays of the same size.
        """
        # Compute residual
        self.residual = input_data - pred_input
        # Buffer norm
        self.norm_buffer = np.roll(
            self.norm_buffer, -1, axis=1
        )
        # Add frame to end of buffer
        self.norm_buffer[..., -1] = norm(self.residual).flatten()
        return self.residual
    
    def post_processing(self, weighted):
        """Optional post processing on output."""
        return weighted
    
    def iterate(self, input_data, pred_input):
        """Perform an iteration.
        
        Args:
            input_data and pred_input are 1D numpy arrays of the same size.
        """
        residual = self.get_residual(input_data, pred_input)
        weighted = self.weight * residual
        processed_weighted = self.post_processing(weighted)
        return processed_weighted

In [None]:
class Switch:
    """Object to manage stacks.

    Models some thalamus function.
    """

    def __init__(self, sensor, stacks):
        """Initialise.

        Args:
            sesnor - a Sensor object to provide input data.
            stacks - a list of Stack objects.
        """
        self.sensor = sensor
        self.stacks = stacks
        # Initialise list for residuals
        self.residuals = [None for s in stacks]
        # Initialise norm list
        self.norms = [None for s in stacks]

    def iterate(self):
        """Perform an iteration.

        Returns:
            residuals - list of residuals.
            norms - list of norms for each residual.

        """
        input_data = self.sensor.get_frame()
        for i, stack in enumerate(self.stacks):
            # Iterate stack
            _, pred_out = stack.iterate(input_data, None)
            # Compute residual
            self.residuals[i] = input_data - pred_out
            # Compute norm
            self.norms[i] = norm(self.residuals[i])
            # Set input data for next stack as residual
            input_data = self.residuals[i]
        return self.residuals, self.norms

## Brain

This may be thought of as the wrapper for the switch and the stacks.

Each stack has a switch_filter? But each 

In [None]:
class Brain:
    """Wrapper for processing."""

    def __init__(self):
        """Initialise."""
        # Initialise input pre-processor as list of sensors
        self.pre_processor = list()
        # Initialise switch as a set of switch filters
        self.switch = list()
        # Initialise modality processors as list of stack lists
        self.cortex = list()
    
    def add_sensor(self, sensor):
        """Add a sensor device."""
        # Add sensor to pre-processor
        self.pre_processor.append(sensor)
        sensor_length = sensor.get_data_length()
        # Add switch filters
        switch_filter = SwitchFilter(sensor_length)
        self.switch.append(switch_filter)
        # Add stack
        pass

    def iterate(self):
        """Perform an iteration."""
        input_data = self.sensor.get_frame()
        for i, stack in enumerate(self.stacks):
            # Iterate stack
            _, pred_out = stack.iterate(input_data, None)
            # Compute residual
            self.residuals[i] = input_data - pred_out
            # Compute norm
            self.norms[i] = norm(self.residuals[i])
            # Set input data for next stack as residual
            input_data = self.residuals[i]
        return self.residuals, self.norms