# Import

In [1]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mclarkmiyamoto[0m ([33mclarkmiyamoto-new-york-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [14]:
import timm
import torch
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Subset, random_split

import matplotlib.pyplot as plt
from tqdm import tqdm

## Deform Images

In [3]:
def sin_distortion(x_length: int,
                   y_length: int,
                   A_nm: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Sin distortion for creating deformation maps.

    Args:
    - x_length (int): Length of x-axis of image.
    - y_length (int): Length of y-axis of image.
    - A_nm (torch.Tensor): Square matrix of coefficients. Sets size of cut off.

    Returns:
    (torch.Tensor, torch.Tensor): Deformation maps for x and y coordinates.
    """
    if A_nm.shape[0] != A_nm.shape[1]:
        raise ValueError('A_nm must be square matrix.')

    A_nm = A_nm.float()

    # Create Coordinates
    x = torch.linspace(-1, 1, x_length, dtype=torch.float32)
    y = torch.linspace(-1, 1, y_length, dtype=torch.float32)
    X, Y = torch.meshgrid(x, y, indexing='ij')

    # Create Diffeo
    x_pert = torch.linspace(0, 1, x_length, dtype=torch.float32)
    y_pert = torch.linspace(0, 1, y_length, dtype=torch.float32)

    n = torch.arange(1, A_nm.shape[0] + 1, dtype=torch.float32)
    x_basis = torch.sin(torch.pi * torch.outer(n, x_pert)).T
    y_basis = torch.sin(torch.pi * torch.outer(n, y_pert))

    perturbation = torch.matmul(x_basis, torch.matmul(A_nm, y_basis))

    x_map = X + perturbation
    y_map = Y + perturbation

    return x_map, y_map

def apply_transformation(image_tensor,
                         A_nm: torch.Tensor,
                         interpolation_type='bilinear'):
    """
    Wrapper of `sin_distortion`. Gets torch.tensor and returns the distorted
    torch.tensor according to A_nm.

    Args:
        image_tensor (torch.Tensor): Inputted image.
        A_nm (torch.Tensor): Characterizes diffeo according to `sin_distortion`.
        interpolation_type (str): Interpolation method ('bilinear' or 'nearest').

    Returns:
        image_tensor_deformed (torch.Tensor): Diffeo applied to `image_tensor`.
    """
    # Create deformation map
    x_length, y_length = image_tensor.shape[1:3]
    x_map, y_map  = sin_distortion(x_length, y_length, A_nm)

    return apply_flowgrid(image_tensor, x_map, y_map, interpolation_type=interpolation_type)


def apply_flowgrid(image_tensor, x_map, y_map, interpolation_type='bilinear'):
    # Stack and unsqueeze to form grid
    grid = torch.stack((y_map, x_map), dim=-1).unsqueeze(0).to(image_tensor.device)

    # Apply grid sample
    image_tensor_deformed = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0),
                                                            grid,
                                                            mode=interpolation_type,
                                                            align_corners=True)

    return image_tensor_deformed.squeeze(0)

In [4]:
def diffeo_dataset(tensor):
    A_nm = torch.tensor([[0.0, 0.14],
                         [-0.02, 0.01]])
    return apply_transformation(tensor, A_nm)

# Get ImageNet

In [6]:
import torch as t

root = '/imagenet/'
total_images = 1000
pct_train = 0.8


num_train = int(total_images * pct_train)
num_val = total_images - num_train

# Preprocess the image w/o diffeo
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to grayscale
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229]),
])

# Preprocess the image w/ diffeo
preprocess_diffeo = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to grayscale
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229]),
    diffeo_dataset,
])



dataset_train_images_og = torchvision.datasets.ImageNet(root=root, 
                                                        split='train', 
                                                        transform=preprocess, )
dataset_train_images_diffeo = torchvision.datasets.ImageNet(root=root, 
                                                            split='train', 
                                                            transform=preprocess_diffeo,)
dataset_val_images_og = torchvision.datasets.ImageNet(root=root, 
                                                      split='val', 
                                                      transform=preprocess,)
dataset_val_images_diffeo = torchvision.datasets.ImageNet(root=root, 
                                                     split='val', 
                                                     transform=preprocess_diffeo)

dataset_train_images_og = Subset(dataset_train_images_og, indices=range(num_train))
dataset_train_images_diffeo = Subset(dataset_train_images_diffeo, indices=range(num_train))
dataset_val_images_og = Subset(dataset_val_images_og, indices=range(num_val))
dataset_val_images_diffeo = Subset(dataset_val_images_diffeo, indices=range(num_val))

# Image Classification Model

In [7]:
# Load a pre-trained ViT model
model = torchvision.models.resnet18(pretrained=True)

# Move the model to GPU if available
model = model.to(device)



In [16]:
tensor_train_images_og     = torch.cat([dataset_train_images_og[i][0].unsqueeze(0).to(device) for i in tqdm(range(len(dataset_train_images_og)))])
tensor_train_images_diffeo = torch.cat([dataset_train_images_diffeo[i][0].unsqueeze(0).to(device) for i in tqdm(range(len(dataset_train_images_og)))])
tensor_val_images_og       = torch.cat([dataset_val_images_og[i][0].unsqueeze(0).to(device) for i in tqdm(range(len(dataset_val_images_og)))])
tensor_val_images_diffeo   = torch.cat([dataset_val_images_diffeo[i][0].unsqueeze(0).to(device) for i in tqdm(range(len(dataset_val_images_og)))])

100%|██████████| 800/800 [00:04<00:00, 179.80it/s]
100%|██████████| 800/800 [00:05<00:00, 149.98it/s]
100%|██████████| 200/200 [00:01<00:00, 141.80it/s]
100%|██████████| 200/200 [00:01<00:00, 121.68it/s]


In [17]:
def get_activation(model, input, layer_index: list):

  activation = {}
  def getActivation(name):
      # the hook signature
      def hook(model, input, output):
          activation[name] = output.detach()
      return hook

  handles = []
  def retrieve_layer_activation(model, input, layer_index):
    if len(input) == 3: input = input[None, :, :, :]

    layers = list(model.children())
    layers_flat = flatten(layers)

    for index in layer_index:
      handles.append(layers_flat[index - 1].register_forward_hook(getActivation(str(index))))

    with t.no_grad(): model(input)
    for handle in handles: handle.remove()

    return

  def flatten(array):
      result = []
      for element in array:
          if hasattr(element, "__iter__"):
              result.extend(flatten(element))
          else:
              result.append(element)
      return result

  retrieve_layer_activation(model, input, layer_index)
  return activation

In [23]:
layer_id = 13

activation_train_og = get_activation(model, tensor_train_images_og, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')
activation_train_diffeo = get_activation(model, tensor_train_images_diffeo, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')
activation_val_og = get_activation(model, tensor_val_images_og, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')
activation_val_diffeo = get_activation(model, tensor_val_images_diffeo, [layer_id])[f'{layer_id}'].flatten(start_dim=1).to('cpu')

# Dataset to attempt representation finding
train_dataset = TensorDataset(activation_train_og, activation_train_diffeo)
val_dataset = TensorDataset(activation_val_og, activation_val_diffeo)

print('Num Parameters', activation_train_og.shape[-1])
print(activation_train_og.shape)
print(activation_train_diffeo.shape)

Num Parameters 512
torch.Size([800, 512])
torch.Size([800, 512])


# Regression

In [40]:
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()

    # Model
    self.encoder = nn.Sequential(
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
    )

  def forward(self, x):
    x = self.encoder(x)
    return x

### Run Experiment

In [41]:
model_ginv = Decoder()
model_ginv.to(device)

trainable_params = sum(p.numel() for p in model_ginv.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

Trainable parameters: 2036224


In [42]:
# -------------------
# 1. Initialize Weights & Biases
# -------------------
wandb.init(
    project="diffeo",
    name=f"MLP",
    config={
        "epochs": 1500,
        "batch_size": 1,
        "lr": 1e-5,
        "weight_decay":0.01,
        "dataset": "ImageNet1k",
    },
)
config = wandb.config

In [43]:
# Create DataLoaders for each split
train_loader = DataLoader(train_dataset,
                          batch_size=config.batch_size,
                          shuffle=True,
                          num_workers=1)

val_loader = DataLoader(val_dataset,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=1)

In [44]:
# -------------------
# Define Model, Optimizer, Loss
# -------------------

optimizer = optim.AdamW(model_ginv.parameters(), 
                        lr=config.lr, 
                        weight_decay=config.weight_decay,
                      )
criterion = nn.MSELoss()

# Baseline error on the first sample
print('Baseline', criterion(*train_dataset[0]))

# -------------------
# Training Loop
# -------------------
for epoch in range(config.epochs):
    # ---- Training ----
    model_ginv.train()
    total_train_loss = 0.0
    for batch_X, batch_Y in train_loader:
        # Move to GPU if available
        batch_X = batch_X.to(device)
        batch_Y = batch_Y.to(device)

        # Forward pass
        predictions = model_ginv(batch_X)
        loss = criterion(predictions, batch_Y)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    # Average training loss over the epoch
    avg_train_loss = total_train_loss / len(train_loader)

    # ---- Validation ----
    model_ginv.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch_X, batch_Y in val_loader:
            batch_X = batch_X.to(device)
            batch_Y = batch_Y.to(device)

            predictions = model_ginv(batch_X)
            loss = criterion(predictions, batch_Y)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    # ---- Logging ----
    # Log both training loss and validation loss at the end of each epoch
    wandb.log({
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "epoch": epoch
    })

    print(f"Epoch [{epoch+1}/{config.epochs}] | "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f}")

Baseline tensor(0.2336)
Epoch [1/1500] | Train Loss: 1.2882 | Val Loss: 0.5017
Epoch [2/1500] | Train Loss: 0.4628 | Val Loss: 0.4952
Epoch [3/1500] | Train Loss: 0.4606 | Val Loss: 0.5062
Epoch [4/1500] | Train Loss: 0.4602 | Val Loss: 0.4968
Epoch [5/1500] | Train Loss: 0.4599 | Val Loss: 0.4941
Epoch [6/1500] | Train Loss: 0.4582 | Val Loss: 0.4974
Epoch [7/1500] | Train Loss: 0.4577 | Val Loss: 0.4917
Epoch [8/1500] | Train Loss: 0.4570 | Val Loss: 0.4940
Epoch [9/1500] | Train Loss: 0.4567 | Val Loss: 0.4904
Epoch [10/1500] | Train Loss: 0.4557 | Val Loss: 0.4944
Epoch [11/1500] | Train Loss: 0.4563 | Val Loss: 0.4899
Epoch [12/1500] | Train Loss: 0.4555 | Val Loss: 0.4911
Epoch [13/1500] | Train Loss: 0.4553 | Val Loss: 0.4908
Epoch [14/1500] | Train Loss: 0.4550 | Val Loss: 0.4954
Epoch [15/1500] | Train Loss: 0.4546 | Val Loss: 0.4910
Epoch [16/1500] | Train Loss: 0.4541 | Val Loss: 0.4892


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14c73458c180>
Traceback (most recent call last):
  File "/ext3/miniforge3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/ext3/miniforge3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1568, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/ext3/miniforge3/lib/python3.12/multiprocessing/process.py", line 142, in join
    def join(self, timeout=None):

KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
train_dataset, val_dataset
pic_id = 1

x, y = train_dataset[pic_id][0], train_dataset[pic_id][1]
x_plot = x.flatten().cpu()
y_plot = y.flatten().cpu()
prediction = model_ginv(x.unsqueeze(0))
prediction_plot = prediction.flatten().cpu().detach().numpy()

plt.plot(x_plot, label='input', alpha=0.3)
plt.plot(y_plot, label='target', alpha=0.3)
plt.plot(x_plot - y_plot, label='baseline')
plt.plot(y_plot - prediction_plot, label='residual', color='black')
plt.legend()