In [None]:
# import statements

import torch
import torch.nn as nn
import torch.nn.functional as F
import samplers
import numpy as np
import importlib
import matplotlib.pyplot as plt
importlib.reload(samplers)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class InputMapping(nn.Module):
    """Fourier features mapping."""

    def __init__(
        self, d_in, n_freq, sigma=1, tdiv=2, incrementalMask=True, Tperiod=None, kill=False
    ):
        super().__init__()
        Bmat = torch.randn(n_freq, d_in) * np.pi * sigma / np.sqrt(d_in)  # gaussian
        # time frequencies are a quarter of spacial frequencies.
        # Bmat[:, d_in-1] /= tdiv
        Bmat[:, 0] /= tdiv

        self.Tperiod = Tperiod
        if Tperiod is not None:
            # Tcycles = (Bmat[:, d_in-1]*Tperiod/(2*np.pi)).round()
            # K = Tcycles*(2*np.pi)/Tperiod
            # Bmat[:, d_in-1] = K
            Tcycles = (Bmat[:, 0] * Tperiod / (2 * np.pi)).round()
            K = Tcycles * (2 * np.pi) / Tperiod
            Bmat[:, 0] = K

        Bnorms = torch.norm(Bmat, p=2, dim=1)
        sortedBnorms, sortIndices = torch.sort(Bnorms)
        Bmat = Bmat[sortIndices, :]

        self.d_in = d_in
        self.n_freq = n_freq
        self.d_out = n_freq * 2 + d_in if Tperiod is None else n_freq * 2 + d_in - 1
        self.B = nn.Linear(d_in, self.d_out, bias=False)
        with torch.no_grad():
            self.B.weight = nn.Parameter(Bmat.to(device), requires_grad=False)
            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)

        self.incrementalMask = incrementalMask
        if not incrementalMask:
            self.mask = nn.Parameter(torch.ones(1, n_freq), requires_grad=False)
        if kill:
            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)

    def step(self, progressPercent):
        if self.incrementalMask:
            float_filled = (progressPercent * self.n_freq) / 0.7
            int_filled = int(float_filled // 1)
            # remainder = float_filled % 1

            if int_filled >= self.n_freq:
                self.mask[0, :] = 1
            else:
                self.mask[0, 0:int_filled] = 1
                # self.mask[0, int_filled] = remainder

    def forward(self, xi):
        # pdb.set_trace()
        dim = self.d_in - 1 # was xi.shape[1] - 1
        y = self.B(xi)
        # Unsqueeze y and xi at dim=0 if they are 1D tensors
        if len(y.shape) == 1:
            y = y.unsqueeze(0)
        if len(xi.shape) == 1:
            xi = xi.unsqueeze(0)
        if self.Tperiod is None:
            return torch.cat([torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi], dim=-1)
        else:
            return torch.cat(
                [torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi[:, 1 : dim + 1]], dim=-1
            )

In [None]:
def smooth_leaky_relu(x, alpha=0.1):
    return alpha * x + (1 - alpha) * F.softplus(x)

In [None]:
# Define a two-layer MLP with output_dim=hidden_dims -- this is our "h" function

class MLPh(nn.Module):
    def __init__(self, base_dims, hidden_dims, fourier_map=None, residual=False):
        super(MLPh, self).__init__()
        self.fourier_map = fourier_map
        self.residual = residual
        self.base_dims = base_dims
        if self.fourier_map is not None:
            self.base_dims = fourier_map.d_out
            print(self.base_dims)
        self.fc1 = nn.Linear(self.base_dims, hidden_dims)
        self.fc2 = nn.Linear(hidden_dims, hidden_dims)

    def forward(self, x):
        if self.fourier_map is not None:
            x = self.fourier_map(x)
        x = F.elu(self.fc1(x)) # elu works well!
        x = self.fc2(x)
        return x

In [None]:
# Define a two-layer MLP with output_dim=1 -- this is our "g" function

class MLPg(nn.Module):
    def __init__(self, hidden_dims, out_dims=1, residual=False):
        super(MLPg, self).__init__()
        self.fc1 = nn.Linear(hidden_dims, hidden_dims)
        self.fc2 = nn.Linear(hidden_dims, out_dims)
        self.residual = residual

    def forward(self, x):
        x = F.elu(self.fc1(x)) # elu works well!
        x = self.fc2(x)
        return x.squeeze()

In [None]:
# Target function is indicator of ball of radius R in 2D

def target_fn(x, radius=1):
    return (torch.linalg.norm(x, dim=1) <= radius).float()
    # Indicator of unit square in 2D
    # return (torch.abs(x[:, 0]) <= radius).float() * (torch.abs(x[:, 1]) <= radius).float()
    # Indicator of union of two unit squares in 2D
    # return (torch.abs(x[:, 0] - 0.5) <= radius).float() * (torch.abs(x[:, 1] - 0.5) <= radius).float() + (torch.abs(x[:, 0] + 0.5) <= radius).float() * (torch.abs(x[:, 1] + 0.5) <= radius).float()

In [None]:
# Plot target function on grid

n_grid = 100
x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
y_grid_target = target_fn(X_grid).squeeze().detach().cpu()

# Heatmap

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_target.reshape(n_grid, n_grid), cmap='coolwarm')
plt.colorbar()
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Target function')

In [None]:
def random_ball(num_points, dimension, radius=1):
    # First generate random directions by normalizing the length of a
    # vector of random-normal values (these distribute evenly on ball).
    random_directions = np.random.normal(size=(dimension,num_points))
    random_directions /= np.linalg.norm(random_directions, axis=0)
    # Second generate a random radius with probability proportional to
    # the surface area of a ball with a given radius.
    random_radii = np.random.random(num_points) ** (1/dimension)
    # Return the list of random (direction & length) points.
    return torch.from_numpy(radius * (random_directions * random_radii).T).float().to(device) 

### $d=5$ visualizations.

In [None]:
# Generate visualizations for experiments in '/rof_ours_vs_exact_results'

# Load results

reg_param = 0.05

avg_values_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/avg_values_exact_nuc_reg_param_' + str(reg_param) + '.npy')
avg_values_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/avg_values_our_nuc_reg_param_' + str(reg_param) + '.npy')

losses_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/exact_nuc_losses_reg_param_' + str(reg_param) + '.npy')
losses_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_losses_reg_param_' + str(reg_param) + '.npy')

abs_errors_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/abs_errors_exact_nuc_reg_param_' + str(reg_param) + '.npy')
abs_errors_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/abs_errors_our_nuc_reg_param_' + str(reg_param) + '.npy')

# Load models

d = 5
base_dims = d
hidden_dims = 100
fourier_map = InputMapping(d_in=base_dims, n_freq=500, sigma=1, incrementalMask=False).to(device)
model_exact_nuc = nn.Sequential(MLPh(base_dims, hidden_dims, fourier_map, residual=False), MLPg(hidden_dims, residual=False)).to(device)
model_exact_nuc.load_state_dict(torch.load('rof_ours_denoising_vs_exact_results_100k_iters_d5/exact_nuc_model_reg_param_' + str(reg_param) + '.pt'))

g_model_our_nuc = MLPg(hidden_dims, residual=False).to(device)
h_model_our_nuc = MLPh(base_dims, hidden_dims, fourier_map=fourier_map, residual=False).to(device)
g_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_g_model_reg_param_' + str(reg_param) + '.pt'))
h_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_h_model_reg_param_' + str(reg_param) + '.pt'))

In [None]:
# Extrapolate avg_values_exact_nuc by copying last value

avg_values_exact_nuc = np.concatenate((avg_values_exact_nuc, np.repeat(avg_values_exact_nuc[-1], len(avg_values_ours) - len(avg_values_exact_nuc))))

# Plot average values

plt.plot(avg_values_ours, label='Our Nuclear Norm', color='C1')
plt.plot(avg_values_exact_nuc, label='Exact Nuclear Norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Average value of function on unit disc')
plt.title("Comparison of average values on unit disc")
plt.legend()

# Horizontal line at correct avg value

correct_val = 1 - d*reg_param

plt.axhline(y=correct_val, color='r', linestyle='--')

In [None]:
# Extrapolate abs_errors_exact_nuc by copying last value

abs_errors_exact_nuc = np.concatenate((abs_errors_exact_nuc, np.repeat(abs_errors_exact_nuc[-1], len(abs_errors_ours) - len(abs_errors_exact_nuc))))

# Smooth the errors with a moving average

window_size = 10
abs_errors_ours = np.convolve(abs_errors_ours, np.ones(window_size)/window_size, mode='valid')
abs_errors_exact_nuc = np.convolve(abs_errors_exact_nuc, np.ones(window_size)/window_size, mode='valid')

# Plot absolute errors

plt.plot(abs_errors_ours, label='Our regularizer', color='C1')
plt.plot(abs_errors_exact_nuc, label='Exact nuclear norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Absolute error (smoothed)')
plt.title("Absolute errors, $\eta = 0.05$")
plt.legend()

plt.savefig('results/abs_errors_d5_eta_0p05.png', dpi=300)

In [None]:
# Extrapolate losses_exact_nuc by copying last value

losses_exact_nuc = np.concatenate((losses_exact_nuc, np.repeat(losses_exact_nuc[-1], len(losses_ours) - len(losses_exact_nuc))))

# Smooth the losses with a moving average

window_size = 10
losses_ours_smoothed = np.convolve(losses_ours, np.ones(window_size)/window_size, mode='valid')
losses_exact_nuc_smoothed = np.convolve(losses_exact_nuc, np.ones(window_size)/window_size, mode='valid')

# Plot loss

plt.plot(np.log(losses_ours_smoothed), label='Our regularizer', color='C1')
plt.plot(np.log(losses_exact_nuc_smoothed), label='Exact nuclear norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Log loss')
plt.title('Log-losses, $\eta=0.05$')
plt.legend()

plt.savefig('results/log_losses_d5_eta_0p05.png', dpi=300)

In [None]:
# Plot model_exact_nuc on grid

n_grid = 100
x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
# Concatenate 3 cols of zeros to X_grid to match input dimension of model_exact_nuc
X_grid = torch.cat((X_grid, torch.zeros(X_grid.shape[0], 3).to(device)), dim=1)
y_grid_exact = model_exact_nuc(X_grid).squeeze().detach().cpu()

# Heatmap

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_exact.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Model with Exact Nuclear Norm Regularization')

In [None]:
# Plot model_our_nuc on grid

x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
# Concatenate 3 cols of zeros to X_grid to match input dimension of model_exact_nuc
X_grid = torch.cat((X_grid, torch.zeros(X_grid.shape[0], 3).to(device)), dim=1)
model_our_nuc = lambda x: g_model_our_nuc(h_model_our_nuc(x))
y_grid_ours = model_our_nuc(X_grid).squeeze().detach().cpu()

# Heatmap

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_ours.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Model with Our Nuclear Norm Regularization')

### $d=2$ visualizations.

In [None]:
# Generate visualizations for experiments in '/rof_ours_vs_exact_results'

# Load results

reg_param = 0.1

avg_values_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/avg_values_exact_nuc_reg_param_' + str(reg_param) + '.npy')
avg_values_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/avg_values_our_nuc_reg_param_' + str(reg_param) + '.npy')

losses_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/exact_nuc_losses_reg_param_' + str(reg_param) + '.npy')
losses_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_losses_reg_param_' + str(reg_param) + '.npy')

abs_errors_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/abs_errors_exact_nuc_reg_param_' + str(reg_param) + '.npy')
abs_errors_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/abs_errors_our_nuc_reg_param_' + str(reg_param) + '.npy')

# Load models

base_dims = 2
hidden_dims = 100
fourier_map = InputMapping(d_in=base_dims, n_freq=500, sigma=1, incrementalMask=False).to(device)
model_exact_nuc = nn.Sequential(MLPh(base_dims, hidden_dims, fourier_map, residual=False), MLPg(hidden_dims, residual=False)).to(device)
model_exact_nuc.load_state_dict(torch.load('rof_ours_denoising_vs_exact_results_100k_iters/exact_nuc_model_reg_param_' + str(reg_param) + '.pt'))

g_model_our_nuc = MLPg(hidden_dims, residual=False).to(device)
h_model_our_nuc = MLPh(base_dims, hidden_dims, fourier_map=fourier_map, residual=False).to(device)
g_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_g_model_reg_param_' + str(reg_param) + '.pt'))
h_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_h_model_reg_param_' + str(reg_param) + '.pt'))

In [None]:
# Extrapolate avg_values_exact_nuc by copying last value

avg_values_exact_nuc = np.concatenate((avg_values_exact_nuc, np.repeat(avg_values_exact_nuc[-1], len(avg_values_ours) - len(avg_values_exact_nuc))))

# Plot average values

plt.plot(avg_values_ours, label='Our Nuclear Norm', color='C1')
plt.plot(avg_values_exact_nuc, label='Exact Nuclear Norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Average value of function on unit disc')
plt.title("Comparison of average values on unit disc")
plt.legend()

# Horizontal line at correct avg value

correct_val = 1 - base_dims*reg_param

plt.axhline(y=correct_val, color='r', linestyle='--')

In [None]:
# Extrapolate abs_errors_exact_nuc by copying last value

abs_errors_exact_nuc = np.concatenate((abs_errors_exact_nuc, np.repeat(abs_errors_exact_nuc[-1], len(abs_errors_ours) - len(abs_errors_exact_nuc))))

# Plot absolute errors

# Make lines thicker
plt.rcParams.update({'lines.linewidth': 3})
# Make font larger
plt.rcParams.update({'font.size': 14})
# Make font bolder
plt.rcParams.update({'font.weight': 'bold'})
# Also make title font bolder
plt.rcParams.update({'axes.titleweight': 'bold'})
# And make x-label and y-label bolder
plt.rcParams.update({'axes.labelweight': 'bold'})

plt.plot(abs_errors_ours, label='Our regularizer', color='C1')
plt.plot(abs_errors_exact_nuc, label='Exact nuclear norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Absolute error')
plt.title("Absolute errors, $\eta = 0.1$")
plt.legend()

plt.savefig('results/abs_errors_d2_eta_0p1.png', dpi=300)

In [None]:
# Extrapolate losses_exact_nuc by copying last value

losses_exact_nuc = np.concatenate((losses_exact_nuc, np.repeat(losses_exact_nuc[-1], len(losses_ours) - len(losses_exact_nuc))))

# Plot loss

# Make lines thicker
plt.rcParams.update({'lines.linewidth': 3})
# Make font larger
plt.rcParams.update({'font.size': 14})
# Make font bolder
plt.rcParams.update({'font.weight': 'bold'})
# Also make title font bolder
plt.rcParams.update({'axes.titleweight': 'bold'})
# And make x-label and y-label bolder
plt.rcParams.update({'axes.labelweight': 'bold'})

plt.plot(np.log(losses_ours), label='Our regularizer', color='C1')
plt.plot(np.log(losses_exact_nuc), label='Exact nuclear norm', color='C0')
plt.xlabel('Iteration')
plt.ylabel('Log loss')
plt.title('Log-losses, $\eta=0.1$')
plt.legend()

plt.savefig('results/log_losses_d2_eta_0p1.png', dpi=300)

In [None]:
# Plot model_exact_nuc on grid

n_grid = 100
x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
# X_grid = fourier_map(X_grid)
y_grid_exact = model_exact_nuc(X_grid).squeeze().detach().cpu()

# Heatmap

plt.rcParams.update({'font.size': 14})
# Make font bolder
plt.rcParams.update({'font.weight': 'bold'})
# Also make title font bolder
plt.rcParams.update({'axes.titleweight': 'bold'})
# And make x-label and y-label bolder
plt.rcParams.update({'axes.labelweight': 'bold'})

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_exact.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Exact nuclear norm, $\eta = 0.1$')

plt.savefig('results/exact_nuclear_norm_eta_0p1.png', dpi=300)

In [None]:
# Plot model_our_nuc on grid

n_grid = 100
x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
# X_grid = fourier_map(X_grid)
model_our_nuc = lambda x: g_model_our_nuc(h_model_our_nuc(x))
y_grid_ours = model_our_nuc(X_grid).squeeze().detach().cpu()

# Heatmap

plt.rcParams.update({'font.size': 14})
# Make font bolder
plt.rcParams.update({'font.weight': 'bold'})
# Also make title font bolder
plt.rcParams.update({'axes.titleweight': 'bold'})
# And make x-label and y-label bolder
plt.rcParams.update({'axes.labelweight': 'bold'})

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_ours.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)
plt.colorbar()
# plt.xlabel('x1')
# plt.ylabel('x2')
plt.title('Our regularizer, $\eta = 0.1$')

plt.savefig('results/our_nuclear_norm_eta_0p1.png', dpi=300)

In [None]:
# Plot exact solution on grid -- target_fn * (1 - 2*reg_param)

n_grid = 100
x_grid = torch.linspace(-2, 2, n_grid)
X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)
y_grid_target = target_fn(X_grid).squeeze().detach().cpu()
y_grid_exact_solution = y_grid_target * (1 - 2*reg_param)

# Heatmap

plt.rcParams.update({'font.size': 14})
# Make font bolder
plt.rcParams.update({'font.weight': 'bold'})
# Also make title font bolder
plt.rcParams.update({'axes.titleweight': 'bold'})
# And make x-label and y-label bolder
plt.rcParams.update({'axes.labelweight': 'bold'})

plt.figure()
plt.pcolormesh(x_grid, x_grid, y_grid_exact_solution.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)
plt.colorbar()
plt.title('Exact solution, $\eta = 0.1$')

plt.savefig('results/exact_solution_eta_0p1.png', dpi=300)