In [1]:
import torch
import torchvision
import numpy as np
import PIL

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

Utility functions and classes with creating a PIL image from x

In [3]:
class NetworkWrapper(torch.nn.Module):
    def __init__(self, network, preprocess_fn):
        super(NetworkWrapper, self).__init__()

        self.preprocess_fn = preprocess_fn
        self.network = network
        self.network.eval()

    def forward(self, x):
        x = self.preprocess_fn(x)
        x = self.network(x)
        return x

class Visualization(torch.nn.Module):
    def __init__(self, h, w):
        super(Visualization, self).__init__()
        self.__data = torch.nn.Parameter(torch.randn(1, 3, h, w))
    
    def __augment(self, x, batch_size):
        x = torch.cat([x] * batch_size, dim=0)

        transforms = torch.nn.Sequential(
            torchvision.transforms.RandomResizedCrop(size=[self.out_h, self.out_w], scale=(0.01, 1.0)),
            torchvision.transforms.RandomRotation(degrees=20),
            torchvision.transforms.RandomHorizontalFlip(p=0.5),
            torchvision.transforms.RandomPerspective(distortion_scale=0.4, p=0.5),
            torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            torchvision.transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))
        )
        x = transforms(x) # Apply random augmentations
        # For more augmentations: https://pytorch.org/vision/0.19/transforms.html
        return x
    
    def __reparameterize(self, x):

        x = torch.nn.functional.sigmoid(x)
        return x
    
    def set_output_shape(self, h, w):
        self.out_h = h
        self.out_w = w

    def forward(self, batch_size):
        x = self.__data
        x = self.__reparameterize(x)
        x = self.__augment(x, batch_size)
        return x

    def to_img(self):
        with torch.no_grad():
            x = self.__data
            x = self.__reparameterize(x)
        x = x.squeeze(0)
        x = x.cpu()
        x = x.numpy()
        x = np.transpose(x, (1, 2, 0))
        x = (x * 255).clip(0, 255)
        x = x.astype(np.uint8)
        pil_img = PIL.Image.fromarray(x)
        return pil_img

Initialize the model, visualization and optimizer. You can select either ResNet50 or ViT-B/16

In [None]:
net = torchvision.models.resnet18(pretrained=True) # ResNet50
# net = torchvision.models.vit_b_16(weights='IMAGENET1K_V1') # ViT-B/16

preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = NetworkWrapper(net, preprocess_fn).to(device)

vis = Visualization(256, 256).to(device)
vis.set_output_shape(224, 224)

optimizer = torch.optim.AdamW(
    params=vis.parameters(),
    lr=0.2,
)

Training loop   

In [None]:
class_index = 852 # Selected ID for activation maximization from ImageNet1K (Tennis ball)

for i in range(10000):
    optimizer.zero_grad()
    batch_size = 8
    imgs = vis(batch_size).to(device)
    outputs = model(imgs)
    target_scores = outputs[:, class_index]

    l2_regularization = torch.norm(imgs, p=2)
    l2_lambda = 0.001 
    regularization = l2_lambda * l2_regularization

    loss = -torch.mean(target_scores) + regularization # Mean score loss + L2 regularization for activation maximization
    
    if (i + 1) % 100 == 0:
        print(f"Iteration {i + 1}, Loss: {loss.item()}, L2 Regularization: {regularization.item()}")
    
    loss.backward()
    optimizer.step()

    # Show visualization every 1000 iterations
    if (i + 1) % 1000 == 0:
        print(f"Image created at: Iteration {i + 1}, Loss: {loss.item()}")
        img = vis.to_img()
        image_name = f"visualization_{i + 1}.png"
        img.save(image_name)