### Import modules + Check GPU

In [2]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
from torchvision.io import read_image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt
import time
from skimage import io
from tqdm import tqdm
import scipy.ndimage
from torch.utils.tensorboard import SummaryWriter

In [5]:
print("GPU available: {}".format(torch.cuda.is_available()))
print("Device: {}".format(torch.cuda.get_device_name(0)))

GPU available: True
Device: Tesla K80


### Data Generation

In [None]:
def isotropic_diffusion(img, niter=1, kappa=50, gamma=0.1, voxelspacing=None):

    # initialize output array
    out = np.array(img, dtype=np.float32, copy=True)

    # set default voxel spacing if not supplied
    if voxelspacing is None:
        voxelspacing = tuple([1.] * img.ndim)

    # initialize some internal variables
    deltas = [np.zeros_like(out) for _ in range(out.ndim)]

    
    time = 0
    
    results_pixels = []
    results_dIdt = []
    results_time = []
    
    results_pixels.append(out.astype(img.dtype))
    results_time.append(time)
    #results_dIdt.append(np.zeros_like(out))

    for iter in tqdm(range(niter)):
        # calculate the diffs
        for i in range(out.ndim):
            slicer = [slice(None, -1) if j == i else slice(None) for j in range(out.ndim)]
            diff_local = np.diff(out, axis=i)
            deltas[i][tuple(slicer)] = diff_local

        matrices = [delta for delta, spacing in zip(deltas, voxelspacing)]

        # second derivative
        for i in range(out.ndim):
            slicer = [slice(1, None) if j == i else slice(None) for j in range(out.ndim)]
            matrices[i][tuple(slicer)] = np.diff(matrices[i], axis=i)

        
        dIdt = np.sum(matrices, axis=0)
        #print(dIdt)
        
        # update the image
        out += gamma * (dIdt)
        time += gamma
        
        results_dIdt.append(dIdt.astype(img.dtype))
        if iter < niter - 1:
            results_pixels.append(out.astype(img.dtype))
            results_time.append(time)

    return results_pixels, results_dIdt, results_time

def get_mgrid(sidelen=256, dim=2):
    
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    
    return mgrid

class ImageFitting(Dataset):
    
    def __init__(self, img_path, niter):
        
        self.transform = Compose([
            Resize(256),
            ToTensor(),
            Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
        ])
        self.coords = get_mgrid()
        
        print("-----Generating Data-----")
        self.base_img = io.imread(img_path)
        self.imgs_pixels, self.imgs_dIdt, self.imgs_time = isotropic_diffusion(self.base_img, niter=niter, kappa=50, gamma=1/(niter+1))

        print("-----Finished-----")
        
        self.len = len(self.imgs_pixels)
        
    def __len__(self):
        
        return self.len
    
    def __getitem__(self, idx):

        image = self.imgs_pixels[idx]
        image = self.transform(Image.fromarray(image))
        
        pixels = image.permute(1, 2, 0).view(-1, 1)
        step_val = torch.full((self.coords.size(0),1), self.imgs_time[idx])

        model_input = torch.cat((self.coords, step_val), 1)
        
        # Compute gradient and laplacian       
        grads_x = scipy.ndimage.sobel(image.numpy(), axis=1).squeeze(0)[..., None]
        grads_y = scipy.ndimage.sobel(image.numpy(), axis=2).squeeze(0)[..., None]
        grads_x, grads_y = torch.from_numpy(grads_x), torch.from_numpy(grads_y)
                
        grads = torch.stack((grads_x, grads_y), dim=-1).view(-1, 2)
        laplace = scipy.ndimage.laplace(image.numpy()).squeeze(0)[..., None]
        laplace = torch.from_numpy(laplace).view(-1, 1)
        
        dIdt = torch.from_numpy(self.imgs_dIdt[idx])
        dIdt = dIdt.permute(0,1).view(-1)
        
        return model_input, {'pixels':pixels, 'grads':grads, 'laplace':laplace, 'dIdt':dIdt}

### Loss Calculation

In [None]:
def computeJacobianFull(x, outputs, create_graph):
    
    dy_dx = torch.autograd.grad(outputs=outputs, inputs=x, grad_outputs=torch.ones_like(outputs),
            retain_graph=True, create_graph=create_graph, allow_unused=True)[0]
    
    dy_dx = dy_dx.view(outputs.size(0), outputs.size(1), dy_dx.size(2))
    
    return dy_dx

def computeLaplaceFull(x, jacobian, create_graph):
    
    div = 0
    for j in range(jacobian.size(-1)):

        dy_dx2 = torch.autograd.grad(outputs=jacobian[:, :, j], inputs=x, grad_outputs=torch.ones_like(jacobian[:, :, j]),
            retain_graph=True, create_graph=create_graph)[0][..., j:j+1]

        div += dy_dx2

    return div

def calcLoss(coords, model_output, gt):
    
    pixel_loss = ((model_output - gt['pixels'])**2).mean()
    
    gradients = computeJacobianFull(coords, model_output, create_graph=True)
    grad_loss = ((gradients[:,:,:-1] - gt['grads']).pow(2).sum(-1)).mean()
    
    laplacian = computeLaplaceFull(coords, gradients[:,:,:-1], create_graph=False)
    laplacian_loss = ((laplacian - gt['laplace'])**2).mean()
    
    dIdt_loss = ((gradients[:,:,-1] - gt['dIdt'])**2).mean()
    
    return pixel_loss, grad_loss, laplacian_loss, dIdt_loss


### SIREN Network Architecture

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

### ELU Network Architecture

In [None]:
class ELULayer(nn.Module):
    
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            nn.init.xavier_uniform_(self.linear.weight)
        
    def forward(self, input):
        return F.elu(self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.linear(input)
        return F.elu(intermediate), intermediate
    
    
class Base(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(ELULayer(in_features, hidden_features))

        for i in range(hidden_layers):
            self.net.append(ELULayer(hidden_features, hidden_features))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                nn.init.xavier_uniform_(final_linear.weight)
                
            self.net.append(final_linear)
        else:
            self.net.append(ELULayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

### Train Network

In [None]:
def train(net, writer, img_path, niter, total_epochs=50, beta_0=1, beta_1=1, beta_2=1, beta_3=1, cyclic=False):
    
    """Args:
        net: Network to Train
        writer: SummaryWriter for logging
        img_path: path to default state image
        niter: number of steps to apply diffusion (0 means only 1 image)
        total_epochs: number of epochs to train
        beta_0: constant for loss on pixel value
        beta_1: constant for loss on gradients
        beta_2: constant for loss on laplacian
        beta_3: constant for loss on pixel time derivative
        cyclic: CyclicLearning rate (allows better learning)"""
    
    
    image = ImageFitting(img_path=img_path, niter=niter)
    dataloader = DataLoader(image, batch_size=1, pin_memory=True, num_workers=0)

    net.cuda()

    epochs_til_summary = 5 #UPDATE ACCORDINGLY
    steps_til_summary = 1 #UPDATE ACCORDINGLY

    optim = torch.optim.Adam(lr=1e-4, params=net.parameters())
    
    if cyclic:
        scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-7, max_lr=1e-4, step_size_up=250, cycle_momentum=False)
    
    print("-----Begin Training-----")
    for epoch in range(1, total_epochs + 1):
        
        epoch_loss = 0.0
        epoch_pixel_loss = 0.0
        epoch_grad_loss = 0.0
        epoch_laplacian_loss = 0.0
        epoch_dIdt_loss = 0.0

        for step, batch in tqdm(enumerate(dataloader)):

            model_input = batch[0].cuda()
            gt = {key: value.cuda() for key, value in batch[1].items()}

            model_output, coords = net(model_input)   
            
            pixel_loss, grad_loss, laplacian_loss, dIdt_loss = calcLoss(coords, model_output, gt)
            
            loss = beta_0 * pixel_loss + beta_1 * grad_loss + beta_2 * laplacian_loss + beta_3 * dIdt_loss
            
            epoch_loss += model_output.shape[0] * loss.item()
            epoch_pixel_loss += model_output.shape[0] * pixel_loss.item()
            epoch_grad_loss += model_output.shape[0] * grad_loss.item()
            epoch_laplacian_loss += model_output.shape[0] * laplacian_loss.item()
            epoch_dIdt_loss += model_output.shape[0] * dIdt_loss.item()

            if not epoch % epochs_til_summary and step % steps_til_summary == steps_til_summary - 1:

                pixel_output = model_output[0].view(1, -1, 256, 256)
                pixel_gt = gt['pixels'][0].view(1, -1, 256, 256)
                img_grid_pixel = torchvision.utils.make_grid(torch.cat((pixel_gt, pixel_output), 0), 2)
                img_grid_pixel = img_grid_pixel * 0.5 + 0.5
                writer.add_image('pixels', img_grid_pixel, epoch * len(dataloader) + step + 1)
                
                
                img_grad = computeJacobianFull(coords, model_output, create_graph=True)
                grad_output = img_grad[0,:,:-1].norm(dim=-1).view(1, -1, 256, 256)
                grad_gt = gt['grads'][0].norm(dim=-1).view(1, -1, 256, 256)
                img_grid_grad = torchvision.utils.make_grid(torch.cat((grad_gt, grad_output), 0), 2)
                writer.add_image('grads', img_grid_grad, epoch * len(dataloader) + step + 1)
                
                
                img_laplacian = computeLaplaceFull(coords, img_grad, create_graph=False)
                laplacian_output = img_laplacian[0].view(1, -1, 256, 256)
                laplacian_gt = gt['laplace'][0].view(1, -1, 256, 256)
                img_grid_laplacian = torchvision.utils.make_grid(torch.cat((laplacian_gt, laplacian_output), 0), 2)
                writer.add_image('laplacians', img_grid_laplacian, epoch * len(dataloader) + step + 1)
                
                dIdt_output = img_grad[0,:,-1].view(1, -1, 256, 256)
                dIdt_gt = gt['dIdt'][0].view(1, -1, 256, 256)
                img_grid_dIdt = torchvision.utils.make_grid(torch.cat((dIdt_gt, dIdt_output), 0), 2)
                writer.add_image('dIdt', img_grid_dIdt, epoch * len(dataloader) + step + 1)

                fig, axes = plt.subplots(2,4, figsize=(18,6))
                axes[0,0].imshow(gt['pixels'][0].cpu().view(256,256).detach().numpy())
                axes[0,1].imshow(gt['grads'][0].norm(dim=-1).cpu().view(256,256).detach().numpy())
                axes[0,2].imshow(gt['laplace'][0].cpu().view(256,256).detach().numpy())
                axes[0,3].imshow(gt['dIdt'][0].cpu().view(256,256).detach().numpy())
                axes[1,0].imshow(model_output[0].cpu().view(256,256).detach().numpy())
                axes[1,1].imshow(img_grad[0][:,:-1].norm(dim=-1).cpu().view(256,256).detach().numpy())
                axes[1,2].imshow(img_laplacian[0].cpu().view(256,256).detach().numpy())
                axes[1,3].imshow(img_grad[0][:,-1].cpu().view(256,256).detach().numpy())
                plt.show()

            optim.zero_grad()
            loss.backward()
            optim.step()

            if cyclic:
                scheduler.step()
        
        # logging epoch loss
        writer.add_scalar('epoch_loss/total', epoch_loss/len(image), epoch)
        writer.add_scalar('epoch_loss/pixel', epoch_pixel_loss/len(image), epoch)
        writer.add_scalar('epoch_loss/grad', epoch_grad_loss/len(image), epoch)
        writer.add_scalar('epoch_loss/laplacian', epoch_laplacian_loss/len(image), epoch)
        writer.add_scalar('epoch_loss/dIdt', epoch_dIdt_loss/len(image), epoch)
        print("epoch %d, Epoch loss: total %0.6f, pixel %0.6f, grad %0.6f, laplacian %0.6f, dIdt %0.6f" % (epoch, epoch_loss/len(image), epoch_pixel_loss/len(image), epoch_grad_loss/len(image), epoch_laplacian_loss/len(image), epoch_dIdt_loss/len(image)))
        
    print("-----Finished-----")
                

In [None]:
writer = SummaryWriter('runs/siren/cameraman_experiment_pixels')

img_siren = Siren(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_siren, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=1, beta_1=0, beta_2=0, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/siren/cameraman_experiment_grads')

img_siren = Siren(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_siren, writer, img_path='original/cameraman.png', niter=1, total_epochs=1000, beta_0=0, beta_1=1, beta_2=0, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/siren/cameraman_experiment_laplace')

img_siren = Siren(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_siren, writer, img_path='original/cameraman.png', niter=1, total_epochs=1000, beta_0=0, beta_1=0, beta_2=1, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/siren/cameraman_experiment_dIdt')

img_siren = Siren(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_siren, writer, img_path='original/cameraman.png', niter=1, total_epochs=1000, beta_0=0, beta_1=0, beta_2=0, beta_3=1)

writer.close()

In [None]:
writer = SummaryWriter('runs/siren/cameraman_experiment_all')

img_siren = Siren(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_siren, writer, img_path='original/cameraman.png', niter=1, total_epochs=1000, beta_0=1, beta_1=.01, beta_2=.001, beta_3=.001, cyclic=True)

writer.close()

In [None]:
writer = SummaryWriter('runs/base/cameraman_experiment_pixels')

img_base = Base(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_base, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=1, beta_1=0, beta_2=0, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/base/cameraman_experiment_grads')

img_base = Base(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_base, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=0, beta_1=1, beta_2=0, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/base/cameraman_experiment_laplace')

img_base = Base(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_base, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=0, beta_1=0, beta_2=1, beta_3=0)

writer.close()

In [None]:
writer = SummaryWriter('runs/base/cameraman_experiment_dIdt')

img_base = Base(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_base, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=0, beta_1=0, beta_2=0, beta_3=1)

writer.close()

In [None]:
writer = SummaryWriter('runs/base/cameraman_experiment_all')

img_base = Base(in_features=3, out_features=1, hidden_features=512, 
                      hidden_layers=3, outermost_linear=True)

train(img_base, writer, img_path='original/cameraman.png', niter=1, total_epochs=10, beta_0=1, beta_1=1, beta_2=1, beta_3=1)

writer.close()

In [None]:
%load_ext tensorboard
%tensorboard --logdir="runs"