<a href="https://colab.research.google.com/github/hmichaeli/alias_free_convnets/blob/main/AFC_shift_equvariance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
from huggingface_hub import hf_hub_download
import torch
from torchvision import datasets, transforms
from alias_free_convnets.models.convnext_afc import convnext_afc_tiny
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import transforms
from models.ideal_lpf import UpsampleRFFT

In [26]:
# !rm -rf alias_free_convnets
!git clone https://github.com/hmichaeli/alias_free_convnets.git
%cd alias_free_convnets

Cloning into 'alias_free_convnets'...
remote: Enumerating objects: 124, done.[K
remote: Counting objects: 100% (124/124), done.[K
remote: Compressing objects: 100% (94/94), done.[K
remote: Total 124 (delta 33), reused 102 (delta 20), pack-reused 0 (from 0)[K
Receiving objects: 100% (124/124), 10.05 MiB | 19.19 MiB/s, done.
Resolving deltas: 100% (33/33), done.
/content/alias_free_convnets/alias_free_convnets/alias_free_convnets


In [27]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# baseline
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_basline.pth")
ckpt = torch.load(path, map_location="cpu")
base_model = convnext_afc_tiny(pretrained=False, num_classes=1000)
base_model.load_state_dict(ckpt, strict=True)
base_model = base_model.to(device).eval()

# AFC
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_afc.pth")
ckpt = torch.load(path, map_location="cpu")
afc_model = convnext_afc_tiny(
        pretrained=False,
        num_classes=1000,
        activation='up_poly_per_channel',
        activation_kwargs={'in_scale': 7, 'out_scale': 7, 'train_scale': True},
        blurpool_kwargs={"filter_type": "ideal", "scale_l2": False},
        normalization_type='CHW2',
        stem_activation_kwargs={"in_scale": 7, "out_scale": 7, "train_scale": True, "cutoff": 0.75},
        normalization_kwargs={},
        stem_mode='activation_residual', stem_activation='lpf_poly_per_channel'
    )
afc_model.load_state_dict(ckpt, strict=False)
afc_model = afc_model.to(device).eval()


  ckpt = torch.load(path, map_location="cpu")


ConvNext kwargs:  {'num_classes': 1000}


  ckpt = torch.load(path, map_location="cpu")


ConvNext kwargs:  {'num_classes': 1000, 'activation': 'up_poly_per_channel', 'activation_kwargs': {'in_scale': 7, 'out_scale': 7, 'train_scale': True}, 'blurpool_kwargs': {'filter_type': 'ideal', 'scale_l2': False}, 'normalization_type': 'CHW2', 'stem_activation_kwargs': {'in_scale': 7, 'out_scale': 7, 'train_scale': True, 'cutoff': 0.75}, 'normalization_kwargs': {}, 'stem_mode': 'activation_residual', 'stem_activation': 'lpf_poly_per_channel'}


In [28]:
# load example image
interpolation = transforms.InterpolationMode.BICUBIC
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
    transforms.Resize(256, interpolation=interpolation),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img_path = "./assets/n01608432_3247.JPEG"
image = Image.open(img_path)
image = transform(image).unsqueeze(0).to(device)


In [29]:
@torch.no_grad()
def shift_and_compare(model, image, shift_x, shift_y):
    """
    Cyclic-Shifts the image, extracts features, upsamples, shifts back, and compares.

    Args:
      model: The PyTorch model to use for feature extraction.
      image: The input image tensor.
      shift_x: Horizontal shift amount.
      shift_y: Vertical shift amount.

    Returns:
      A tuple containing:
        - The original feature map.
        - The shifted and reversed feature map.
        - The difference between the two feature maps.
    """

    # Shift the image cyclically
    shifted_image = torch.roll(image, shifts=(shift_x, shift_y), dims=(3, 2))

    # Get feature maps from the model
    feature_map = model.forward_features(image, avgpool=False)
    shifted_feature_map = model.forward_features(shifted_image, avgpool=False)

    # Upsample to the original image size
    size_ratio = int(image.shape[-1] / feature_map.shape[-1])
    feature_map = UpsampleRFFT(size_ratio)(feature_map)
    shifted_feature_map = UpsampleRFFT(size_ratio)(shifted_feature_map)

    # Reverse the shift
    shifted_feature_map = torch.roll(shifted_feature_map, shifts=(-shift_x, -shift_y), dims=(3, 2))

    # Featuremap shift-equivariance diff
    difference = torch.abs(shifted_feature_map - feature_map)
    print("featuremap avg diff: ", torch.mean(difference))

    # Feature-vector invariance / sum-shift invariance
    feature_vec = torch.mean(feature_map, dim=(2, 3))
    shifted_feature_vec = torch.mean(shifted_feature_map, dim=(2, 3))

    print("feature vector diff: ", torch.mean(torch.abs(feature_vec - shifted_feature_vec)))


print("baseline: ")
shift_and_compare(base_model, image, 1, 1)
print("afc: ")
shift_and_compare(afc_model, image, 1, 1)


baseline: 
featuremap avg diff:  tensor(0.2781, device='cuda:0')
feature vector diff:  tensor(0.0941, device='cuda:0')
afc: 
featuremap avg diff:  tensor(1.3378e-05, device='cuda:0')
feature vector diff:  tensor(3.4836e-06, device='cuda:0')
