In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import alpine

In [None]:
import torch
import numpy as np
import skimage.data, skimage.io


from matplotlib import pyplot as plt
from functools import partial
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics import MetricTracker

In [None]:
import alpine

In [None]:
from alpine.models import Siren
from alpine.models.utils import get_coords_nd, get_coords_spatial
from alpine.vis import pca

In [None]:
epochs = 2000

In [None]:
siren_model = Siren(in_features=2, out_features=3, hidden_features=256, hidden_layers=5, outermost_linear=True).cuda()
print(siren_model)
siren_model.compile()
# siren_model.compile(scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda x:  0.1**min(x/1000,1)))

In [None]:
coords = get_coords_spatial(256, 256).cuda()[None,...]
print(coords.shape)

In [None]:
output = siren_model(coords)
print(output['output'].shape)

In [None]:
import skimage.transform


gt_img = skimage.transform.resize(skimage.data.astronaut(), (256,256))
gt = torch.from_numpy(gt_img).float().cuda()[None,...]



In [None]:
print(siren_model)

In [None]:
print(epochs, coords.shape)

In [None]:
outputs = siren_model.fit_signal(coords, signal=gt, n_iters=epochs, enable_tqdm=True, return_features=True, 
                                    metric_trackers = {'psnr': MetricTracker(PeakSignalNoiseRatio().to('cuda'))}, track_loss_histroy=True)

In [None]:
output = outputs['output']
all_metrics = outputs['metrics']
for m, mv in all_metrics.items():
    plt.figure()
    plt.plot(mv.numpy())
    plt.title(m)
    plt.show()

### Visualizing INR output

In [None]:
# output = siren_model.render(coords)['output']
plt.figure()
plt.imshow(output.cpu().detach().numpy().reshape(256,256,3))
plt.axis('off')
plt.show()

### Visualizing PCA features of learned INR features

In [None]:
print(outputs.keys())
features = torch.stack(outputs['features'],dim=0).squeeze()
print(features.shape)

In [None]:
feats = pca.compute_pca_features(features[None,...], num_components=5, signal_shape=(256,256))
print(feats.shape)

In [None]:
fig, ax = plt.subplots(feats.shape[0], feats.shape[-1])
for i in range(feats.shape[0]):
    for j in range(feats.shape[-1]):
        ax[i,j].imshow(feats[i,...,j].cpu().detach().numpy())
        ax[i,j].axis('off')

plt.suptitle("PCA features of all INR layers")
plt.show()

### Visualizing deep network geometry using Local Complexity by Humayun et.al

In [None]:
from alpine.vis import partitions

In [None]:
NUM_SAMPLED_POINTS = 1024

In [None]:
input_space_subdivision = partitions.get_partitions_from_inr(
    x_bounds=[-1, 1],
    y_bounds=[-1, 1],
    model = siren_model,
    signal_dims = (256, 256),
    sampled_points = NUM_SAMPLED_POINTS,
    sampled_points_batch_sizes= 256,
)

In [None]:
partitions.show_partitions(input_space_subdivision.detach().cpu().numpy(), normalize_each=True, dpi=80)