In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Linear, ReLU, Sequential
from torch.utils.data import DataLoader
import tqdm
from core.image_siren import GradientUtils, ImageSiren, PixelDataset

In [None]:
# Prefer CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Image loading
img_ = plt.imread('dog.png')
downsampling_factor = 2
img = 2 * (img_ - 0.5)
img = img[::downsampling_factor, ::downsampling_factor]

In [None]:
plt.imshow(img, cmap='gray')

In [None]:
size = img.shape[0]

In [None]:
# Create the dataset
dataset = PixelDataset(img)

In [None]:
# Parameters
n_epochs = 5000
batch_size = int(size ** 2)
logging_freq = 50

In [None]:
model_name = 'siren'  # 'siren' or 'mlp_relu'
hidden_features = 512
hidden_layers = 3

In [None]:
target = 'intensity'  # 'intensity', 'grad' or 'laplace'

In [None]:
# Create the model
if model_name == 'siren':
    model = ImageSiren(
        hidden_features=hidden_features,
        hidden_layers=hidden_layers,
        hidden_omega=30
    )
elif model_name == 'mlp_relu':
    layers = [Linear(2, hidden_features), ReLU()]
    for _ in range(hidden_layers):
        layers.append(Linear(hidden_features, hidden_features))
        layers.append(ReLU())
    layers.append(Linear(hidden_features, 1))
    model = Sequential(*layers)
    for module in model.modules():
        if not isinstance(module, Linear):
            continue
        torch.nn.init.xavier_uniform_(module.weight)
else:
    raise ValueError('Unknown model name')

model.to(device)

In [None]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optim = torch.optim.Adam(params=model.parameters(), lr=1e-4)

In [None]:
# Train the model
for e in range(n_epochs):
    losses = []
    for d_batch in tqdm.tqdm(dataloader):
        x_batch = d_batch['coords'].to(torch.float32).to(device)
        x_batch.requires_grad = True  # Allow taking derivatives

        y_true_batch = d_batch['intensity'].to(torch.float32).to(device)
        y_true_batch = y_true_batch[:, None]  # Add another dimension

        y_pred_batch = model(x_batch)

        if target == 'intensity':
            loss = torch.nn.functional.mse_loss(y_pred_batch, y_true_batch)
        elif target == 'grad':
            y_pred_g_batch = GradientUtils.gradient(y_pred_batch, x_batch)
            y_true_g_batch = d_batch['grad'].to(torch.float32)
            loss = torch.nn.functional.mse_loss(y_pred_g_batch, y_true_g_batch)
        elif target == 'laplace':
            y_pred_l_batch = GradientUtils.laplacian(y_pred_batch, x_batch)
            y_true_l_batch = d_batch['laplace'].to(torch.float32)
            loss = torch.nn.functional.mse_loss(y_pred_l_batch, y_true_l_batch)
        else:
            raise ValueError('Unknown target')

        losses.append(loss.cpu().item())

        optim.zero_grad()
        loss.backward()
        optim.step()

    print(e, np.mean(losses))

    if e % logging_freq == 0:
        # Display the prediction image
        pred_img = np.zeros_like(img)
        pred_img_grad_norm = np.zeros_like(img)
        pred_img_laplace = np.zeros_like(img)

        orig_img = np.zeros_like(img)
        for d_batch in tqdm.tqdm(dataloader):
            coords_cpu = d_batch['coords'].to(torch.float32)
            coords_cpu.requires_grad = True

            coords = coords_cpu.to(device)
            coords_abs = d_batch['coords_abs'].numpy()

            pred = model(coords).cpu()
            pred_n = pred.detach().numpy().squeeze()
            pred_g = GradientUtils.gradient(pred, coords_cpu).norm(dim=-1).detach().numpy().squeeze()
            pred_l = GradientUtils.laplace(pred, coords_cpu).detach().numpy().squeeze()

            pred_img[coords_abs[:, 0], coords_abs[:, 1]] = pred_n
            pred_img_grad_norm[coords_abs[:, 0], coords_abs[:, 1]] = pred_g
            pred_img_laplace[coords_abs[:, 0], coords_abs[:, 1]] = pred_l

        fig, axes = plt.subplots(3, 2, constrained_layout=True)
        axes[0, 0].imshow(dataset.img, cmap='gray')
        axes[0, 1].imshow(pred_img, cmap='gray')
        axes[1, 0].imshow(dataset.grad_norm, cmap='gray')
        axes[1, 1].imshow(pred_img_grad_norm, cmap='gray')
        axes[2, 0].imshow(dataset.laplace, cmap='gray')
        axes[2, 1].imshow(pred_img_laplace, cmap='gray')

        for row in axes:
            for ax in row:
                ax.axis('off')

        fig.suptitle('Epoch {}'.format(e))
        axes[0, 0].set_title('Original')
        axes[0, 1].set_title('Prediction')

        plt.show()