# HIDDEN LAYER INTERACTION
This code builds an interface for interacting with the hidden layers during image synthesis. The P-GAN model in the backend has been altered, so that its convolutional channels can be manipulated.

### Load the model and everything else you need.

In [1]:
import torch
import torch.nn as nn
import torch.hub as hub
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

torch.backends.mps.is_built()
mps_device = torch.device('mps')

#load PGAN 
from models.progressive_gan import ProgressiveGAN as PGAN
gan_local = PGAN()

#load trained checkpoints
gan_local.load_state_dict(torch.load('FashionGen_s6_i96000.pt', map_location=torch.device('mps')))

Average network found !


In [3]:
#INDEXING LAYERS: scale/layer, conv, size
layers = [[0, 0, 512, 4, 4], #0
          [1, 0, 512, 8, 8], #1
          [1, 1, 512, 8, 8], #2
          [2, 0, 512, 16, 16], #3
          [2, 1, 512, 16, 16], #4
          [3, 0, 512, 32, 32], #5
          [3, 1, 512, 32, 32], #6
          [4, 0, 256, 64, 64], #7
          [4, 1, 256, 64, 64], #8
          [5, 0, 128, 128, 128], #9
          [5, 1, 128, 128, 128], #11
          [6, 0, 64, 256, 256], #12
          [6, 1, 64, 256, 256]] #13

layer_labels = [f"{layer[0]}-{layer[1]}" for layer in layers]

layers = np.array(layers)

layer_index = 4 #initialize to layer 2-0

In [5]:
#CHOOSE NOISE
# noise, _ = gan_local.buildNoiseData(1) #random noise

#save noise you like
# torch.save(noise, f'my_data/noise/fave/timmy.pt')
# torch.save(noise, f'my_data/noise/fave/plaindress.pt')

#load noise
noise = torch.load('my_data/noise/fave/tina.pt')

## The interface:

In [6]:
def slidechannels(layer_index, channel, manipulation_tech, factor, vert_range, hor_range):
    # noise, _ = gan_local.buildNoiseData(1) #random noise
    scale = layers[layer_index][0]
    conv = layers[layer_index][1]
    genimg = gan_local.netG(noise,
                        manipulation_tech=manipulation_tech,
                        factor=[factor],
                        layer=scale,
                        conv=conv,
                        channel=[channel],
                        pos_x=slice(vert_range[0], vert_range[1]),
                        pos_y=slice(hor_range[0], hor_range[1])
                        )
    plt.imshow(genimg[0].permute(1,2,0).detach().numpy())
    plt.title(f"Layer {layers[layer_index][0]}-{layers[layer_index][1]}, Channel {channel}")
    plt.text(95, 268, f"{manipulation_tech} ({np.round(factor, decimals=2)})")
    plt.axis('off')
    ax = plt.gca()
    rect = plt.Rectangle(
        (0, 0), 1, 1,
        transform=ax.transAxes, 
        linewidth=1, 
        edgecolor="black", 
        facecolor="none" 
    )
    ax.add_patch(rect)
    plt.show()

def update_sliders(change):
    layer_index = change['new']
    channel_slider.max = layers[layer_index][2]-1
    vertical_slider.max = layers[layer_index][4]
    horizontal_slider.max = layers[layer_index][4]
    vertical_slider.value = [0, layers[layer_index][4]]
    horizontal_slider.value = [0, layers[layer_index][4]]

#dropdowns for layer, and sliders for channel and factor using ipywidgets
layer_dropdown = widgets.Dropdown(options=[(label, idx) for idx, label in enumerate(layer_labels)], value=3, description='Layer:', layout=widgets.Layout(width='20%'))
mani_dropdown = widgets.Dropdown(options=('Scaling', 'Adding', 'FeatViz', 'Overwriting', 'Random'), description='Method:')
channel_slider = widgets.IntSlider(min=0, max=layers[layer_index][2], step=1, value=390, description='Channel:', layout=widgets.Layout(width='70%'))
factor_slider = widgets.FloatSlider(min=-20.0, max=20.0, step=0.01, value=1.0, description='Factor:', layout=widgets.Layout(width='50%'))

#slicing sliders
vertical_slider = widgets.IntRangeSlider(min=0, max=layers[layer_index][4], step=1, value=[0, layers[layer_index][4]], description='Vertical:', orientation='horizontal', readout=True)
horizontal_slider = widgets.IntRangeSlider(min=0, max=layers[layer_index][4], step=1, value=[0, layers[layer_index][4]], description='Horizontal:', orientation='horizontal', readout=True)

#interactive output that updates when any widget changes
interactive_output = widgets.interactive_output(slidechannels, {'layer_index': layer_dropdown, 'channel': channel_slider, 'manipulation_tech': mani_dropdown, 'factor': factor_slider, 'vert_range': vertical_slider, 'hor_range': horizontal_slider})
layer_dropdown.observe(update_sliders, names='value')

#display the widgets and the interactive output
display(widgets.HBox([layer_dropdown, channel_slider]))
display(widgets.HBox([vertical_slider, horizontal_slider]))
display(widgets.HBox([mani_dropdown, factor_slider]))
display(interactive_output)

HBox(children=(Dropdown(description='Layer:', index=3, layout=Layout(width='20%'), options=(('0-0', 0), ('1-0'…

HBox(children=(IntRangeSlider(value=(0, 16), description='Vertical:', max=16), IntRangeSlider(value=(0, 16), d…

HBox(children=(Dropdown(description='Method:', options=('Scaling', 'Adding', 'FeatViz', 'Overwriting', 'Random…

Output()

### Interfering two channels:

In [36]:
def slidechannels(layer_index, channel1, channel2, manipulation_tech, factor1, factor2, vert_range, hor_range):
    scale = layers[layer_index][0]
    conv = layers[layer_index][1]
    genimg = gan_local.netG(noise,
                        manipulation_tech=manipulation_tech,
                        factor=[factor1, factor2],
                        layer=scale,
                        conv=conv,
                        channel=[channel1, channel2],
                        pos_x=slice(vert_range[0], vert_range[1]),
                        pos_y=slice(hor_range[0], hor_range[1])
                        )
    plt.imshow(genimg[0].permute(1,2,0).detach().numpy())
    # plt.title(f"Layer {layers[layer_index][0]}-{layers[layer_index][1]}, Channel {channel}, Factor {}")
    plt.title(f"Layer {layers[layer_index][0]}-{layers[layer_index][1]}, Channel {channel1}, {channel2}")
    plt.text(85, 268, f"Combining ({np.round(factor1, decimals=2)}, {np.round(factor2, decimals=2)})")
    plt.axis('off')
# Create a black border by adding a rectangle to the axes
    ax = plt.gca()  # Get current axes
    rect = plt.Rectangle(
        (0, 0), 1, 1,  # Bottom-left corner and dimensions as fraction of axes
        transform=ax.transAxes,  # Use axes coordinates
        linewidth=1,  # Border thickness
        edgecolor="black",  # Border color
        facecolor="none"  # Transparent background
    )
    ax.add_patch(rect)
    plt.show()

def update_sliders(change):
    layer_index = change['new']
    channel_slider.max = layers[layer_index][2]-1
    vertical_slider.max = layers[layer_index][3]
    horizontal_slider.max = layers[layer_index][4]
    vertical_slider.value = [0, layers[layer_index][4]]
    horizontal_slider.value = [0, layers[layer_index][4]]

#interactive dropdowns for layer, and sliders for channel and factor using ipywidgets
layer_dropdown = widgets.Dropdown(options=list(range(len(layers))), value=3, description='Layer:', layout=widgets.Layout(width='20%'))
mani_dropdown = widgets.Dropdown(options=('Scaling', 'Adding', 'Overwriting', 'Random', 'Inferring'), description='Method:')
channel1_slider = widgets.IntSlider(min=0, max=layers[layer_index][2], step=1, value=109, description='Channel1:', layout=widgets.Layout(width='35%'))
factor1_slider = widgets.FloatSlider(min=-20.0, max=20.0, step=0.01, value=1.0, description='Factor1:', layout=widgets.Layout(width='35%'))
channel2_slider = widgets.IntSlider(min=0, max=layers[layer_index][2], step=1, value=109, description='Channel2:', layout=widgets.Layout(width='35%'))
factor2_slider = widgets.FloatSlider(min=-20.0, max=20.0, step=0.01, value=1.0, description='Factor2:', layout=widgets.Layout(width='35%'))


#slicing sliders
vertical_slider = widgets.IntRangeSlider(min=0, max=layers[layer_index][4], step=1, value=[0, layers[layer_index][4]], description='Vertical:', orientation='horizontal', readout=True)
horizontal_slider = widgets.IntRangeSlider(min=0, max=layers[layer_index][4], step=1, value=[0, layers[layer_index][4]], description='Horizontal:', orientation='horizontal', readout=True)

#interactive output that updates when any widget changes
interactive_output = widgets.interactive_output(slidechannels, {'layer_index': layer_dropdown, 'channel1': channel1_slider, 'channel2': channel2_slider, 'manipulation_tech': mani_dropdown, 'factor1': factor1_slider, 'factor2': factor2_slider, 'vert_range': vertical_slider, 'hor_range': horizontal_slider})
layer_dropdown.observe(update_sliders, names='value')

#display the widgets and the interactive output
display(widgets.HBox([layer_dropdown, channel1_slider, channel2_slider]))
display(widgets.HBox([vertical_slider, horizontal_slider]))
display(widgets.HBox([mani_dropdown, factor1_slider, factor2_slider]))
display(interactive_output)

HBox(children=(Dropdown(description='Layer:', index=3, layout=Layout(width='20%'), options=(0, 1, 2, 3, 4, 5, …

HBox(children=(IntRangeSlider(value=(0, 16), description='Vertical:', max=16), IntRangeSlider(value=(0, 16), d…

HBox(children=(Dropdown(description='Method:', options=('Scaling', 'Adding', 'Overwriting', 'Random', 'Inferri…

Output()

read more about the project here: https://doi.org/10.1145/3715336.3735437