# Neural Implicit Representation

Faisal Qureshi     
faisal.qureshi@ontariotechu.ca

In [None]:
from importlib import reload
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

## Picking an MNIST digit

In [None]:
import torchvision as tv

def pick_mnist_image(idx=0, download=False):
    data = tv.datasets.MNIST(root='../../data', download=download)
    image = np.array(data[idx][0])
    return image

In [None]:
image = pick_mnist_image(idx=0, download=False)
h, w = image.shape

plt.figure(figsize=(5,5))
plt.title(f'{w}x{h}')
plt.xticks([])
plt.yticks([])
plt.imshow(image, cmap='gray');

## Pixel dataset

In [None]:
import sys
sys.path.append('../') # Need for positional_encoding.py

import torch
import torch.utils.data as tdata
import positional_encoding as pe

class NimDataset(tdata.Dataset):
    def __init__(self, image, pos_enc_dim=8):
        self.image = image / 255.
        self.pos, self.enc_x, self.enc_y = self.construct_positional_encoding(self.image, pos_enc_dim=pos_enc_dim)
        self.n = len(self.pos)
        
    @staticmethod
    def construct_positional_encoding(image, pos_enc_dim):
        h, w = image.shape
        x, y = np.arange(w), np.arange(h)
        xx, yy = np.meshgrid(x, y)
        pos = np.stack([xx, yy, image], axis=2).reshape(-1, 3)
        enc_x = pe.positional_encoding(x, pos_enc_dim)
        enc_y = pe.positional_encoding(y, pos_enc_dim)
        return pos, enc_x, enc_y
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        x = self.pos[idx, 0]
        y = self.pos[idx, 1]
        e_x = self.enc_x[int(x)]
        e_y = self.enc_y[int(y)]
        return {
            'pos': torch.Tensor(np.hstack((e_x, e_y))),
            'c': torch.Tensor(self.pos[idx, 2:])
        }

In [None]:
dataset = NimDataset(image, 8)
dataset[2]

## Model

In [None]:
class Nir(torch.nn.Module):
    def __init__(self, h, w, pos_enc_dim, output_dim):
        super(Nir, self).__init__()
        
        self.linear1 = torch.nn.Linear(pos_enc_dim*2, 32)
        self.linear2 = torch.nn.Linear(32, 32)
        self.linear3 = torch.nn.Linear(32, 16)        
        self.linear4 = torch.nn.Linear(16, output_dim)
        
        self.h, self.w, self.pos_enc_dim = h, w, pos_enc_dim
        
    def forward(self, x):
        x = torch.sin(self.linear1(x))
        x = torch.sin(self.linear2(x))
        x = torch.sin(self.linear3(x))
        x = self.linear4(x)
        return x

## DataLoader and compute device

In [None]:
from torch.utils.data import DataLoader

training_data = NimDataset(image, 8)
train_dataloader = DataLoader(training_data, batch_size=16, shuffle=False)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

model = Nir(28, 28, 8, 1).to(device)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

num_epochs = 1
for epoch in range(num_epochs):
    loss_epoch = 0.0
    cnts = 0
    for _, data in enumerate(train_dataloader):
        inputs = data['pos'].to(device)
        labels = data['c'].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)    
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
    
        cnts += 1
        loss_epoch += loss.cpu().item()
        
    loss_epoch /= cnts
    writer.add_scalar('Loss/train', loss, epoch)
        
    if epoch % 1000 == 0 or epoch == num_epochs-1:
        print(f'epoch = {epoch}: loss = {loss_epoch}')
        
#torch.save(model.state_dict(), './nir.pts')

## Inference

In [None]:
def reconstruct_image(model):
    h, w, pos_enc_dim = model.h, model.w, model.pos_enc_dim
    
    output_image = np.empty((h,w))
    x, y = np.arange(w), np.arange(h)
    xx, yy = np.meshgrid(x, y)
    pos = np.stack([xx, yy], axis=2).reshape(-1, 2)

    model.eval()
    for i in range(pos.shape[0]):
        x, y = pos[i, 0], pos[i, 1]
        enc_x = pe.positional_encoding(x, pos_enc_dim)
        enc_y = pe.positional_encoding(y, pos_enc_dim)
        inputs = torch.Tensor(np.hstack((enc_x, enc_y))).unsqueeze(0).to(device)
        c = model(inputs)
        output_image[int(y)][int(x)] = c.detach().cpu()[0][0][0].numpy()
    
    return output_image

In [None]:
model = Nir(28, 28, 8, 1)
model.load_state_dict(torch.load('./nir.pts'))
model.to(device)

output_image = reconstruct_image(model)
h, w = output_image.shape

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title(f'reconstructed {w}x{h}')
plt.xticks([])
plt.yticks([])
plt.imshow(output_image, cmap='gray')
plt.subplot(1,2,2)
plt.title(f'original {w}x{h}')
plt.xticks([])
plt.yticks([])
plt.imshow(image, cmap='gray')

## Super-resolution

In [None]:
def resample_image(model, new_h, new_w):
    h, w, pos_enc_dim = model.h, model.w, model.pos_enc_dim
    
    output_image = np.empty((new_h,new_w))
    x, y = np.linspace(0, w, new_w), np.linspace(0, h, new_h)    
    xx, yy = np.meshgrid(x, y)
    xx_, yy_ = np.meshgrid(np.arange(new_w), np.arange(new_h))
    pos = np.stack([xx, yy, xx_, yy_], axis=2).reshape(-1, 4)

    model.eval()
    for i in range(pos.shape[0]):
        x, y, x_, y_ = pos[i, 0], pos[i, 1], int(pos[i, 2]), int(pos[i, 3])
        enc_x = pe.positional_encoding(x, pos_enc_dim)
        enc_y = pe.positional_encoding(y, pos_enc_dim)
        inputs = torch.Tensor(np.hstack((enc_x, enc_y))).unsqueeze(0).to(device)
        c = model(inputs)        
        output_image[y_][x_] = c.detach().cpu()[0][0][0].numpy()
    
    return output_image                      

In [None]:
model = Nir(28, 28, 8, 1)
model.load_state_dict(torch.load('./nir.pts'))
model.to(device)

new_h, new_w = 128, 128
output_image = resample_image(model,new_h,new_w)
h, w = 28, 28

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title(f'reconstructed {new_h}x{new_w}')
plt.xticks([])
plt.yticks([])
plt.imshow(output_image, cmap='gray')
plt.subplot(1,2,2)
plt.title(f'original {h}x{w}')
plt.xticks([])
plt.yticks([])
plt.imshow(image, cmap='gray')