In [1]:
import os, sys
import math
import pathlib
from functools import partial

import torch
from torch import nn
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from model import *

from kornia.filters import laplacian, spatial_gradient

In [2]:
torch.manual_seed(2)

### VARIOUS INIT FUNCTIONS ###
init_functions = {
                    'ones':torch.nn.init.ones_,
                    'eye':torch.nn.init.eye_,
                    'default': partial(torch.nn.init.kaiming_uniform_, a=5 ** (.5)),
                    'paper': None
                 }

for init_name, init_function in init_functions.items():
    path = pathlib.Path.cwd() / 'tensorboard_logs' / init_name
    writer = SummaryWriter(log_dir=path)
    
    def layer_logger(inst, inp, out, number=0):
        layer_name = f"{number}_{inst.__class__.__name__}"
        writer.add_histogram(layer_name, out)
    
    model = image_siren(
                            hidden_layers=10,
                            hidden_features=200,
                            first_omega=30,
                            hidden_omega=30,
                            c=6,
                            custom_initalizationg_function=init_function
                        )
    model = model.cuda() if torch.cuda.is_available() else model
    
    for i, layer in enumerate(model.model.modules()):
        if not i:
            continue
        layer.register_forward_hook(partial(layer_logger, number=(i + 1) // 2))
    
    inp = 2 * (torch.rand(10000, 2) - .5)
    
    inp = inp.cuda() if torch.cuda.is_available() else inp
    
    writer.add_histogram('0', inp)
    
    res = model(inp)
    
    del model, inp
    torch.cuda.empty_cache()

In [3]:
def generate_coordinates(n: int) -> torch.Tensor:
    '''
    Genearates grid of 2D coordinates [0, n]x[n, 0]
    params
    -----------
    n: int
        Number of 2d points
    returns:
    -----------
    coord_abs: torch.ndarray
        image coordinates of (n**2) x 2 size
    '''
    #meshgrid of torch uses ij so lets just use it 
    #rather than use np meshgrid and change the 
    # indexing to ij
    rows, cols = torch.meshgrid(torch.arange(n), torch.arange(n))
    # i , j format coordinates
    coords_abs = torch.stack([torch.ravel(rows), torch.ravel(cols)], axis=-1)
    
    return coords_abs

In [6]:
class pixel_dataset(Dataset):
    '''
    Custom Dataloader for the torch training
    params:
    ------------
    size: int
        height and width of image
    coords_abs: torch.tensor
        image coordinates (n**2) x 2 size
    grad: torch.tensor
        gradient approximation in two directions x, y of size (size x size, 2) 
    grad_norm: torch.tensor
        gradient image(approx) normalized of size (size x size)
    laplace: torch.tensor
        laplace of the image/signal to be precise
    '''
    def __init__(self, img):
        if not img.dim() == 2 or not img.size()[0] == img.size()[1]:
            raise ValueError('Image should be single channel square image')
        
        #creating dataset
        self.img = img
        self.size = img.size()[0]
        self.coords_abs = generate_coordinates(self.size)
        # better not normalize
        self.grad = spatial_gradient(img.view(1, 1, self.size, self.size), mode='sobel', normalized=False)[0][0]
        self.grad_norm = torch.linalg.norm(self.grad, dim=0)
        self.laplace = laplacian(img.view(1, 1, self.size, self.size), normalized=False)
    
    def __len__(self):
        '''number of samples :) (pixels :( )'''
        return self.size ** 2
    
    def __getitem__(self, idx):
        '''get all relavant data for one single coordinate'''
        coords_abs = self.coords_abs[idx]
        r, c = coords_abs
        
        coords = 2 * ((coords_abs / self.size) - .5) # change scale
        
        return {
                    'coords': coords,
                    'coords_abs': coords_abs,
                    'intensity': self.img[r, c],
                    'grad_norm': self.grad_norm[r, c],
                    'grad': self.grad[r, c],
                    'laplace': self.laplace[r, c]
               }
        

In [7]:
class gradient_utils:
    @staticmethod
    def gradient(target, coords):
        '''
        Compute the gradient with respect to input
        -------------
        params
        -------------
        target: torch.Tensor
            2D tensor of shape (n_coords, ?) representing the targets
        -------------
        coords: torch.Tensor
            2D tensor of shape (n_coords, 2) representing the coordinates
        -------------
        returns: grad: torch.Tensor
            2D tensor of shape (n_coords, 2) representing the gradient
        '''
        return torch.autograd.grad(
                                        target,
                                        corrds,
                                        grad_outputs=torch.ones_like(target),
                                        create_graph=True
                                  )[0]
    
    @staticmethod
    def divergence(grad, coords):
        '''
        Compute the second order partial derivative when input is gradient
        2D case it is F_{xx} + F_{yy}
        --------------
        params
        --------------
        grad : torch.Tensor
            2D tensor of shape (n_corrds, 2) representing the gradient w.r.t x and y
        --------------
        coords: torch.Tensor
            2D tensor of shape (n_coords, 2) representing the coordinates
        ---------------
        returns: divergence: torch.Tensor
            2D tensor of shape (n_coords, 1) representing the divergence
        '''
        div = .0
        for i in range(coords.shape[1]):
            div += torch.autograd.grad(
                                            grad[..., i],
                                            coords,
                                            torch.ones_like(grad[..., i]),
                                            create_graph=True
                                      )[0][..., i: i + 1]
        return div
        
    @staticmethod
    def laplace(target, coords):
        '''
        Computer Laplace
        ------------------
        params
        ------------------
        target: torch.Tensor
            2D tensor of shape (n_coords, 1) representing the targets
        ------------------
        coords: torch.Tensor
            2D tensor of shape (n_coords, 2) representing the coordinates
        ------------------
        returns: laplace: torch.Tensor
            2D tensor of shape (n_coords, 1) representing the laplace
         '''
        grad = gradient_utils.gradient(target, coords)
        
        return gradient_utils.divergence(grad, coords)
        