<a href="https://colab.research.google.com/github/hallpaz/nov23google/blob/main/code/representational_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Representational Networks

## Hallison Paz

### November 8th, 2023

Lecture given at Google Brasil's office at São Paulo.

Additional resources available at [this repository](https://github.com/hallpaz/googlesptalk).



In [19]:
from IPython.display import HTML
HTML('''<iframe width="560" height="315"
        src="https://www.youtube.com/embed/_ZtQ0-tDwbY"
        frameborder="0" allow="accelerometer; autoplay; encrypted-media;
        gyroscope; picture-in-picture" allowfullscreen></iframe>''')



# Training a Representational Network for Images

In [None]:
import torch
from torch import nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor
import matplotlib.pyplot as plt
import imageio
import numpy as np
from typing import Sequence

# Defining a dataset structure

In [None]:
def make_grid_coords(nsamples, start, end, dim, flatten=True):
  if not isinstance(nsamples, Sequence):
      nsamples = dim * [nsamples]
  if not isinstance(start, Sequence):
      start = dim * [start]
  if not isinstance(end, Sequence):
      end = dim * [end]
  if len(nsamples) != dim or len(start) != dim or len(end) != dim:
      raise ValueError("'nsamples'; 'start'; and 'end' should be a single value or have same  length as 'dim'")

  dir_samples = tuple([torch.linspace(start[i], end[i], steps=nsamples[i])
                  for i in range(dim)])
  grid = torch.stack(torch.meshgrid(*dir_samples, indexing='ij'), dim=-1)
  return grid.reshape(-1, dim) if flatten else grid

class ImageDataset(Dataset):
  def __init__(self, filepath, size=0, color_space='L'):
    super().__init__()
    img = Image.open(filepath).convert(color_space)
    if size > 0:
      img = img.resize((size, size))
    else:
      size = img.width
    if color_space == 'L':
      self.channels = 1
    else:
      self.channels = 3
    # N x 2; N = width * height
    self.coords = make_grid_coords(size, -1, 1, 2)
    # N x (1 ou 3)
    self.pixels = to_tensor(img).permute(1, 2, 0).reshape(-1, self.channels)


  def __len__(self):
    #return len(self.pixels)
    return 1

  def __getitem__(self, index):
    return self.coords, self.pixels

def network_to_image(model, channels=3, res=512,
                     return_img=False, device='cpu'):
  coords = make_grid_coords(res, -1, 1, 2).to(device)
  pixels = model(coords).clamp(0, 1).reshape(res, res, channels)
  img = pixels.squeeze(-1).detach().cpu().numpy()
  if return_img:
    return img
  plt.imshow(img)

**Nota:** Imagem do Masp é um recorte de: https://www.flickr.com/photos/governosp/52692996751

In [None]:
!wget https://github.com/hallpaz/nov23google/blob/ecca56318491c043b4741cd230073a9d9a76fac1/img/masp.jpg?raw=true -O masp.jpg

In [None]:
hyper = {
    'width': 512,
    'height': 512,
    'channels': 3,
    'epochs': 1000
}

In [None]:
dataset = ImageDataset("/content/masp.jpg", hyper['width'], "RGB")
dataloader = DataLoader(dataset, hyper['width'] * hyper['width'])

In [None]:
_, pixels = dataset[0]
pixels = pixels.reshape(hyper['width'],
               hyper['height'],
               hyper['channels'])
plt.imshow(pixels)

# Training routine

In [None]:
def train(model, dataloader, hyper, device,
          steps_til_summary=20, gif_path=""):
    dim = hyper['width']
    epochs = hyper['epochs']
    channels = hyper['channels']
    model.to(device)
    model.train()
    optim = torch.optim.Adam(lr=1e-3, params=model.parameters())
    model_input, ground_truth = next(iter(dataloader))
    model_input, ground_truth = model_input.to(device), ground_truth.to(device)

    if gif_path:
        writer = imageio.get_writer(gif_path, mode='I', duration=0.3)

    for step in range(epochs):
        model_output = model(model_input.to(device))
        loss = ((model_output - ground_truth)**2).mean()

        if not (step % steps_til_summary):
            print("Step %d, Total loss %0.6f" % (step, loss))
            # plot_tensor_img(model_output, dim)
            network_to_image(model, channels, dim, device=device)

        if gif_path and (step % 5 == 0):
            img = network_to_image(model, channels, dim, True, device)
            writer.append_data(np.uint8(img * 255))

        optim.zero_grad()
        loss.backward()
        optim.step()
    # last inference
    model.eval()
    network_to_image(model, hyper['channels'], return_img=True, device=device)
    if gif_path:
      writer.append_data(np.uint8(img * 255))
      writer.close()

# Defining a Neural Network model

In [None]:
class ReluNetwork(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Linear(2, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, channels)
    )

  def forward(self, input):
    return self.layers(input)


In [None]:
relu_model = ReluNetwork(3)
network_to_image(relu_model, 3, hyper['width'])

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train(relu_model,
      dataloader,
      hyper,
      device,
      gif_path='masp.gif')

In [None]:
class FourierNetwork(nn.Module):
  def __init__(self, omega_0=30, channels=1):
    super().__init__()
    self.omega_0 = omega_0
    self.first_layer = nn.Linear(2, 256)
    self.layers = nn.Sequential(
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, channels)
    )

    with torch.no_grad():
      self.first_layer.weight.uniform_(-1 / 2, 1 / 2)

  def forward(self, coords):
    x = torch.sin(self.omega_0 * self.first_layer(coords))
    return self.layers(x)

In [None]:
fourier_model = FourierNetwork(30, 3)
train(fourier_model,
      dataloader,
      hyper,
      device,
      gif_path='fourier_masp.gif')

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        # return output, coords
        return output

In [None]:
siren_model = Siren(2, 256, 2, 3, True, 30)
train(siren_model,
      dataloader,
      hyper,
      device,
      gif_path='siren_masp.gif')

# Exploring the model's space continuity

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, Box, interact_manual

In [None]:
slider = widgets.FloatRangeSlider(
    value=[-1.0, 1.0],
    min=-7,
    max=7,
    step=0.1,
    description='Interval:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    layout=widgets.Layout(width='50%')
)

In [None]:
model = siren_model
res = hyper['width']
channels = hyper['channels']
def plot_model(interval):
  model.to(device)
  grid = make_grid_coords(res, *interval, dim=2).to(device)
  output = model(grid)
  model_out = torch.clamp(output, 0.0, 1.0)

  pixels = model_out.cpu().detach().view(res, res, channels).cpu()
  pixels = (pixels * 255).numpy().astype(np.uint8)
  if channels == 1:
      pixels = np.repeat(pixels, 3, axis=-1)
  return Image.fromarray(pixels)

interact(plot_model, interval=slider)