In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

from torch.utils.data import Dataset, DataLoader
import glob

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            # final_linear = nn.Sigmoid()
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
            self.net.append(nn.Sigmoid())
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output

In [None]:
class DIV2K_valid_HR_dataset(Dataset):
    def __init__(self, path, sidelength):
        super().__init__()
        self.img_path_list = sorted(glob.glob(path + "/*.png"))[:32]
        self.transform = Compose([
            Resize((sidelength, sidelength)),
            ToTensor()
        ])

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        img = Image.open(img_path)
        img = self.transform(img)
        self.pixels = img.permute(1, 2, 0).view(-1, 3)
        return self.pixels

In [None]:
dataset = DIV2K_valid_HR_dataset(path="DIV2K_valid_HR", sidelength=512)
dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)

In [None]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing="xy"), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    tensor = tensor * 256
    tensor[tensor > 255] = 255
    tensor[tensor < 0] = 0
    tensor = tensor.type(torch.uint8).permute(1, 2, 0).cpu().numpy()
    return tensor

In [None]:
# https://cvnote.ddlee.cc/2019/09/12/psnr-ssim-python

import math
def calculate_psnr(img1, img2):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

# https://cvnote.ddlee.cc/2019/09/12/psnr-ssim-python
import math
import numpy as np
import cv2

def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(img1, img2):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

In [None]:
psnr_dict = {}
ssim_dict = {}
for ind, target in enumerate(tqdm(dataloader)):
    siren_model = Siren(in_features=2, out_features=3, hidden_features=256, 
                  hidden_layers=3, outermost_linear=True).to(device)

    optimizer = torch.optim.Adam(lr=1e-4, params=siren_model.parameters())

    xy_grid = get_mgrid(sidelen=512, dim=2).unsqueeze(0).to(device)

    target = target.to(device)

    for step in range(2000):
        optimizer.zero_grad()

        generated = siren_model(xy_grid)

        loss = ((generated - target)**2).mean()

        loss.backward()
        optimizer.step()

        # if step % 100 == 0:
        #     print("Step %d, Total loss %0.6f" % (step, loss))
        #     fig, axes = plt.subplots(1,2, figsize=(12,6))
        #     axes[0].imshow(tensor_to_numpy(target[0].cpu().view(512,512,3).permute(2,0,1)))
        #     axes[1].imshow(tensor_to_numpy(generated[0].cpu().view(512,512,3).permute(2,0,1)))
        #     plt.show()

    psnr_dict[ind] = calculate_psnr(tensor_to_numpy(target[0].view(3,512,512).detach()), tensor_to_numpy(generated[0].view(3,512,512).detach()))
    ssim_dict[ind] = calculate_ssim(tensor_to_numpy(target[0].view(3,512,512).detach()), tensor_to_numpy(generated[0].view(3,512,512).detach()))
    print(psnr_dict[ind], ssim_dict[ind])

In [None]:
import torchsummary
torchsummary.summary(siren_model, (262144, 2))

In [None]:
psnr_dict

In [None]:
ssim_dict

In [None]:
psnr_total = 0
for val in psnr_dict.values():
    psnr_total += val
psnr_average = psnr_total / len(psnr_dict)
psnr_average

In [None]:
ssim_total = 0
for val in ssim_dict.values():
    ssim_total += val

ssim_average = ssim_total / len(ssim_dict)
ssim_average

In [None]:
import json

with open('baseline_SIREN_psnr_dict.json', 'w') as fp:
    json.dump(psnr_dict, fp)
with open('baseline_SIREN_ssim_dict.json', 'w') as fp:
    json.dump(ssim_dict, fp)