# Building foveated deep vision models based on kNN-convolution

Next, we will dive into how to do perceptual processing of our foveated sensor outputs, making use of the sensor manifold.

For this, we use k-nearest-neighbor (kNN) receptive fields. In 2-D, receptive fields are specified as $(h,w)$ rectangular grids; on our 3-D manifold, they are specified as kNNs. 

The details are described in the paper. Here, we will go through the relevant code modules to see how to build up networks based on kNN-convolution on the foveated sensor manifold

We will now start looking at `fovi.arch`, where all of the architectural features relevant to foveated perceptual processing live. 

The building block layers are stored in `fovi.arch.knn`. Let's use them to build a super simple 1-layer convolutional network, with a convolution layer followed by a pooling layer, normalization, and ReLU.

In [1]:
%load_ext autoreload
%autoreload 2

from fovi.arch.knn import KNNPoolingLayer, KNNConvLayer, get_in_out_coords
from fovi.arch.norm import KNNBatchNorm
import torch.nn as nn

fov = 16
cmf_a = 0.5
device = 'cpu'

cartesian_res = 64
conv_kernel_cartesian = [7, 7]
pool_kernel_cartesian = [3, 3]
conv_stride = 1
pool_stride = 2
channels = 128

# determine neighborhood sizes based on target cartesian kernels
k_conv = conv_kernel_cartesian[0]*conv_kernel_cartesian[1]
k_pool = pool_kernel_cartesian[0]*pool_kernel_cartesian[1]

# set up coordinates based on input resolution and strides
in_cart_res = cartesian_res
sensor_coords, conv_coords, out_cart_res = get_in_out_coords(in_cart_res, fov, cmf_a, conv_stride, in_cart_res=in_cart_res, device=device)
# the previous layer is the input to the next layer
in_cart_res = out_cart_res
_, pool_coords, _ = get_in_out_coords(in_cart_res, fov, cmf_a, pool_stride, in_cart_res=in_cart_res, in_coords=conv_coords, device=device)

conv_layer = KNNConvLayer(3, channels, k_conv, sensor_coords, conv_coords, 
                          ref_frame_side_length=2*conv_kernel_cartesian[0], 
                          device=device,
                          )
pool_layer = KNNPoolingLayer(k_pool, conv_coords, pool_coords, mode='max', device=device)

full_layer = nn.Sequential(
    conv_layer,
    pool_layer,
    nn.ReLU(),
    KNNBatchNorm(len(pool_coords), channels, device=device),
).to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


note that the "auto_match_cart_resources=1" attempts to match resources as closely as possible, but cannot be perfect (see print out). the default behavior is to ensure that we select less than or equal to the cartesian equivalent resources.

ok!

let's create some fake data and pass it through our simple fovi layer

In [2]:
import torch

# Our foveated sensor outputs and KNNLayer inputs/outputs are formatted as [batch, num_channels, num_coords]
x_sensor = torch.rand([64, 3, len(sensor_coords)]).to(device)

layer_output = full_layer(x_sensor)

In [3]:
layer_output.shape

torch.Size([64, 128, 964])

In practice, the `KNNAlexNetBlock` is set up to do exactly what we just did: combine conv, pooling, nonlinearity, and normalization.

In [4]:
from fovi.arch.knnalexnet import KNNAlexNetBlock

block = KNNAlexNetBlock(3, channels, k_conv, fov, cmf_a, cartesian_res, conv_stride, cart_res=cartesian_res, pool=True, pool_k=k_pool, pool_stride=pool_stride, norm_type='batch', auto_match_cart_resources=1, device=device, ref_frame_mult=2)

In [5]:
block_output = block(x_sensor)

block_output.shape

torch.Size([64, 128, 964])

# Building an AlexNet-like KNN model

Now that we've seen the layers and blocks, we are ready to build a complete KNNAlexNet model. For this, we will just use our wrapper function, and refer you to the code for further detail. The `KNNAlexNet` class builds AlexNet-like models, but is more flexible to different numbers of layers, different kernel sizes, channels dimensions, etc. 

In [6]:
from fovi.arch.knnalexnet import KNNAlexNet

model = KNNAlexNet(
    cartesian_res, 
    3, 
    [96, 256, 384, 256, 256], # channels per layer
    [4,1,1,1,1], # conv stride per layer
    [1,4], # pool after
    [11**2, 5**2, 3**2, 3**3, 3**2], # k per layer
    n_classes=1000,
    out_res=None, # no output pooling
    auto_match_cart_resources=1,
    norm_type='batch',
    ref_frame_mult=2,
    fov=fov,
    cmf_a=cmf_a,
    device=device,
    )

no output pooling layer


In [7]:
model

KNNAlexNet(
  (layers): ModuleList(
    (0): KNNAlexNetBlock(
      (conv): KNNConvLayer(
      	in_channels=3
      	out_channels=96
      	k=121
      	n_ref=484
      	in_coords=SamplingCoords(length=4085, fov=16, cmf_a=0.5, resolution=53, style=isotropic)
      	out_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
      	sample_cortex=True
      )
      (norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (1): KNNAlexNetBlock(
      (conv): KNNConvLayer(
      	in_channels=96
      	out_channels=256
      	k=25
      	n_ref=100
      	in_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
      	out_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
      	sample_cortex=True
      )
      (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
     

# Building a complete FoviNet

You may have noticed that we have only worked with fake data thus far. Our KNN architectures are designed to work with outputs formatted on the sensor manifold. To get this from images, we need to use a `RetinalTransform` object, or more simply, a `SaccadePolicy` which will also determine our fixations for us. If you forget about these, go back to `step1_sampling.ipynb`. 

The last piece of the puzzle is the `FoviNet` class which combines a fixation policy with a processing network. Since there are a lot of hyperparameters, here, we specify them all in a neat hierarchical `config`. This is typically specified as a `.yaml` file, and we use `hydra` and `omega` to handle these in our training scripts. Because the config uses inheritance, we initialize and compose it with `hydra`. 

In [8]:
from fovi.fovinet import FoviNet
from hydra import compose, initialize

# Use hydra/omega to process the hierarchical config, including all defaults
with initialize(version_base=None, config_path="../config"):
    config = compose(config_name="fovi_alexnet.yaml")
print(type(config)) # OmegaConf DictConfig

fovinet = FoviNet(config, device=device)

<class 'omegaconf.dictconfig.DictConfig'>
adjusting FOV for fixation: 16.0 (full: 16.0)


  w_delta = (w_max - w_min)/(res-1)


Note: horizontal flip always done in the loader, to avoid differences across fixations
Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]


In [9]:
fovinet

FoviNet(
  (network): BackboneProjectorWrapper(
    (backbone): KNNAlexNet(
      (layers): ModuleList(
        (0): KNNAlexNetBlock(
          (conv): KNNConvLayer(
          	in_channels=3
          	out_channels=96
          	k=121
          	n_ref=484
          	in_coords=SamplingCoords(length=4085, fov=16.0, cmf_a=0.5, resolution=53, style=isotropic)
          	out_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)
          	sample_cortex=True
          )
          (norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
          (pool): KNNPoolingLayer(
          	mode=max
          	k=9
          	in_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)
          	out_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)
          	sample_cortex=True
          )
        )
        (1): KNNAlexNetBlock(
          (co

We can now process image data with our fovinet model

In [10]:
from fovi.demo import get_image_as_batch

# data = torch.rand([128, 3, 256, 256]).to(device)

data = get_image_as_batch(device=device)

print(data.shape)

category_logits, layer_outputs, x_fixs = fovinet(data)

torch.Size([1, 3, 256, 256])


In [11]:
# global avg pool of final conv layer
print(layer_outputs[0].shape) # (batch, num_fixations, conv_dim)
# final MLP layer:
print(layer_outputs[-1].shape) # (batch, num_fixations, fc_dim)
# category logits averaged across fixations
print(category_logits.shape) # (batch, num_classes)

torch.Size([1, 4, 256])
torch.Size([1, 4, 1024])
torch.Size([1, 1000])


# load a pre-trained model

In [12]:
from fovi import get_model_from_base_fn

base_fn = 'fovi-alexnet_a-1_res-64_rfmult-2_in1k'
model = get_model_from_base_fn(base_fn)

Model with base_fn fovi-alexnet_a-1_res-64_rfmult-2_in1k found in ../models
adjusting FOV for fixation: 16.0 (full: 16.0)


  w_delta = (w_max - w_min)/(res-1)


Note: horizontal flip always done in the loader, to avoid differences across fixations
Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]
