In [1]:
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast

import torch.nn.functional as F
from torch.nn import Conv2d, BatchNorm2d, ReLU, Sigmoid


if torch.cuda.is_available():
    print("CUDA is available. PyTorch is running on the GPU.")
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available. PyTorch is running on the CPU.")

CUDA is available. PyTorch is running on the GPU.
Device name: Quadro RTX 8000


# Import Data

In [3]:
data_dir = '/vast/xj2173/diffeo/data/temp_inv_data/'
numbers = [f"{i:02}" for i in range(40)]
data_name = [data_dir + f'15-100-4-4-3-224-224_image-00{i}_activation_layer-09.pt' for i in numbers]

In [4]:
# data[0] is the 0th picture, data[1] is the 1st picture, etc..
data = [torch.load(file_name, map_location='cpu') for file_name in tqdm(data_name)]
data = torch.stack(data, dim=0)
data.shape

100%|██████████| 40/40 [01:09<00:00,  1.75s/it]


torch.Size([40, 1500, 48, 56, 56])

- 40 different images
- 1500 different diffeos
    - 15 diffeo strengths
    - 100 individual diffeos per diffeo strength
- 48 channels
- 56 x 56 image

In [14]:
target_channel = 1
num_of_features_cutoff = 1500

features = data[:, 1:, target_channel, :, :]
pics, diffeo, size, _ = features.shape
features = features.permute(1, 0, 2, 3)
features = features.reshape(pics * diffeo, size, size)
features = features[:num_of_features_cutoff, :, :]

class NaiveInverseDiffeo(Dataset):
    def __init__(self, transform=None):
        self.label = data[target_picture, 0, target_channel, :, :]
        # self.num_pics = pics
        
    def __len__(self):
        return int(features.shape[0])

    def __getitem__(self, idx):
        # num_pics = self.num_pics
        label = self.label
        feature = features[idx]
        
        return feature, label


dataset = NaiveInverseDiffeo()

In [15]:
features[0].shape

torch.Size([56, 56])

# Objective

We've experimentally confirmed:
$$
g_{naive}^{-1} \star N_i (g \cdot I) = N_i(I) + \eta_i
$$
Where $g_{naive}^{-1}$ is just the inverse of $g$ using the $\cdot$ algebra. So basically, it looks like the original but w/ some noise $\eta_i$. 

Now the objective is to see if we can learn $\eta_i$. Or rewritten in another way
$$
h^{-1} \star g_{naive}^{-1} \star  N_i (g \cdot I) = N_i(I)
$$
We want to learn $h^{-1}$, s.t. the 2nd equation is true


# Learning $h^{-1}$

- 40 different images
- 15 diffeo strengths
- 100 individual diffeos per diffeo strength
- 48 channels
- 56 x 56 image

### Split Data

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [17]:
split = 0.8
batch_size = 25

train_size = int(split * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(dataset=val_dataset,   batch_size=batch_size, shuffle=True)

### Initalize Model

In [65]:
class Model0(nn.Module):
    def __init__(self, batch_size):
        super(Model0, self).__init__()
        self.batch_size = batch_size  # Store batch_size if needed for some reason
        
        self.conv1 = nn.Conv2d(self.batch_size, 16, kernel_size=3, stride=1, padding=1)  
        self.conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(1568, 100)  # Adjust the input size according to the output size after pooling
        self.fc2 = nn.Linear(100, 3136)

    def forward(self, x):
        # Convoultion Layer
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)  # Apply 2x2 max pooling
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)  # Apply 2x2 max pooling

        # MLP
        x = torch.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.reshape(1, 56, 56)
        return x


In [None]:
learning_rate = 0.001
num_epochs = 10000

model = Model0(batch_size)
model = model.to(device)

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum = 0.4)
scaler = GradScaler()

print("Starting training...")
for epoch in tqdm(range(num_epochs)):
    model.train()  # Set the model to training mode
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Move inputs and targets to the GPU
        inputs, targets = inputs.to(device), targets.to(device)
    
        with autocast():
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        optimizer.zero_grad()  # Zero the gradient buffers
        scaler.scale(loss).backward()  # Backpropagation
        scaler.step(optimizer)  # Update weights
        scaler.update()

    if (epoch + 1) % 25 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')


Starting training...


  0%|          | 26/10000 [00:04<28:05,  5.92it/s]

Epoch [25/10000], Loss: 27.0307


  1%|          | 51/10000 [00:08<29:49,  5.56it/s]

Epoch [50/10000], Loss: 25.3266


  1%|          | 76/10000 [00:13<27:18,  6.06it/s]

Epoch [75/10000], Loss: 24.2716


  1%|          | 101/10000 [00:17<29:43,  5.55it/s]

Epoch [100/10000], Loss: 24.2808


  1%|▏         | 126/10000 [00:21<29:26,  5.59it/s]

Epoch [125/10000], Loss: 25.2602


  2%|▏         | 151/10000 [00:26<28:47,  5.70it/s]

Epoch [150/10000], Loss: 23.5202


  2%|▏         | 176/10000 [00:30<29:52,  5.48it/s]

Epoch [175/10000], Loss: 25.1893


  2%|▏         | 201/10000 [00:34<27:26,  5.95it/s]

Epoch [200/10000], Loss: 24.1025


  2%|▏         | 226/10000 [00:39<26:52,  6.06it/s]

Epoch [225/10000], Loss: 23.0251


  3%|▎         | 251/10000 [00:43<26:48,  6.06it/s]

Epoch [250/10000], Loss: 24.8919


  3%|▎         | 276/10000 [00:47<28:44,  5.64it/s]

Epoch [275/10000], Loss: 25.2920


  3%|▎         | 301/10000 [00:52<27:35,  5.86it/s]

Epoch [300/10000], Loss: 25.2624


  3%|▎         | 326/10000 [00:56<28:04,  5.74it/s]

Epoch [325/10000], Loss: 24.0748


  4%|▎         | 351/10000 [01:01<28:43,  5.60it/s]

Epoch [350/10000], Loss: 24.6176


  4%|▍         | 376/10000 [01:05<28:32,  5.62it/s]

Epoch [375/10000], Loss: 24.4556


  4%|▍         | 401/10000 [01:10<27:54,  5.73it/s]

Epoch [400/10000], Loss: 23.2515


  4%|▍         | 426/10000 [01:14<28:50,  5.53it/s]

Epoch [425/10000], Loss: 23.6945


  5%|▍         | 451/10000 [01:19<28:28,  5.59it/s]

Epoch [450/10000], Loss: 25.6904


  5%|▍         | 476/10000 [01:23<26:27,  6.00it/s]

Epoch [475/10000], Loss: 24.2744


  5%|▌         | 501/10000 [01:27<28:10,  5.62it/s]

Epoch [500/10000], Loss: 24.6097


  5%|▌         | 526/10000 [01:32<27:34,  5.72it/s]

Epoch [525/10000], Loss: 24.0893


  6%|▌         | 551/10000 [01:36<27:44,  5.68it/s]

Epoch [550/10000], Loss: 26.0782


  6%|▌         | 576/10000 [01:40<27:46,  5.65it/s]

Epoch [575/10000], Loss: 23.7773


  6%|▌         | 601/10000 [01:45<26:20,  5.95it/s]

Epoch [600/10000], Loss: 24.0634


  6%|▋         | 626/10000 [01:49<27:01,  5.78it/s]

Epoch [625/10000], Loss: 26.3283


  7%|▋         | 651/10000 [01:53<27:08,  5.74it/s]

Epoch [650/10000], Loss: 23.7246


  7%|▋         | 676/10000 [01:58<27:13,  5.71it/s]

Epoch [675/10000], Loss: 24.2288


  7%|▋         | 701/10000 [02:02<27:06,  5.72it/s]

Epoch [700/10000], Loss: 25.6198


  7%|▋         | 726/10000 [02:06<26:30,  5.83it/s]

Epoch [725/10000], Loss: 24.4263


  8%|▊         | 751/10000 [02:11<26:54,  5.73it/s]

Epoch [750/10000], Loss: 22.5718


  8%|▊         | 776/10000 [02:15<25:41,  5.98it/s]

Epoch [775/10000], Loss: 25.2752


  8%|▊         | 801/10000 [02:19<26:56,  5.69it/s]

Epoch [800/10000], Loss: 24.0190


  8%|▊         | 826/10000 [02:24<26:44,  5.72it/s]

Epoch [825/10000], Loss: 24.2374


  9%|▊         | 851/10000 [02:28<28:08,  5.42it/s]

Epoch [850/10000], Loss: 26.5488


  9%|▉         | 876/10000 [02:33<27:38,  5.50it/s]

Epoch [875/10000], Loss: 25.0820


  9%|▉         | 901/10000 [02:37<25:42,  5.90it/s]

Epoch [900/10000], Loss: 24.5311


  9%|▉         | 926/10000 [02:41<26:52,  5.63it/s]

Epoch [925/10000], Loss: 23.3288


 10%|▉         | 951/10000 [02:46<25:57,  5.81it/s]

Epoch [950/10000], Loss: 26.2496


 10%|▉         | 976/10000 [02:50<24:47,  6.07it/s]

Epoch [975/10000], Loss: 24.5677


 10%|█         | 1001/10000 [02:54<27:08,  5.52it/s]

Epoch [1000/10000], Loss: 23.6407


 10%|█         | 1026/10000 [02:59<27:17,  5.48it/s]

Epoch [1025/10000], Loss: 24.3482


 11%|█         | 1051/10000 [03:03<26:32,  5.62it/s]

Epoch [1050/10000], Loss: 22.7020


 11%|█         | 1076/10000 [03:08<26:29,  5.61it/s]

Epoch [1075/10000], Loss: 23.5180


 11%|█         | 1101/10000 [03:12<26:32,  5.59it/s]

Epoch [1100/10000], Loss: 22.3236


 11%|█▏        | 1126/10000 [03:17<26:21,  5.61it/s]

Epoch [1125/10000], Loss: 24.0832


 12%|█▏        | 1151/10000 [03:21<24:26,  6.03it/s]

Epoch [1150/10000], Loss: 23.8781


 12%|█▏        | 1176/10000 [03:25<25:54,  5.68it/s]

Epoch [1175/10000], Loss: 22.9316


 12%|█▏        | 1201/10000 [03:29<26:29,  5.53it/s]

Epoch [1200/10000], Loss: 23.9436


 12%|█▏        | 1226/10000 [03:34<25:02,  5.84it/s]

Epoch [1225/10000], Loss: 24.3404


 13%|█▎        | 1251/10000 [03:38<25:47,  5.65it/s]

Epoch [1250/10000], Loss: 25.8432


 13%|█▎        | 1276/10000 [03:43<26:36,  5.46it/s]

Epoch [1275/10000], Loss: 25.0474


 13%|█▎        | 1301/10000 [03:47<25:46,  5.62it/s]

Epoch [1300/10000], Loss: 22.9803


 13%|█▎        | 1326/10000 [03:52<25:34,  5.65it/s]

Epoch [1325/10000], Loss: 24.2588


 14%|█▎        | 1351/10000 [03:56<26:04,  5.53it/s]

Epoch [1350/10000], Loss: 24.6399


 14%|█▍        | 1376/10000 [04:00<25:13,  5.70it/s]

Epoch [1375/10000], Loss: 25.4635


 14%|█▍        | 1401/10000 [04:05<25:27,  5.63it/s]

Epoch [1400/10000], Loss: 25.9594


 14%|█▍        | 1426/10000 [04:09<25:52,  5.52it/s]

Epoch [1425/10000], Loss: 24.5203


 15%|█▍        | 1451/10000 [04:14<25:06,  5.67it/s]

Epoch [1450/10000], Loss: 23.6262


 15%|█▍        | 1476/10000 [04:18<24:43,  5.75it/s]

Epoch [1475/10000], Loss: 24.0044


 15%|█▌        | 1501/10000 [04:22<25:20,  5.59it/s]

Epoch [1500/10000], Loss: 25.3960


 15%|█▌        | 1526/10000 [04:27<24:50,  5.69it/s]

Epoch [1525/10000], Loss: 24.2039


 16%|█▌        | 1551/10000 [04:31<25:13,  5.58it/s]

Epoch [1550/10000], Loss: 25.3396


 16%|█▌        | 1576/10000 [04:36<25:21,  5.54it/s]

Epoch [1575/10000], Loss: 25.3265


 16%|█▌        | 1601/10000 [04:40<25:34,  5.47it/s]

Epoch [1600/10000], Loss: 24.2019


 16%|█▋        | 1626/10000 [04:45<24:29,  5.70it/s]

Epoch [1625/10000], Loss: 23.8895


 17%|█▋        | 1651/10000 [04:49<24:45,  5.62it/s]

Epoch [1650/10000], Loss: 24.0799


 17%|█▋        | 1676/10000 [04:54<26:03,  5.32it/s]

Epoch [1675/10000], Loss: 22.8501


 17%|█▋        | 1701/10000 [04:58<24:33,  5.63it/s]

Epoch [1700/10000], Loss: 23.0749


 17%|█▋        | 1726/10000 [05:03<25:09,  5.48it/s]

Epoch [1725/10000], Loss: 24.5341


 18%|█▊        | 1751/10000 [05:07<23:35,  5.83it/s]

Epoch [1750/10000], Loss: 24.7382


 18%|█▊        | 1776/10000 [05:12<24:04,  5.69it/s]

Epoch [1775/10000], Loss: 23.2900


 18%|█▊        | 1801/10000 [05:16<24:05,  5.67it/s]

Epoch [1800/10000], Loss: 24.7882


 18%|█▊        | 1826/10000 [05:20<24:06,  5.65it/s]

Epoch [1825/10000], Loss: 24.3045


 19%|█▊        | 1851/10000 [05:25<25:25,  5.34it/s]

Epoch [1850/10000], Loss: 24.8437


 19%|█▉        | 1876/10000 [05:29<23:35,  5.74it/s]

Epoch [1875/10000], Loss: 26.1230


 19%|█▉        | 1901/10000 [05:34<24:23,  5.53it/s]

Epoch [1900/10000], Loss: 24.4710


 19%|█▉        | 1926/10000 [05:38<23:57,  5.62it/s]

Epoch [1925/10000], Loss: 23.9836


 20%|█▉        | 1951/10000 [05:43<24:33,  5.46it/s]

Epoch [1950/10000], Loss: 25.4260


 20%|█▉        | 1976/10000 [05:47<23:53,  5.60it/s]

Epoch [1975/10000], Loss: 24.9317


 20%|██        | 2001/10000 [05:52<23:41,  5.63it/s]

Epoch [2000/10000], Loss: 23.1841


 20%|██        | 2026/10000 [05:56<24:10,  5.50it/s]

Epoch [2025/10000], Loss: 23.8828


 21%|██        | 2051/10000 [06:00<23:35,  5.62it/s]

Epoch [2050/10000], Loss: 22.9248


 21%|██        | 2076/10000 [06:05<23:45,  5.56it/s]

Epoch [2075/10000], Loss: 25.9835


 21%|██        | 2101/10000 [06:09<22:56,  5.74it/s]

Epoch [2100/10000], Loss: 24.9057


 21%|██▏       | 2126/10000 [06:14<23:10,  5.66it/s]

Epoch [2125/10000], Loss: 24.8962


 22%|██▏       | 2151/10000 [06:18<24:23,  5.36it/s]

Epoch [2150/10000], Loss: 26.3830


 22%|██▏       | 2176/10000 [06:23<24:00,  5.43it/s]

Epoch [2175/10000], Loss: 24.4479


 22%|██▏       | 2201/10000 [06:27<23:46,  5.47it/s]

Epoch [2200/10000], Loss: 24.5233


 22%|██▏       | 2226/10000 [06:31<24:04,  5.38it/s]

Epoch [2225/10000], Loss: 25.6216


 23%|██▎       | 2251/10000 [06:36<22:53,  5.64it/s]

Epoch [2250/10000], Loss: 27.1541


 23%|██▎       | 2276/10000 [06:40<22:17,  5.78it/s]

Epoch [2275/10000], Loss: 24.4176


 23%|██▎       | 2301/10000 [06:45<22:32,  5.69it/s]

Epoch [2300/10000], Loss: 24.4873


 23%|██▎       | 2326/10000 [06:49<22:12,  5.76it/s]

Epoch [2325/10000], Loss: 23.4229


 24%|██▎       | 2351/10000 [06:53<22:25,  5.68it/s]

Epoch [2350/10000], Loss: 25.9259


 24%|██▍       | 2376/10000 [06:58<22:30,  5.64it/s]

Epoch [2375/10000], Loss: 24.5369


 24%|██▍       | 2401/10000 [07:02<23:21,  5.42it/s]

Epoch [2400/10000], Loss: 23.6148


 24%|██▍       | 2426/10000 [07:07<22:13,  5.68it/s]

Epoch [2425/10000], Loss: 25.4628


 25%|██▍       | 2451/10000 [07:11<21:49,  5.77it/s]

Epoch [2450/10000], Loss: 23.5991


 25%|██▍       | 2476/10000 [07:16<22:13,  5.64it/s]

Epoch [2475/10000], Loss: 24.3066


 25%|██▌       | 2501/10000 [07:20<21:28,  5.82it/s]

Epoch [2500/10000], Loss: 23.4688


 25%|██▌       | 2526/10000 [07:24<20:33,  6.06it/s]

Epoch [2525/10000], Loss: 27.5116


 26%|██▌       | 2551/10000 [07:29<22:48,  5.44it/s]

Epoch [2550/10000], Loss: 26.0750


 26%|██▌       | 2576/10000 [07:33<22:00,  5.62it/s]

Epoch [2575/10000], Loss: 26.4256


 26%|██▌       | 2601/10000 [07:38<21:26,  5.75it/s]

Epoch [2600/10000], Loss: 23.8610


 26%|██▋       | 2626/10000 [07:42<22:17,  5.51it/s]

Epoch [2625/10000], Loss: 23.6979


 27%|██▋       | 2651/10000 [07:47<21:43,  5.64it/s]

Epoch [2650/10000], Loss: 23.8115


 27%|██▋       | 2676/10000 [07:51<21:26,  5.70it/s]

Epoch [2675/10000], Loss: 24.1046


 27%|██▋       | 2701/10000 [07:56<22:05,  5.51it/s]

Epoch [2700/10000], Loss: 27.8465


 27%|██▋       | 2726/10000 [08:00<22:03,  5.50it/s]

Epoch [2725/10000], Loss: 23.9275


 28%|██▊       | 2751/10000 [08:04<21:23,  5.65it/s]

Epoch [2750/10000], Loss: 24.3877


 28%|██▊       | 2776/10000 [08:09<21:02,  5.72it/s]

Epoch [2775/10000], Loss: 23.7199


 28%|██▊       | 2801/10000 [08:13<20:28,  5.86it/s]

Epoch [2800/10000], Loss: 23.6340


 28%|██▊       | 2826/10000 [08:18<21:18,  5.61it/s]

Epoch [2825/10000], Loss: 25.3749


 29%|██▊       | 2851/10000 [08:22<21:09,  5.63it/s]

Epoch [2850/10000], Loss: 24.3267


 29%|██▉       | 2876/10000 [08:27<21:28,  5.53it/s]

Epoch [2875/10000], Loss: 24.2698


 29%|██▉       | 2901/10000 [08:31<20:11,  5.86it/s]

Epoch [2900/10000], Loss: 23.2799


 29%|██▉       | 2926/10000 [08:35<19:40,  5.99it/s]

Epoch [2925/10000], Loss: 26.5897


 30%|██▉       | 2951/10000 [08:40<20:00,  5.87it/s]

Epoch [2950/10000], Loss: 25.7073


 30%|██▉       | 2976/10000 [08:44<20:55,  5.59it/s]

Epoch [2975/10000], Loss: 25.2314


 30%|███       | 3001/10000 [08:49<20:01,  5.83it/s]

Epoch [3000/10000], Loss: 24.6850


 30%|███       | 3026/10000 [08:53<20:09,  5.77it/s]

Epoch [3025/10000], Loss: 25.7481


 31%|███       | 3051/10000 [08:57<20:41,  5.60it/s]

Epoch [3050/10000], Loss: 24.1065


 31%|███       | 3076/10000 [09:02<20:02,  5.76it/s]

Epoch [3075/10000], Loss: 22.6569


 31%|███       | 3101/10000 [09:06<18:57,  6.07it/s]

Epoch [3100/10000], Loss: 24.6359


 31%|███▏      | 3126/10000 [09:10<20:18,  5.64it/s]

Epoch [3125/10000], Loss: 23.3497


 32%|███▏      | 3151/10000 [09:14<19:56,  5.72it/s]

Epoch [3150/10000], Loss: 24.2688


 32%|███▏      | 3176/10000 [09:19<19:44,  5.76it/s]

Epoch [3175/10000], Loss: 23.4404


 32%|███▏      | 3201/10000 [09:23<18:57,  5.98it/s]

Epoch [3200/10000], Loss: 25.1433


 32%|███▏      | 3226/10000 [09:27<19:19,  5.84it/s]

Epoch [3225/10000], Loss: 25.8716


 33%|███▎      | 3251/10000 [09:32<19:27,  5.78it/s]

Epoch [3250/10000], Loss: 23.8516


 33%|███▎      | 3276/10000 [09:36<20:02,  5.59it/s]

Epoch [3275/10000], Loss: 24.2992


 33%|███▎      | 3292/10000 [09:39<19:57,  5.60it/s]

Things I've tried...
- Changed the minimization method
  - Adam
  - SGD (w/ various momentum: `0.1`, `0.4`)

In [66]:
class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(48, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.ReLU())
        
        self.final_conv = nn.Conv2d(48, 48, kernel_size=1)  # Maintain the channel dimension
        
        self.fc1 = nn.Linear(48 * 56 * 56, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 48 * 56 * 56)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.final_conv(out)
        
        out = out.view(out.size(0), -1)  # Flatten the tensor
        out = nn.ReLU()(self.fc1(out))
        out = nn.ReLU()(self.fc2(out))
        out = self.fc3(out)
        
        out = out.view(out.size(0), 48, 56, 56)  # Reshape to the original shape
        return out