In [None]:
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from tqdm import tqdm

from metrics import MSE
from network import ProgressiveSiren

1D scalar function fitting (Sec. 4.1)

In [None]:
#set random seed
random_seed = 31210
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#create synthetic target function
x = torch.linspace(0, 1, steps = 200)
y = torch.zeros_like(x)
n_modes = 10
f = 5
pi = 3.14159
for i in range(n_modes):
    phase = torch.rand(1) * 2 * pi
    y += torch.sin(2 * pi * f * x + phase)
    f += 5
x, y = torch.unsqueeze(x.to(device), 1), torch.unsqueeze(y.to(device), 1)

#visualize target ground truth function
print("ground truth")
plt.plot(x.cpu(), y.cpu(), 'y')
plt.grid(True)

In [None]:
#create the streamable neural field network
starting_width = 10
increased_width = 10
n_subnets = 4

net = ProgressiveSiren(in_feats = 1, hidden_feats = starting_width, n_hidden_layers = 3, out_feats = 1)
net.to(device)
print(net)

lr = 1e-4
epochs = 150
# loss function
mse = MSE()
trained_widths = [starting_width]

#training loop
for i in range(n_subnets):
    if i != 0:
        net.grow_width(width = increased_width)
        trained_widths.append(trained_widths[-1] + increased_width)
    net.select_subnet(i)
    optimizer = optim.Adam(net.parameters(), lr = lr)
    print("current width: {}".format(trained_widths[-1]))
    for e in tqdm(range(epochs)):
        optimizer.zero_grad()               # clear gradients
        yhat = net(x)                       # forward prop.
        loss = mse(yhat, y)                 # compute loss
        loss.backward()                     # backward prop.
        if i > 0:
            net.freeze_subnet(i - 1)        # clear gradients of pretrained sub-network
        optimizer.step()                    # update weights

    print("MSE:", loss.detach().cpu().numpy())
    print()
print("training done")

In [None]:
#visualize output
ax = []
fig = plt.figure(figsize = (36, 10))

with torch.no_grad():
    net.eval()
    for i, w in enumerate(trained_widths):
        net.select_subnet(i)    # select sub-network
        yhat = net(x)           # forward prop.
        ax.append(fig.add_subplot(2, 4, i + 1))
        ax[i].plot(x.cpu(), y.cpu(), 'y--', label = 'gt')
        ax[i].plot(x.cpu(), yhat.cpu(), 'g', label = 'width: {}'.format(w))
        plt.ylim(-5.7, 4.7)
        plt.grid(True)
        plt.legend(loc = 'lower left', fontsize = 15)
        plt.title(f'width: {w}', fontsize = 20)
    for i, w in enumerate(trained_widths):
        if i == 0:
            continue
        net.select_subnet(i)                # select sub-network
        yhat_res = net.forward_residual(x)  # forward prop. for residual output
        ax.append(fig.add_subplot(2, 4, i + 5))
        ax[i + 3].plot(x.cpu(), yhat_res.cpu(), 'k', label = 'width: {}'.format(w))
        plt.ylim(-5.7, 4.7) 
        plt.grid(True)
        plt.legend(loc = 'lower left', fontsize = 15)
        plt.title(f'width: {w} (residual)', fontsize = 20)