In [None]:
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

from metrics import MSE, PSNR
from network import ProgressiveSiren
from utils import make_grid2d

Spectral growing of an image (Sec. 4.2)

In [None]:
# set random seed
random_seed = 31210
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# show target image to reconstruct
y = Image.open('cameraman.tif')
print('target size:', y.size)
plt.imshow(y, cmap = 'gray')
y = transforms.ToTensor()(y).to(device)
C, H, W = y.shape
x = make_grid2d(H, W).to(device)
x = x.reshape(-1, 2)
y = y.reshape(C, -1)

In [None]:
# create the streamable neural field network
starting_width = 20
increased_width = 20
n_subnets = 4

net = ProgressiveSiren(in_feats = 2, hidden_feats = starting_width, n_hidden_layers = 3, out_feats = 1)
net.to(device)
print(net)

lr = 2e-4
epochs = 5000
# loss function
mse = MSE()
trained_widths = [starting_width]

# training loop
for i in range(n_subnets):
    if i != 0:
        # grow width
        net.grow_width(width = increased_width)
        trained_widths.append(trained_widths[-1] + increased_width)
    net.select_subnet(i)
    optimizer = optim.Adam(net.parameters(), lr = lr)
    print("current width: {}".format(trained_widths[-1]))
    for e in tqdm(range(epochs)):
        optimizer.zero_grad()               # clear gradients
        yhat = net(x)                       # forward prop.
        loss = mse(yhat.permute(1, 0), y)   # compute loss
        loss.backward()                     # backward prop.
        if i > 0:
            net.freeze_subnet(i - 1)        # clear gradients of pretrained sub-network
        optimizer.step()                    # update weights

    print("MSE:", loss.detach().cpu().numpy())
    print()
print("training done")

In [None]:
# show results
ax = []
fig = plt.figure(figsize = (20, 10))
with torch.no_grad():
    net.eval()
    for i, w in enumerate(trained_widths):
        net.select_subnet(i)                    # select sub-network
        yhat = net(x).reshape(H, W, C)          # forward prop.
        y = y.reshape(H, W, C)
        psnr = PSNR()(y.cpu(), yhat.cpu())      # compute PSNR
        ax.append(fig.add_subplot(2, 4, i + 1))
        ax[i].imshow(yhat.cpu(), cmap = 'gray')
        plt.title(f'width: {w}, PSNR: {psnr:.2f}', fontsize = 20)

    for i, w in enumerate(trained_widths):
        if i == 0:
            continue
        net.select_subnet(i)                    # select sub-network
        yhat_res = net.forward_residual(x)      # forward prop. for residual output
        ax.append(fig.add_subplot(2, 4, i + 5))
        ax[i + 3].imshow(yhat_res.reshape(H, W, C).cpu(), cmap = 'gray')
        plt.title(f'width: {w} (residual)', fontsize = 20)

Growing frequency spectrum (supplementary Sec. 2.1)

In [None]:
# load pretrained model
net = net = ProgressiveSiren(in_feats = 2, hidden_feats = 120, n_hidden_layers = 4, out_feats = 3)
net.to(device)
trained_widths = [30, 60, 90, 120]
for i in range(6):
    net.net[i].subnet_widths = [30, 60, 90, 120]
net.load_state_dict(torch.load('sunflower.pth'))
print(net)
C, H, W = 3, 800, 800
x = make_grid2d(H, W).to(device)
x = x.reshape(-1, 2)

In [None]:
# show results
ax = []
fig = plt.figure(figsize = (20, 10))
with torch.no_grad():
    net.eval()
    for i, w in enumerate(trained_widths):
        net.select_subnet(i)            # select sub-network
        yhat = net(x).reshape(H, W, C)  # forward prop.
        ax.append(fig.add_subplot(2, 4, i + 1))
        ax[-1].imshow(yhat.cpu(), cmap = 'gray')
        plt.title(f'width: {w}', fontsize = 20)
        yhat_f = torch.sum(torch.log10(torch.abs(torch.fft.fft2(yhat.permute(2, 0, 1)))), 0)
        yhat_f = torch.roll(yhat_f, shifts = (int(yhat_f.shape[0] / 2), int(yhat_f.shape[1] / 2)), dims = (0, 1))
        ax.append(fig.add_subplot(2, 4, i + 5))
        ax[-1].imshow(yhat_f.cpu(), cmap = 'magma')