In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, jit

from sklearn import linear_model
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

import torch

from tqdm import tqdm
import matplotlib.pyplot as plt

# Hogg's Idea

We want to learn $h^{-1}$ such that
$$
h^{-1}(\theta) \star g_{naive}^{-1}(\vec r) \star \vec N_i(g(\theta, \vec r) \cdot I(\vec r)) = \vec N_i(I(\vec r))
$$
where $\vec N_i \in \mathbb R^c \times \mathbb R^{l \times l}$, it's a vector over channels, containing images of size $l$. Note, Hogg believes $\partial_{\vec r} h^{-1} = 0$, $h^{-1}$, can be expressed independent of coordinates.

Hogg's proposition is the learned $h^{-1} $ is a real-valued matrix which acts on the level of the channels
$$
h^{-1} = \begin{pmatrix}
h_{11}^{-1} & h_{21}^{-1} & ...\\
h_{21}^{-1} & h_{22}^{-1} & ...\\
\vdots & \vdots & \ddots
\end{pmatrix} \in \mathbb M^{c \times c}(\mathbb R)
$$

## Methodology
Solve
$$
Ax = y
$$
such that $x := g_{naive}^{-1}(\vec r) \star \vec N_i(g(\theta, \vec r) \cdot I(\vec r))$ and $y := \vec N_i(I(\vec r))$. So the matrix multiplication is occurring on the level of the channels.

## Download Data
Get layers 14-19. All same resolution

In [10]:
def LOAD_NgI(target_pic, numbers):
    data_dir = '/vast/xj2173/diffeo/data/all_cnn_layers/'
    data_name = [data_dir + f'15-100-4-4-3-224-224_image-{target_pic}_activation_layer-{i}.pt' for i in numbers]
    
    
    # data[0] is the 0th picture, data[1] is the 1st picture, etc..
    data = [torch.load(file_name, map_location='cpu') for file_name in tqdm(data_name)]
    data = torch.stack(data, dim=0)

    return data

def LOAD_gInvNgI(target_pic, numbers, data):
    mode = 'nearest'
    res_of_layer = data.shape[-1]
    
    data_dir = '/vast/xj2173/diffeo/data/all_cnn_layers/'
    file_name = data_dir + '15-100-4-4-3-224-224_inv_grid_sample.pt'
    inv_diffeos_maps = torch.load(file_name, map_location='cpu')
    inv_diffeos_maps = torch.stack(inv_diffeos_maps)
    # Downsample transformation to resolution of convolutional layer
    
    inv_diffeos_maps = inv_diffeos_maps.reshape(15 * 100, 224, 224, 2)
    inv_diffeos_maps = inv_diffeos_maps.permute(0, 3, 1, 2)
    inv_diffeos_maps = torch.nn.functional.interpolate(inv_diffeos_maps, size=(res_of_layer, res_of_layer), mode=mode)
    inv_diffeos_maps = inv_diffeos_maps.permute(0, 2, 3, 1)

    data_inv = [torch.nn.functional.grid_sample(pic_data, inv_diffeos_maps, mode = mode) for pic_data in tqdm(data)]
    data_inv = torch.stack(data_inv, dim=0)

    return data_inv

def LOAD_NI(target_pic, numbers):
    data_dir = '/vast/xj2173/diffeo/data/reference/'
    data_name = [data_dir + f'val_image-{target_pic}_activation_layer-{i}.pt' for i in numbers]
    
    # ref_data[0] is the 0th picture, ref_data[1] is the 1st picture, etc..
    ref_data = [torch.load(file_name, map_location='cpu').squeeze(0) for file_name in tqdm(data_name)]
    ref_data = torch.stack(ref_data)
    return ref_data

def LOAD_all(target_pic, numbers):
    data = LOAD_NgI(target_pic, numbers)
    data_inv = LOAD_gInvNgI(target_pic, numbers, data)
    ref_data = LOAD_NI(target_pic, numbers)

    return data, data_inv, ref_data
    

### `LOAD_all` documentation
1. `data` is $N(g * i)$
2. `data_inv` is $g^{-1}_{naive} \star N(g * I)$
3. `ref_data` is $N_i(I)$

## Attempt 2: Prevent Overfitting
$$
\forall i,j:  \ \ \  \min_{[h^{-1}]_{ij}} \mathcal L = \min_{[h^{-1}]_{ij}} \text{MSE} \left ( h^{-1} g_{naive}^{-1} N_\alpha(g \cdot I)  ,   N_\alpha(I)\right)_{I, \vec r}
$$
where $\text{MSE}(\cdot, *)_{I, \vec r}$ is the quantity average over set of images and pixels

In [11]:
target_pics = [f'{i:04}' for i in range(0,100)]
layer_numbers = [f"{i:02}" for i in range(14,15)]

In [12]:
data = []
data_inv = []
ref_data = []
for pic in tqdm(target_pics):
    data_i, data_inv_i, ref_data_i = LOAD_all('0001', layer_numbers)
    data_i = data_i.squeeze(0)
    data_inv_i = data_inv_i.squeeze(0)
    ref_data_i = ref_data_i.squeeze(0)
    data.append(data_i)
    data_inv.append(data_inv_i)
    ref_data.append(ref_data_i)
data = torch.stack(data)
data_inv = torch.stack(data_inv)
ref_data = torch.stack(ref_data)
    
data.shape, data_inv.shape, ref_data.shape

  0%|          | 0/100 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 11.80it/s]

100%|██████████| 1/1 [00:00<00:00, 26.91it/s]

100%|██████████| 1/1 [00:00<00:00, 1013.61it/s]
  1%|          | 1/100 [00:00<01:07,  1.48it/s]
100%|██████████| 1/1 [00:00<00:00, 12.04it/s]

100%|██████████| 1/1 [00:00<00:00, 27.15it/s]

100%|██████████| 1/1 [00:00<00:00, 1292.15it/s]
  2%|▏         | 2/100 [00:01<01:04,  1.51it/s]
100%|██████████| 1/1 [00:00<00:00, 12.12it/s]

100%|██████████| 1/1 [00:00<00:00, 26.83it/s]

100%|██████████| 1/1 [00:00<00:00, 1116.10it/s]
  3%|▎         | 3/100 [00:01<01:03,  1.52it/s]
100%|██████████| 1/1 [00:00<00:00, 12.11it/s]

100%|██████████| 1/1 [00:00<00:00, 27.17it/s]

100%|██████████| 1/1 [00:00<00:00, 1082.40it/s]
  4%|▍         | 4/100 [00:02<01:02,  1.53it/s]
100%|██████████| 1/1 [00:00<00:00, 12.13it/s]

100%|██████████| 1/1 [00:00<00:00, 26.56it/s]

100%|██████████| 1/1 [00:00<00:00, 1136.05it/s]
  5%|▌         | 5/100 [00:03<01:01,  1.53it/s]
100%|█████

(torch.Size([100, 1500, 128, 14, 14]),
 torch.Size([100, 1500, 128, 14, 14]),
 torch.Size([100, 128, 14, 14]))

In [13]:
diffeo_idx = 0

feature = data_inv[:, diffeo_idx, :, :, :]
label = ref_data[:, :, :, :]
feature.shape, label.shape

(torch.Size([100, 128, 14, 14]), torch.Size([100, 128, 14, 14]))

In [None]:
import torch
import torch.optim as optim

# Define the loss function
def loss(A, features, labels):
    features
    predictions = torch.einsum('ab,iaxy->ibxy', A, features)
    return torch.mean((predictions - labels) ** 2)


# Initialize A with normal distribution
features_shape = (feature.shape[1], feature.shape[1])  # Replace with the actual shape of your features
A = torch.randn(features_shape, requires_grad=True)

# Hyperparameters
learning_rate = 0.0001
num_iterations = 100000
threshold = 1e-6  # Define a threshold for change in loss
patience = 100  # Define the number of iterations to wait before stopping if no improvement
counter = 0  # Initialize a counter to track the number of iterations without significant change
previous_loss = float('inf')  # Initialize the previous loss to a high value


# Define the optimizer
optimizer = optim.Adam([A], lr=learning_rate)

for i in range(num_iterations):
    optimizer.zero_grad()  # Zero the gradients before each iteration
    current_loss = loss(A, feature, label)
    current_loss.backward()  # Backpropagate to compute gradients
    optimizer.step()  # Update parameters
    
    if i % 500 == 0:
        print(f"Iteration {i}: Loss = {current_loss.item()}")
    
    # Check for early stopping
    if abs(previous_loss - current_loss.item()) < threshold:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at iteration {i}: Loss = {current_loss.item()}")
            break
    else:
        counter = 0  # Reset counter if there is significant change

    previous_loss = current_loss.item()  # Update previous loss

Iteration 0: Loss = 1816.99853515625
Iteration 500: Loss = 1539.9759521484375
Iteration 1000: Loss = 1317.5291748046875
Iteration 1500: Loss = 1134.6761474609375
Iteration 2000: Loss = 982.0088500976562
Iteration 2500: Loss = 853.0062255859375
Iteration 3000: Loss = 742.9746704101562
Iteration 3500: Loss = 648.4883422851562
Iteration 4000: Loss = 566.9557495117188
Iteration 4500: Loss = 496.32366943359375
Iteration 5000: Loss = 434.9242858886719
Iteration 5500: Loss = 381.3919982910156
Iteration 6000: Loss = 334.5985107421875
Iteration 6500: Loss = 293.6014709472656
Iteration 7000: Loss = 257.6108703613281
Iteration 7500: Loss = 225.9676055908203
Iteration 8000: Loss = 198.12332153320312
Iteration 8500: Loss = 173.61624145507812
Iteration 9000: Loss = 152.05039978027344
Iteration 9500: Loss = 133.0829315185547
Iteration 10000: Loss = 116.41386413574219
Iteration 10500: Loss = 101.77722930908203
Iteration 11000: Loss = 88.93477630615234
Iteration 11500: Loss = 77.67372131347656
Iteratio

In [None]:
def plot(tensor, start_index=0, end_index=12, title=None, vmin=None, vmax=None):
    num_plots = end_index - start_index
    num_cols = 4  # Number of columns in the plot grid
    num_rows = (num_plots + num_cols - 1) // num_cols  # Calculate the number of rows needed
    
    # Create the figure and subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, num_rows * 5))  # Increase figure size for better fit

    # Adjust the layout to give more space
    fig.subplots_adjust(right=0.85, left=0.05, top=0.92, bottom=0.1, wspace=0.1, hspace=0.1)

    # Create a dummy Axes to hold the common colorbar
    cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.7])  # Adjust these values to fit the layout

    for i, ax in enumerate(axes.flatten()):
        if i < num_plots:
            heatmap = tensor[start_index + i].detach().numpy()  # Assuming tensor is a PyTorch tensor
            im = ax.imshow(heatmap, cmap='viridis', vmin=vmin, vmax=vmax)
            ax.set_title(f'Channel {start_index + i + 1}')
            ax.axis('off')
        else:
            ax.axis('off')  # Turn off axis for unused subplot

    # Create a colorbar for the whole figure
    fig.colorbar(im, cax=cbar_ax)

    # Set the overall title, if provided
    if title:
        fig.suptitle(title, size=20, y=0.98)  # Adjust the vertical position of the title

    plt.show()

In [None]:
# Selct Image
image_idx = 0
NOTHING_feature = data[image_idx, diffeo_idx, :, :, :]
feature = data_inv[image_idx, diffeo_idx, :, :, :]
label = ref_data[image_idx, :, :, :]

approx_identity = torch.einsum('ab,axy->bxy', A, feature)

plot(approx_identity, 
     title='$h^{-1} g^{-1}_{naive} N_i(g \cdot I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(feature - label, 
     title = '$g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(approx_identity - label, 
     title = '$h^{-1} g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
torch.mean(torch.abs(feature))

In [None]:
error_per_pixel = lambda y_hat, y: torch.sum(torch.abs(y_hat - y)) / torch.numel(y)

print(f'For Image {image_idx}')
print(f'Avg Error Per Pixel, NO CORRECTION: {float(error_per_pixel(NOTHING_feature, label)):.7f}')
print(f'Avg Error Per Pixel, using g^-1: {float(error_per_pixel(feature, label)):.7f}')
print(f'Avg Error Per Pixel, using g^-1 and h^-1: {float(error_per_pixel(approx_identity, label)):.7f}')

## Attempt 1: Matrix Inverse, Independent of Position

In this case, entries of $h^{-1}$ have the same value at every pixel position

Our loss function $\mathcal L$ is:
$$
\forall i,j; \ \  \  \min_{[h^{-1}]_{ij}} \mathcal L = \min_{[h^{-1}]_{ij}} \text{MSE} \left ( h^{-1} g_{naive}^{-1} N_\alpha(g \cdot I)  ,   N_\alpha(I)\right)_{\vec r}
$$

$\text{MSE}(\cdot, *)_{\vec r}$ means the quantities are averaged over all pixel positions

In [None]:
layer_idx = 5 #0 := layer 14, 5 := layer 19
diffeo_idx = 0

feature = data_inv[layer_idx, diffeo_idx, :, :, :]
label = ref_data[layer_idx, :, :, :]
feature.shape, label.shape

In [None]:
import torch
import torch.optim as optim

# Define the loss function
def loss(A, features, labels):
    predictions = torch.einsum('ab,axy->bxy', A, features)
    return torch.mean((predictions - labels) ** 2)


# Initialize A with normal distribution
features_shape = (len(feature), len(feature))  # Replace with the actual shape of your features
A = torch.randn(features_shape, requires_grad=True)

# Hyperparameters
learning_rate = 0.0001
num_iterations = 100000
threshold = 1e-6  # Define a threshold for change in loss
patience = 100  # Define the number of iterations to wait before stopping if no improvement
counter = 0  # Initialize a counter to track the number of iterations without significant change
previous_loss = float('inf')  # Initialize the previous loss to a high value


# Define the optimizer
optimizer = optim.Adam([A], lr=learning_rate)

for i in range(num_iterations):
    optimizer.zero_grad()  # Zero the gradients before each iteration
    current_loss = loss(A, feature, label)
    current_loss.backward()  # Backpropagate to compute gradients
    optimizer.step()  # Update parameters
    
    if i % 500 == 0:
        print(f"Iteration {i}: Loss = {current_loss.item()}")
    
    # Check for early stopping
    if abs(previous_loss - current_loss.item()) < threshold:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at iteration {i}: Loss = {current_loss.item()}")
            break
    else:
        counter = 0  # Reset counter if there is significant change

    previous_loss = current_loss.item()  # Update previous loss

In [None]:
def plot(tensor, start_index=0, end_index=12, title=None, vmin=None, vmax=None):
    num_plots = end_index - start_index
    num_cols = 4  # Number of columns in the plot grid
    num_rows = (num_plots + num_cols - 1) // num_cols  # Calculate the number of rows needed
    
    # Create the figure and subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, num_rows * 5))  # Increase figure size for better fit

    # Adjust the layout to give more space
    fig.subplots_adjust(right=0.85, left=0.05, top=0.92, bottom=0.1, wspace=0.1, hspace=0.1)

    # Create a dummy Axes to hold the common colorbar
    cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.7])  # Adjust these values to fit the layout

    for i, ax in enumerate(axes.flatten()):
        if i < num_plots:
            heatmap = tensor[start_index + i].detach().numpy()  # Assuming tensor is a PyTorch tensor
            im = ax.imshow(heatmap, cmap='viridis', vmin=vmin, vmax=vmax)
            ax.set_title(f'Channel {start_index + i + 1}')
            ax.axis('off')
        else:
            ax.axis('off')  # Turn off axis for unused subplot

    # Create a colorbar for the whole figure
    fig.colorbar(im, cax=cbar_ax)

    # Set the overall title, if provided
    if title:
        fig.suptitle(title, size=20, y=0.98)  # Adjust the vertical position of the title

    plt.show()


In [None]:
# Selct Image
image_idx = 1
feature = data_inv[image_idx, diffeo_idx, :, :, :]
label = ref_data[image_idx, :, :, :]

approx_identity = torch.einsum('ab,axy->bxy', A, feature)

plot(approx_identity, 
     title='$h^{-1} g^{-1}_{naive} N_i(g \cdot I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(feature - label, 
     title = '$g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(approx_identity - label, 
     title = '$h^{-1} g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

## Results

### On Trained Image

In [None]:
# Selct Image
layer_idx = layer_idx
feature = data_inv[layer_idx, diffeo_idx, :, :, :]
label = ref_data[layer_idx, :, :, :]

approx_identity = torch.einsum('ab,axy->bxy', A, feature)

plot(approx_identity, 
     title='$h^{-1} g^{-1}_{naive} N_i(g \cdot I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(feature - label, 
     title = '$g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(approx_identity - label, 
     title = '$h^{-1} g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
error_per_pixel = lambda y_hat, y: torch.sum(torch.abs(y_hat - y)) / torch.numel(y)

print(f'For Layer {layer_idx + 14}')
print(f'Avg Error Per Pixel, NOT using h^-1: {float(error_per_pixel(feature, label)):.7f}')
print(f'Avg Error Per Pixel, using h^-1: {float(error_per_pixel(approx_identity, label)):.7f}')

### Do Results Generalize Beyond Image

In [None]:
next_pic = f"{int(target_pic) + 10:04}"
data, data_inv, ref_data = LOAD_all(next_pic, layer_numbers)

In [None]:
# Selct Image
layer_idx = layer_idx
feature = data_inv[layer_idx, diffeo_idx, :, :, :]
label = ref_data[layer_idx, :, :, :]

approx_identity = torch.einsum('ab,axy->bxy', A, feature)

plot(approx_identity, 
     title='$h^{-1} g^{-1}_{naive} N_i(g \cdot I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(feature - label, 
     title = '$g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
plot(approx_identity - label, 
     title = '$h^{-1} g_{naive}^{-1} N_i(g \cdot I) - N(I)$',
     vmin=-12,
     vmax=12)

In [None]:
error_per_pixel = lambda y_hat, y: torch.sum(torch.abs(y_hat - y)) / torch.numel(y)

print(f'For Layer {layer_idx + 14}')
print(f'Avg Error Per Pixel, NOT using h^-1: {float(error_per_pixel(feature, label)):.7f}')
print(f'Avg Error Per Pixel, using h^-1: {float(error_per_pixel(approx_identity, label)):.7f}')