# Get activations from a foveated model

Here we will demonstrate two methods for getting activitations. The first uses the model class directly. 

Let's load a pre-trained model

In [1]:
%load_ext autoreload
%autoreload 2

from foveation import load_config
from foveation.saccadenet import SaccadeNet

base_fn = 'fovknnalexnet_a-1_res-64_in1k'
config, state_dict, model_key = load_config(base_fn, load=True, folder='../models', device='cpu')
model = SaccadeNet(config, device='cpu')
model.load_state_dict(state_dict[model_key])

[[96, 11, 2, 5, 1], [256, 5, 1, 2, 1], [384, 3, 1, 1, 1], [384, 3, 1, 1, 1], [256, 3, 1, 1, 1]]
adjusting FOV for fixation: 16.0 (full: 16.0)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 26 giving 964 points (desired: 1024)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


found resolution 26 giving 964 points (desired: 1024)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 4 giving 16 points (desired: 16)
found resolution 1 giving 1 points (desired: 1)
loader_transforms: Compose(
    ToTorchImage(device=cpu, dtype=torch.float32, from_numpy=True)
    RandomHorizontalFlip(p=0.5, seed=None)
)
pre_transforms: Compose(
    RandomColorJitter(p=0.8, hue=[-0.1, 0.1], saturation=[0.8, 1.2], va

  w_delta = (w_max - w_min)/(res-1)
  fix_loc = torch.tensor(self._check_fix_loc(fix_loc, x.shape[0]), dtype=self.dtype, device=self.device)


ssl_fixator:
NoSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
))

sup_fixator:
MultiRandomSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
), n_fixations=4)

LINEAR PROBE NUM CLASSES: 1000


<All keys matched successfully>

### Now we can create some fake data and get activations.
First, let's see which layers are available to hook

In [23]:
model.list_available_layers()

['',
 'backbone',
 'backbone.layers',
 'backbone.layers.0',
 'backbone.layers.0.conv',
 'backbone.layers.0.conv.conv',
 'backbone.layers.0.norm',
 'backbone.layers.0.activation',
 'backbone.layers.0.pool',
 'backbone.layers.1',
 'backbone.layers.1.conv',
 'backbone.layers.1.conv.conv',
 'backbone.layers.1.norm',
 'backbone.layers.1.activation',
 'backbone.layers.1.pool',
 'backbone.layers.2',
 'backbone.layers.2.conv',
 'backbone.layers.2.conv.conv',
 'backbone.layers.2.norm',
 'backbone.layers.2.activation',
 'backbone.layers.3',
 'backbone.layers.3.conv',
 'backbone.layers.3.conv.conv',
 'backbone.layers.3.norm',
 'backbone.layers.3.activation',
 'backbone.layers.4',
 'backbone.layers.4.conv',
 'backbone.layers.4.conv.conv',
 'backbone.layers.4.norm',
 'backbone.layers.4.activation',
 'backbone.layers.4.pool',
 'backbone.layers.5',
 'projector',
 'projector.layers',
 'projector.layers.fc_block_6',
 'projector.layers.fc_block_6.0',
 'projector.layers.fc_block_6.1',
 'projector.layers.

Let's hook the the fourth backbone block (layers.3), the full backbone (conv layers), and the projector (MLP)

In [35]:
import torch

inputs = torch.rand((10, 3, 256, 256)).to('cpu')
outputs, acts = model.get_activations(inputs, layer_names=['backbone.layers.3', 'backbone', 'projector'])

  fix_loc = torch.tensor(self._check_fix_loc(fix_loc, x.shape[0]), dtype=self.dtype, device=self.device)


Note that the intermediate backbone block retains a spatial dimension ($n=60$), whereas the full backbone has been globally pooled and has no spatial dimension, similarly to the projector.

Note also that each activation tensor contains a fixation dimension as the second dimension.

In [36]:
{k: v.shape for k, v in acts.items()}

{'backbone.layers.3': torch.Size([10, 4, 384, 60]),
 'backbone': torch.Size([10, 4, 256]),
 'projector': torch.Size([10, 4, 1024])}

# Using the trainer class

An even more stream-lined way of getting activations is to use the Trainer class. 

For this to work, you will need to define paths to existing dataset files. For now, these must be FFCV files. Soon, we will allow for standard image datasets. 

When loading a trainer from pre-trained, it is generally easiest to use the utility `get_trainer_from_base_fn`, which does a few basic things under the hood so we don't need to manually edit the config to turn off distributed training, etc. 

In [None]:
from foveation import get_trainer_from_base_fn

base_fn = 'fovknnalexnet_a-1_res-64_in1k'
# edit the paths to those storing your ImageNet-1K FFCV files
# in general, any kwarg you pass in will be used to update the loaded config file
kwargs = {
    'data.train_dataset': '/n/alvarez_lab_tier1/Users/nblauch/datasets/ffcv/imagenet/train_compressed.ffcv',
    'data.val_dataset': '/n/alvarez_lab_tier1/Users/nblauch/datasets/ffcv/imagenet/val_compressed.ffcv',
          }
trainer = get_trainer_from_base_fn(base_fn, load=True, model_dirs=['../models'], **kwargs)


[[96, 11, 2, 5, 1], [256, 5, 1, 2, 1], [384, 3, 1, 1, 1], [384, 3, 1, 1, 1], [256, 3, 1, 1, 1]]
adjusting FOV for fixation: 16.0 (full: 16.0)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 26 giving 964 points (desired: 1024)
found resolution 26 giving 964 points (desired: 1024)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 4 giving 16

  w_delta = (w_max - w_min)/(res-1)


found resolution 53 giving 4085 points (desired: 4096)
Auto-matched resolution to 53 (4085 sampling coordinates) to best match 4096 cartesian pixels.


  fix_loc = torch.tensor(self._check_fix_loc(fix_loc, x.shape[0]), dtype=self.dtype, device=self.device)


ssl_fixator:
NoSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
))

sup_fixator:
MultiRandomSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
), n_fixations=4)

LINEAR PROBE NUM CLASSES: 1000
SaccadeNet(
  (network): BackboneProjectorWrapper(
    (backbone): KNNAlexNet(
      (layers): ModuleList(
        (0): KNNAlexNetBlock(
          (conv): KNNConvLayer(
          	in_channels=3
          	out_channels=96
          	k=121
          	in_coords=SamplingCoords(length=4085, fov=16.0, cmf_a=0.5, resolution=53, style=isotropic)
          	out_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)
          	sample_cortex=True
          )
          (norm): KNNBatc

In [37]:
outputs, activations, targets = trainer.compute_activations(trainer.val_loader, layer_names=['backbone.layers.3', 'backbone', 'projector'], max_batches=4)

  1%|          | 3/391 [00:00<01:16,  5.09it/s]


In [38]:
{k: v.shape for k, v in activations.items()}

{'backbone.layers.3': (512, 4, 384, 60),
 'backbone': (512, 4, 256),
 'projector': (512, 4, 1024)}

note that we also now have the network outputs, which have aggregated over fixations

In [30]:
outputs.shape

(512, 1000)

we can quickly check our top-1 accuracy (note: this is an unstable estimate since we used a small number of batches)

In [34]:
trainer.val_meters['top_1_val'](torch.tensor(outputs), torch.tensor(targets))

tensor(0.5781, device='cuda:0')