In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from PIL import Image
import os

# added imports
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torchvision.transforms import functional as TF
import matplotlib.pyplot as plt
from pytorch_msssim import ssim

In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from model import FusionDenoiser

In [5]:
torch.manual_seed(43)
np.random.seed(43)

In [6]:
from dataset import FocalDataset

In [7]:
ds = FocalDataset(input_channels=1, output_channels=3, augment=True, normalize=True)

In [8]:
stack, gt = ds[2]

In [9]:
stack.shape

torch.Size([8, 1, 512, 512])

In [10]:
gt.shape

torch.Size([3, 512, 512])

In [11]:
# fig, axes = plt.subplots(1, 3, figsize=(18,10), sharey=True)

# axes[0].imshow(np.moveaxis(stack[0].numpy(), 0,-1), cmap='gray')
# axes[1].imshow(np.moveaxis(stack[1].numpy(), 0,-1), cmap='gray')
# axes[2].imshow(np.moveaxis(gt.numpy(), 0,-1), cmap='gray')

In [12]:
# normalization constants used by pretrained IFCNN
# only really needed when using pretrained weights
# may be a good idea to adapt IFCNN to be similar to SwinIR in the way it works with image pixel range
mean = [0.485, 0.456, 0.406] #[0.46]*3
std = [0.229, 0.224, 0.225] #[0.225]*3

In [13]:
trial_mode = 2  # random, lytro, integrals

In [14]:
# Source: https://discuss.pytorch.org/t/how-to-add-noise-to-mnist-dataset-when-using-pytorch/59745/2
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [15]:
if trial_mode == 0:
    inp = torch.randn(16, 4, 1, 128, 128)
    inp = inp.repeat(1, 1, 3, 1, 1) # repeat gray channels as faux rgb
elif trial_mode == 1:
    trans = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Resize((128, 128), antialias=True),
    ] + ([AddGaussianNoise(std=0.015)] if trial_mode == 1 else []) + [   # Add noise to lytro images to see SwinIR effect
        transforms.Normalize(mean=mean, std=std),
    ])
    
    if trial_mode == 1:
        paths = glob('./lytro/*.jpg')
        
    inp = torch.stack([trans(Image.open(p)) for p in paths])[None, :]
elif trial_mode == 2:
    inp = transforms.Resize((128, 128), antialias=True)(ds[5][0])[None, :]
    
# inp = inp.cuda()
inp = inp.to(device)
inp.shape  # batch, n_focal_lengths, channels, dim1, dim2

torch.Size([1, 8, 1, 128, 128])

In [16]:
# plt.imshow(inp[0,0,0].cpu(), cmap='gray')

In [17]:
inp.mean().item(), inp.std().item()

(0.38773103657223723, 0.12404649765057327)

## Outline

Below one can see the **updated steps** of
- Loading the model using a cpu (should also work for gpu altough that wasn't tested). For that also the model.py file was updated.
- Defining and using the evaluation metrics, so psnr and ssim for somewhat arbitrary inputs. I have added some sample calls of the functions and used them in the training loop to give insight into how they can be used.
- Model training. Since I don't have a strong machine available, I had to stop the training early. I hope after updating the target values, the training works out fine.

**Attention:** 
- I have used a pseudo-target and not the target we have to use. This is relevant for the CustomDataset() class since that one has to be updated with the correct target.
- There is no train/validation-split. I just copied the validation set with the training set for simplicity, since we don't have data.

### Create model

Get missing file from here: https://github.com/uzeful/IFCNN/blob/master/Code/snapshots/IFCNN-MAX.pth

In [18]:
# limited to img_size=128 when using pretrained=True, as SwinIR does not have pretrained weights for that size
# model = FusionDenoiser(img_size=128, swin_version='V2', use_checkpoint=True, pretrained=True)
model = FusionDenoiser(img_size=128, swin_version='V2', use_checkpoint=True, pretrained=True)
model = model.to(device).eval()

No pretrained weights available for grayscale Swin V2


### Create evaluation metrics

In [19]:
inp.shape

torch.Size([1, 8, 1, 128, 128])

In [20]:
# the model uses 3 color channels, so we need to repeat the one grayscale dimension 3 times
inp = inp.repeat(1, 1, 3, 1, 1)

In [21]:
inp = inp.float()
output = model(inp)

In [22]:
scaled_output = tuple(o * 0.98 for o in output)

In [23]:
def calculate_psnr_tensor(target, prediction):
    mse = torch.mean((target - prediction) ** 2)
    psnr = 20 * torch.log10(torch.tensor(255.0)) - 10 * torch.log10(mse)
    return psnr.item()  # Convert to Python scalar

psnr_values = [calculate_psnr_tensor(o, s) for o, s in zip(output, scaled_output)]
# ssim_values = [calculate_ssim_tensor(o, s) for o, s in zip(output, scaled_output)]

average_psnr = sum(psnr_values) / len(psnr_values)
# average_ssim = sum(ssim_values) / len(ssim_values)

print("Average PSNR:", average_psnr)
# print("Average SSIM:", average_ssim)

Average PSNR: 39.614797592163086


In [24]:
def calculate_ssim_tensor(target, prediction):
    # ensure the input tensors are float32, as required by pytorch_msssim
    target = target.type(torch.float32)
    prediction = prediction.type(torch.float32)

    # the ssim function expects tensors in the shape of (batch, channel, height, width)
    # ensure your tensors are correctly shaped
    ssim_value = ssim(target, prediction, data_range=255, size_average=True)  # size_average to return the average of SSIM
    return ssim_value.item()  # make it python scalar

# Example usage
ssim_values = [calculate_ssim_tensor(o, s) for o, s in zip(output, scaled_output)]
average_ssim = sum(ssim_values) / len(ssim_values)

print("Average SSIM:", average_ssim)

Average SSIM: 0.999765545129776


### Train model

In [25]:
model = model.float()

In [26]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

In [27]:
# This cell needs to be adjusted for our problem and redone.

class CustomDataset(Dataset):
    def __init__(self, data, model):
        # Flatten the first two dimensions
        self.data = data.view(-1, 1, 128, 128)
        self.model = model

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        '''
            Attention: I have made up the target! The CustomDataset needs to be transformed if
            we want to use it for the project.
        '''
        inp = self.data[idx].repeat(1, 3, 1, 1)  # Convert to 3-channel image
        inp = inp.unsqueeze(0)  # Add batch dimension
        with torch.no_grad():
            model_output = self.model(inp)
            if isinstance(model_output, (list, tuple)):
                target = model_output[0]  # Assuming you want to use the first output as the target
            else:
                target = model_output
            target = target * 0.9
        return inp.squeeze(0), target.squeeze(0)  # Remove batch dimension for DataLoader compatibility

# Assuming model and inp are already defined
dataset = CustomDataset(inp, model)

# Creating the same dataset for both training and validation
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [28]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5

In [30]:
# lists to store metrics for plotting
train_losses, val_losses = [], []
train_psnrs, val_psnrs = [], []
train_ssims, val_ssims = [], []

for epoch in range(num_epochs):
    model.train()
    train_loss, train_psnr, train_ssim = 0.0, 0.0, 0.0

    for inp, target in train_loader:
        optimizer.zero_grad()
        model_output = model(inp)
        if isinstance(model_output, tuple):
            model_output = model_output[0]

        loss = criterion(model_output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_psnr += calculate_psnr_tensor(target, model_output)
        train_ssim += calculate_ssim_tensor(target, model_output)
        
        print('training loop (early stopping) done')
        break   # this needs to be removed and guarantees us that we only train on the first sample, not all.

    # average over the epoch and store metrics (training)
    train_loss /= len(train_loader)
    train_psnr /= len(train_loader)
    train_ssim /= len(train_loader)
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    train_ssims.append(train_ssim)

    model.eval()
    val_loss, val_psnr, val_ssim = 0.0, 0.0, 0.0
    with torch.no_grad():
        for inp, target in val_loader:
            model_output = model(inp)
            if isinstance(model_output, tuple):
                model_output = model_output[0]

            loss = criterion(model_output, target)
            val_loss += loss.item()
            val_psnr += calculate_psnr_tensor(target, model_output)
            val_ssim += calculate_ssim_tensor(target, model_output)
            
            print('validation loop (early stopping) done')
            break   # this needs to be removed and guarantees us that we only evaluate on the first sample, not all.
        
    # average over the epoch and store metrics (validation)
    val_loss /= len(val_loader)
    val_psnr /= len(val_loader)
    val_ssim /= len(val_loader)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    val_ssims.append(val_ssim)

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}, Training PSNR: {train_psnr}, Training SSIM: {train_ssim}, Validation PSNR: {val_psnr}, Validation SSIM: {val_ssim}")

training loop (early stopping) done
validation loop (early stopping) done
Epoch 1/5, Training Loss: 7.80271848042806, Validation Loss: 6.423454920450847, Training PSNR: 1.058592955271403, Training SSIM: 0.041258071859677635, Validation PSNR: 1.0937918027242024, Validation SSIM: 0.041403998931248985
training loop (early stopping) done
validation loop (early stopping) done
Epoch 2/5, Training Loss: 7.269136428833008, Validation Loss: 5.911933898925781, Training PSNR: 1.0714109738667805, Training SSIM: 0.04125023384888967, Validation PSNR: 1.1088080406188965, Validation SSIM: 0.04140225052833557
training loop (early stopping) done
validation loop (early stopping) done
Epoch 3/5, Training Loss: 6.770402272542317, Validation Loss: 5.586081822713216, Training PSNR: 1.0842727025349934, Training SSIM: 0.041259231666723885, Validation PSNR: 1.1190673510233562, Validation SSIM: 0.04140132168928782
training loop (early stopping) done
validation loop (early stopping) done
Epoch 4/5, Training Loss:

In [2]:
# add plotting here

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18,10), sharey=True)

axes[0].imshow(inp[0,0,0].cpu(), cmap='gray')
axes[1].imshow(fused[0,0].cpu(), cmap='gray')
axes[2].imshow(denoised[0,0].cpu(), cmap='gray')

axes[0].set_title('One of the inputs')
axes[1].set_title('Fused')
axes[2].set_title('Denoised')

NameError: name 'fused' is not defined

# TODO

- generate AOS integral data
- Dataset with augmentation (random 0°, 90°, 180°, 270° rotation + random flip, maybe random crop, maybe noise). Data seems to be quite abundant, so just rotation and flips may be sufficient
- evaluation function
- training loop with logging, validation evaluation and checkpointing
- self-ensemble for final predictions (not used for training)

In [None]:
j = np.arange(2*3*1*2).reshape(2, 3, 1, 2)
j

In [None]:
num_rots = np.random.randint(4)
j = np.rot90(j, num_rots, (-3, -2))

if np.random.rand() > 0.5:
    j = np.flip(j, -2)

if np.random.rand() > 0.5:
    j = np.flip(j, -3)

In [None]:
np.repeat(j, 3, -2)