# Import

In [14]:
import timm
import torch
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Subset, random_split

import matplotlib.pyplot as plt
from tqdm import tqdm

## Deform Images

In [3]:
def sin_distortion(x_length: int,
                   y_length: int,
                   A_nm: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Sin distortion for creating deformation maps.

    Args:
    - x_length (int): Length of x-axis of image.
    - y_length (int): Length of y-axis of image.
    - A_nm (torch.Tensor): Square matrix of coefficients. Sets size of cut off.

    Returns:
    (torch.Tensor, torch.Tensor): Deformation maps for x and y coordinates.
    """
    if A_nm.shape[0] != A_nm.shape[1]:
        raise ValueError('A_nm must be square matrix.')

    A_nm = A_nm.float()

    # Create Coordinates
    x = torch.linspace(-1, 1, x_length, dtype=torch.float32)
    y = torch.linspace(-1, 1, y_length, dtype=torch.float32)
    X, Y = torch.meshgrid(x, y, indexing='ij')

    # Create Diffeo
    x_pert = torch.linspace(0, 1, x_length, dtype=torch.float32)
    y_pert = torch.linspace(0, 1, y_length, dtype=torch.float32)

    n = torch.arange(1, A_nm.shape[0] + 1, dtype=torch.float32)
    x_basis = torch.sin(torch.pi * torch.outer(n, x_pert)).T
    y_basis = torch.sin(torch.pi * torch.outer(n, y_pert))

    perturbation = torch.matmul(x_basis, torch.matmul(A_nm, y_basis))

    x_map = X + perturbation
    y_map = Y + perturbation

    return x_map, y_map

def apply_transformation(image_tensor,
                         A_nm: torch.Tensor,
                         interpolation_type='bilinear'):
    """
    Wrapper of `sin_distortion`. Gets torch.tensor and returns the distorted
    torch.tensor according to A_nm.

    Args:
        image_tensor (torch.Tensor): Inputted image.
        A_nm (torch.Tensor): Characterizes diffeo according to `sin_distortion`.
        interpolation_type (str): Interpolation method ('bilinear' or 'nearest').

    Returns:
        image_tensor_deformed (torch.Tensor): Diffeo applied to `image_tensor`.
    """
    # Create deformation map
    x_length, y_length = image_tensor.shape[1:3]
    x_map, y_map  = sin_distortion(x_length, y_length, A_nm)

    return apply_flowgrid(image_tensor, x_map, y_map, interpolation_type=interpolation_type)


def apply_flowgrid(image_tensor, x_map, y_map, interpolation_type='bilinear'):
    # Stack and unsqueeze to form grid
    grid = torch.stack((y_map, x_map), dim=-1).unsqueeze(0).to(image_tensor.device)

    # Apply grid sample
    image_tensor_deformed = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0),
                                                            grid,
                                                            mode=interpolation_type,
                                                            align_corners=True)

    return image_tensor_deformed.squeeze(0)

In [4]:
def diffeo_dataset(tensor):
    A_nm = torch.tensor([[0.0, 0.14],
                         [-0.02, 0.01]])
    return apply_transformation(tensor, A_nm)

# Get ImageNet

In [6]:
import torch as t

root = '/imagenet/'
total_images = 1000
pct_train = 0.8


num_train = int(total_images * pct_train)
num_val = total_images - num_train

# Preprocess the image w/o diffeo
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to grayscale
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229]),
])

# Preprocess the image w/ diffeo
preprocess_diffeo = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to grayscale
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229]),
    diffeo_dataset,
])



dataset_images_og = torchvision.datasets.ImageNet(root=root, 
                                                        split='train', 
                                                        transform=preprocess, )
dataset_images_diffeo = torchvision.datasets.ImageNet(root=root, 
                                                            split='train', 
                                                            transform=preprocess_diffeo,)

# Image Classification Model

In [7]:
# Load a pre-trained ViT model
model = torchvision.models.resnet18(pretrained=True)

# Move the model to GPU if available
model = model.to(device)



In [16]:
tensor_train_images_og     = torch.cat([dataset_images_og[i][0].unsqueeze(0).to(device) for i in tqdm(range(len(dataset_images_og)))])

100%|██████████| 800/800 [00:04<00:00, 179.80it/s]
100%|██████████| 800/800 [00:05<00:00, 149.98it/s]
100%|██████████| 200/200 [00:01<00:00, 141.80it/s]
100%|██████████| 200/200 [00:01<00:00, 121.68it/s]


In [17]:
def get_activation(model, input, layer_index: list):

  activation = {}
  def getActivation(name):
      # the hook signature
      def hook(model, input, output):
          activation[name] = output.detach()
      return hook

  handles = []
  def retrieve_layer_activation(model, input, layer_index):
    if len(input) == 3: input = input[None, :, :, :]

    layers = list(model.children())
    layers_flat = flatten(layers)

    for index in layer_index:
      handles.append(layers_flat[index - 1].register_forward_hook(getActivation(str(index))))

    with t.no_grad(): model(input)
    for handle in handles: handle.remove()

    return

  def flatten(array):
      result = []
      for element in array:
          if hasattr(element, "__iter__"):
              result.extend(flatten(element))
          else:
              result.append(element)
      return result

  retrieve_layer_activation(model, input, layer_index)
  return activation

In [23]:
layer_id = 13

activation_train_og = get_activation(model, tensor_train_images_og, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')
activation_train_diffeo = get_activation(model, tensor_train_images_diffeo, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')

final_dataset = TensorDataset(activation_train_diffeo, activation_train_og)

torch.save(final_dataset, "resnet18_imagenet1k_train.pt")

Num Parameters 512
torch.Size([800, 512])
torch.Size([800, 512])
