<div align="center">
  <img src="https://raw.githubusercontent.com/lapalap/dora/6991c4a08f27e4171e3a9b0bdffc0a14966e07df/assets/images/logo.svg" width="350"/>
</div>

<div align="center"><h1>DORA: s-AMS generation for ImageNet networks</h1>
<h5>This notebook demonstrates DORA's capability to analyze representation spaces of commonly used Computer Vision Architectures.</h5>

In [None]:
! pip install git+https://github.com/lapalap/dora.git --quiet
! pip install umap-learn --quiet
! pip install timm --quiet

In [6]:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from dora import Dora
from dora.objectives import ChannelObjective

from timm import create_model

from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 🤖 Initializing Computer Vision Models

In [5]:
def get_model(model_name):
    if model_name == 'resnet18':
        return models.resnet18(pretrained=True).to(device)
    if model_name == 'alexnet':
        return models.alexnet(pretrained=True).to(device)
    if model_name == 'vit_base_patch16_224':
        return create_model('vit_base_patch16_224', pretrained=True).to(device)
    if model_name == 'beit_base_patch16_224':
        return create_model('beit_base_patch16_224', pretrained=True).to(device)
    if model_name == 'inception_v3':
        return models.inception_v3(pretrained=True).to(device)
    if model_name == 'densenet161':
        return models.densenet161(pretrained=True).to(device)
    if model_name == 'mobilenet_v2':
        return models.mobilenet_v2(pretrained=True).to(device)
    if model_name == 'shufflenet_v2_x1_0':
        return models.shufflenet_v2_x1_0(pretrained=True).to(device)

model_names = ['resnet18',
               'alexnet',
               'vit_base_patch16_224',
               'beit_base_patch16_224',
               'inception_v3',
               'densenet161',
               'mobilenet_v2',
               'shufflenet_v2_x1_0'
]

# 🖼 Generating s-AMS
Here we iteratively generate s-AMS signals for the output representations of popular ImageNet pre-trained Computer Vision models. NOTE: it will take some time

In [None]:
k = 1000
n = 3
neuron_indices = [i for i in range(0, k)]

for model_name in model_names:

  model = get_model(model_name).eval()
  
  d = Dora(model=model,
          storage_dir="dora/",
          device=device)
  
  experiment_name = model_name

  d.generate_signals(
        neuron_idx=neuron_indices,
        num_samples = n,
        layer=model,
        only_maximization = True,
        image_transforms = transforms.Compose([transforms.Pad(2, fill=.5, padding_mode='constant'),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine(0, translate=(0.015, 0.015), fill=0.5),
                                              transforms.RandomAffine((-20,20),
                                                                      scale=(0.75, 1.025),
                                                                      fill=0.5),
                                              transforms.RandomCrop((224, 224),
                                                                    padding=None,
                                                                    pad_if_needed=True,
                                                                    fill=0,
                                                                    padding_mode='constant')]),
        objective_fn=ChannelObjective(),
        lr=0.05,
        width=224,
        height=224,
        iters=500,
        experiment_name=experiment_name,
        overwrite_experiment=True,  ## will still use what already exists if generation params are same
    )