In [None]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from tifffile import imread, imwrite
from tqdm.notebook import tqdm

: 

In [None]:
# SIREN network
# Code adapted from the following GitHub repository:
# https://github.com/vsitzmann/siren?tab=readme-ov-file
from torch import nn
import torch
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt
from tqdm import trange
from torch.utils.data import DataLoader, Dataset
import matplotlib
from IPython import display

class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a    # hyperparameter.
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                            1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                            np.sqrt(6 / self.in_features) / self.omega_0)
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))
        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))
        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                             np.sqrt(6 / hidden_features) / hidden_omega_0)
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))
        self.net = nn.Sequential(*self.net)
        
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords
    
    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()
        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)
                if retain_grad:
                    x.retain_grad()
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1
        return activations

class ImageFitting(Dataset):
    def __init__(self, id=0, batch_size=65536):
        super().__init__()
        self.batch_size = batch_size
        img = load_movie_tensor(id)
        self.pixels = img.reshape(-1, 1)
        self.coords = get_mgrid(img)
        self.batches = int(np.ceil(len(self.pixels) / self.batch_size))
    def __len__(self):
        return self.batches
    def __getitem__(self, idx):
        return self.coords[idx*self.batch_size:(idx+1)*self.batch_size], self.pixels[idx*self.batch_size:(idx+1)*self.batch_size]

def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)

def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i + 1]
    return div

def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

def load_movie_tensor(id):
    file = f"data_batch_01/resliced_{id:03d}.tif"
    img = imread(file)
    tensor = torch.from_numpy(img) # comes as (y,t,x)
    tensor = tensor.permute(2,0,1) # (x,y,t)
    return tensor

def get_mgrid(imgtensor, dim=3):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int    dim: int'''
    shape = imgtensor.shape[::-1]
    tensors = tuple([torch.linspace(-1, 1, steps=i) for i in shape])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

: 

In [None]:
if __name__ == '__main__':
    device = "cpu"
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    resliced = ImageFitting(22)
    dataloader = DataLoader(resliced, batch_size=1, pin_memory=True, num_workers=0)
    img_siren = Siren(in_features=3, out_features=1, hidden_features=256,
                      hidden_layers=3, outermost_linear=True, first_omega_0=80, hidden_omega_0=80.)
    img_siren.to(device)

    epochs = 20  # Since the whole image is our dataset, this just means 500 gradient descent steps.

    optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())

    for epoch in trange(epochs):
        accumulated_outputs = []
        for step, (model_input, ground_truth) in tqdm(enumerate(dataloader),'step',resliced.batches):
            model_input = model_input.to(device)
            model_output, coords = img_siren(model_input)
            accumulated_outputs.append(model_output)
            loss = ((model_output - ground_truth) ** 2).mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
        output_scene = accumulated_outputs[0]
        for output in accumulated_outputs[1:]:
            output_scene = torch.cat((output_scene,output), dim=0)
        print(output_scene.shape)
        output_scene = output_scene.view(256, 256, 100).cpu().detach().numpy()
        imwrite(output_scene, f"output_{epoch}.tif")

plt.close()

: 

In [None]:
def load_movie_tensor(id):
    file = f"data_batch_01/resliced_{id:03d}.tif"
    img = imread(file)
    tensor = torch.from_numpy(img) # comes as (y,t,x)
    tensor = tensor.permute(2,0,1) # (x,y,t)
    return tensor

tensor = load_movie_tensor(0)
tensor.shape

In [None]:
def get_mgrid(imgtensor, dim=3):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int    dim: int'''
    shape = imgtensor.shape[::-1]
    tensors = tuple([torch.linspace(-1, 1, steps=i) for i in shape])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

get_mgrid(tensor)

In [None]:
img = load_movie_tensor(22)
print(img.shape)
pixels = img.reshape(-1,1)
pixels.shape
# self.coords = get_mgrid(img)