# Load pre-trained models

Some folks will want to just get right to testing models. If that's you, start here. Otherwise, see the other more detailed notebooks

## Foveated AlexNet-like CNN model

Our first model is an AlexNet-like CNN model. This model was trained with resource constraints, so instead of processing 224x224 pixel images at uniform resolution, it processes ~64x64 pixel images at a variable resolution that peaks in the center of gaze and falls off progressively. These resource constraints model the constraints the human brain has on brain size, here, reducing the number of neurons by a factor of 16; we can scale up the processing demands by moving our eyes and processing more fixations over time, trading energy usage for improved visual performance. 

In [None]:
from foveation import load_config
from foveation.saccadenet import SaccadeNet

base_fn = 'fovknnalexnet_a-1_res-64_in1k'
config, state_dict, model_key = load_config(base_fn, load=True, folder='../models', device='cpu')
model = SaccadeNet(config, device='cpu')
model.load_state_dict(state_dict[model_key])

[[96, 11, 2, 5, 1], [256, 5, 1, 2, 1], [384, 3, 1, 1, 1], [384, 3, 1, 1, 1], [256, 3, 1, 1, 1]]
adjusting FOV for fixation: 16.0 (full: 16.0)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 53 giving 4085 points (desired: 4096)
found resolution 26 giving 964 points (desired: 1024)
found resolution 26 giving 964 points (desired: 1024)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 13 giving 230 points (desired: 256)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 7 giving 60 points (desired: 64)
found resolution 4 giving 16

  w_delta = (w_max - w_min)/(res-1)
  fix_loc = torch.tensor(self._check_fix_loc(fix_loc, x.shape[0]), dtype=self.dtype, device=self.device)


ssl_fixator:
NoSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
))

sup_fixator:
MultiRandomSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
), n_fixations=4)

LINEAR PROBE NUM CLASSES: 1000


<All keys matched successfully>

# Foveated ViT-S DINOv3

Our second model is ViT-S that was pretrained under the DINOv3 protocol. More precisely, a larger ViT was first trained, and then it distilled its knowledge into the pre-trained ViT-S DINOv3 model we used as a starting point. We then adapted this model to receive foveated inputs. This involved replacing the patch embedding with a foveated one, and doing some low-rank adaptation (LoRA) to allow the network to better handle foveated inputs. Like the previous model, this model was trained with resource constraints; instead of processing 224x224 pixel images at uniform resolution, it processes ~64x64 pixel images at a variable resolution that peaks in the center of gaze and falls off progressively. 

Here, the resource constraints are not as extreme compared to the original model. This is because the number of patches -- rather than pixels -- primarily determines the processing resources in a ViT, and we opted to use 8x8-like patches ($k=64$), over a 64x64-like input sensor manifold ($n\approx4096$). This model thus reduces from the standard 14x14 or 16x16 number of patches to 8x8 number of patches, which is still a reduction of 3-4x. Thus, per fixation the savings in the linear operations is 3-4x, whereas the savings in the attention (quadratic) operations is 9-16x. We take multiple fixations to unfold the processing constraints over time. 

In [3]:
base_fn = 'fovknndinov3-s_a-2.78_res-64_in1k'
config, state_dict, model_key = load_config(base_fn, load=True, folder='../models', device='cpu')
model = SaccadeNet(config, device='cpu')
model.load_state_dict(state_dict[model_key])

found resolution 44 giving 3976 points (desired: 4096)


  num_neighbors = torch.minimum(torch.tensor(self.k*m), torch.tensor(self.in_coords.shape[0]))


found resolution 6 giving 64 points (desired: 64)
found resolution 6 giving 64 points (desired: 64)


100%|██████████| 64/64 [00:00<00:00, 755.57it/s]
100%|██████████| 64/64 [00:00<00:00, 373.66it/s]


adjusting FOV for fixation: 16.0 (full: 16.0)
loader_transforms: Compose(
    ToTorchImage(device=cpu, dtype=torch.float32, from_numpy=True)
    RandomHorizontalFlip(p=0.5, seed=None)
)
pre_transforms: Compose(
    RandomColorJitter(p=0.8, hue=[-0.1, 0.1], saturation=[0.8, 1.2], value=[0.6, 1.4], contrast=[0.6, 1.4], seed=None)
    RandomGrayscale(p=0.2, num_output_channels=3, seed=None)
    NormalizeGPU(mean=tensor([0.4850, 0.4560, 0.4060], dtype=torch.float64), std=tensor([0.2290, 0.2240, 0.2250], dtype=torch.float64), inplace=True)
)
post_transforms: None
found resolution 44 giving 3976 points (desired: 4096)
Auto-matched resolution to 44 (3976 sampling coordinates) to best match 4096 cartesian pixels.


  fix_loc = torch.tensor(self._check_fix_loc(fix_loc, x.shape[0]), dtype=self.dtype, device=self.device)


ssl_fixator:
NoSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=2.785765, style=isotropic, resolution=44, mode=nearest, n=3976)
))

sup_fixator:
MultiRandomSaccadePolicy(retinal_transform=RetinalTransform(
  (foveal_color): GaussianColorDecay(sigma=None)
  (sampler): GridSampler(fov=16.0, cmf_a=2.785765, style=isotropic, resolution=44, mode=nearest, n=3976)
), n_fixations=4)

LINEAR PROBE NUM CLASSES: 1000


<All keys matched successfully>