# 1. Setup

In [None]:
!pip install -q -U einops datasets matplotlib tqdm torchmetrics torchvision plotly scipy altair

In [None]:
!pip install gradio==v3.13.0

In [None]:
#External libraries
import math
import numpy as np
import imageio

import gradio as gr

from sklearn.metrics import mean_absolute_error
from skimage.metrics import structural_similarity as ssim
from scipy.spatial import distance

import torch
from torchvision import transforms
from torch.optim import Adam

import plotly.express as px

In [None]:
#Local functions
from DDPM import GaussianDiffusion, linear_beta_schedule
from UNet import Unet
from hooks import ObscureChannelHook
import interface_helpers
import visualization

# 2. Settings

In [None]:
label_list = ['Bangs', 'Male','Smiling', 'Wearing_Lipstick']
block_list = ["init_conv", "downs.0", "downs.1", "downs.2", "downs.3", "mid", "ups.0", "ups.1", "ups.2", "final_conv"]
layer_list = ["init_conv",
          "downs.0.0.ds_conv","downs.0.0.net","downs.0.0.res_conv","downs.0.1.ds_conv","downs.0.1.net","downs.0.1.res_conv","downs.0.2.fn","downs.0.3",
          "downs.1.0.ds_conv","downs.1.0.net","downs.1.0.res_conv","downs.1.1.ds_conv","downs.1.1.net","downs.1.1.res_conv","downs.1.2.fn","downs.1.3",
          "downs.2.0.ds_conv","downs.2.0.net","downs.2.0.res_conv","downs.2.1.ds_conv","downs.2.1.net","downs.2.1.res_conv","downs.2.2.fn","downs.2.3",
          "downs.3.0.ds_conv","downs.3.0.net","downs.3.0.res_conv","downs.3.1.ds_conv","downs.3.1.net","downs.3.1.res_conv","downs.3.2.fn","downs.3.3",
          "ups.0.0.ds_conv","ups.0.0.net","ups.0.0.res_conv","ups.0.1.ds_conv","ups.0.1.net","ups.0.1.res_conv","ups.0.2.fn","ups.0.3",
          "ups.1.0.ds_conv","ups.1.0.net","ups.1.0.res_conv","ups.1.1.ds_conv","ups.1.1.net","ups.1.1.res_conv","ups.1.2.fn","ups.1.3",
          "ups.2.0.ds_conv","ups.2.0.net","ups.2.0.res_conv","ups.2.1.ds_conv","ups.2.1.net","ups.2.1.res_conv","ups.2.2.fn","ups.2.3",
          "mid_block1.ds_conv","mid_block1.net","mid_block1.res_conv", "mid_attn.fn", "mid_block2.ds_conv","mid_block2.net","mid_block2.res_conv",
          "final_conv"]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#Labels for both 64 and 128 model
label_map = {
    "Bangs": 8,
    "Male": 4,
    "Smiling": 2,
    "Wearing_Lipstick": 1
}

In [None]:
hooks = {}

# 4. Interface

In [None]:
def load_model(model_select):
    """
    Load model and configure DDPM process. As Gradio does not offer an option to return model as a datatype,
    we need to declare the model and diffusion as global variable to access from other functions.

    Parameters:
    ----------
    model_select: str
        ID of the selected model as a string.

    Returns:
    --------
    blocks_select: list[str]: list containing all block IDs of the UNet.
    act_unet_image: imageio.Image. Image displaying the model architecture.
    grad_unet_image: imageio.Image. Image displaying the model architecture.

    """

    global model, diffusion

    timesteps = 1000

    if model_select == "CelebA 128":
        channels, image_size, mode = 3, 128, 'c'
        model = Unet(
            dim=64,
            channels = channels+1,
            out_dim = channels,
            dim_mults=(1, 2, 4, 8)

        )
        model.load_state_dict(torch.load('models/128_RGB/model_c_128_rgb_chkpt_32.pth'))
        unet_image = imageio.imread('images/unet_128.png')

    elif model_select == "CelebA 64":
        channels, image_size, mode = 3, 64, 'c'
        model = Unet(
            dim=64,
            channels = channels+1,
            out_dim = channels,
            dim_mults=(1, 2, 4, 8,)

        )
        model.load_state_dict(torch.load('models/64_RGB_NEW/model_c_64_rgb_chkpt_50.pth'))
        unet_image = imageio.imread('images/unet_64.png')
    else:
        raise AssertionError ("Please select a model before proceeding")



    model.to(device)

    betas = linear_beta_schedule(timesteps)

    optimizer = Adam(model.parameters(), lr= 2e-4)

    diffusion = GaussianDiffusion(
                    model,
                    mode = mode,
                    image_size = image_size,
                    channels = channels,
                    timesteps = timesteps,
                    loss_type = 'l1',
                    betas = betas,
                    device = device
                    )


    return {
           blocks_select: gr.update(choices=block_list),
           act_unet_image: unet_image,
           grad_unet_image: unet_image,
           }

def configure_interface(blocks_select):
    """
    Configures interface. Updates the selection option of layers depending on selected blocks.

    Parameters:
    ----------
    blocks_select: list[str]
        list containing the selection of six blocks by their names

    Returns:
    --------
    l_b0: gr.Checkbox with choices being all layers in the zeroeth block.
    l_b1: gr.Checkbox with choices being all layers in the first block.
    l_b2: gr.Checkbox with choices being all layers in the second block.
    l_b3: gr.Checkbox with choices being all layers in the third block.
    l_b4: gr.Checkbox with choices being all layers in the fourth block.
    l_b5: gr.Checkbox with choices being all layers in the fifth block.
    l_b6: gr.Checkbox with choices being all layers in the sixth block.
    mask_layer: gr.Checkbox with choices being all layers in the network.

    """

    assert model is not None, "Please select a model to proceed"
    assert len(blocks_select) == 6, "Please select 6 blocks to proceed"

    return {
            l_b0: gr.update(choices=[s for s in layer_list if blocks_select[0] in s], visible=True),
            l_b1: gr.update(choices=[s for s in layer_list if blocks_select[1] in s], visible=True),
            l_b2: gr.update(choices=[s for s in layer_list if blocks_select[2] in s], visible=True),
            l_b3: gr.update(choices=[s for s in layer_list if blocks_select[3] in s], visible=True),
            l_b4: gr.update(choices=[s for s in layer_list if blocks_select[4] in s], visible=True),
            l_b5: gr.update(choices=[s for s in layer_list if blocks_select[5] in s], visible=True),
            mask_layer: gr.update(choices=layer_list)
           }


In [None]:
def diffuse(labels, sample_res, l_b0, l_b1, l_b2, l_b3, l_b4, l_b5, random_seed, track_mode, image_mask = None, switch_labels = None, switch_step = None):
    """
        Runs diffusion for given input parameters and returns samples and analysis figures.

        Parameters:
        ----------
        labels (gr.CheckboxGroup): Contains selected labels to condition the model
        sampes_res (int): Sampling resoluiton, defining interval of extracting activations/gradient
        l_b0,...,l_b5 (str): Name of the selected layer to track.
        random_seed (int): Intitial seed for the diffusion process.
        track_mode (int): Flag indicating whether activations or gradients should be sampled.
        image_mask (gr.Image): Gradio binary mask.

        Returns:
        --------
    """

    y = sum([label_map[k] for k in labels])
    y = torch.tensor([y]).to(device)



    layers = [l_b0, l_b1, l_b2, l_b3, l_b4, l_b5]

    random_seed = int(random_seed)


    sample_steps = [k  for k in np.arange(0,1000)[::-1] if k % sample_res == 0 and k > 0]

    if track_mode == 0: #Activations
        if switch_labels is not None and switch_step is not None:
            s_y = sum([label_map[k] for k in switch_labels])
            s_y = torch.tensor([s_y]).to(device)
            switch_step = int(switch_step)
        analysis_settings = interface_helpers.set_analysis_settings(sample_steps, layers, track_mode = track_mode, mask = None, s_y=s_y, s_t=switch_step)

    elif track_mode == 1: #Gradients
        mask = torch.from_numpy(np.asarray(image_mask["mask"])[:,:,0:1]).to(device).reshape(1,diffusion.image_size,diffusion.image_size)
        analysis_settings = interface_helpers.set_analysis_settings(sample_steps = sample_steps, layers = layers, track_mode = track_mode, mask = mask,s_y=None, s_t=[])

    #Sample from DDPM with previously specified settings
    samples, gv, acts, grads = diffusion.sample(batch_size=1, random_seed=random_seed, condition = y, analysis=analysis_settings)

    #SSIM for every timestep with previous one
    ssim_t = [ssim(np.moveaxis(samples[i][0], [0], [2]), np.moveaxis(samples[i-1][0], [0], [2]), multichannel=True) for i in range(1,1000)]

    if track_mode == 0: #Activations
        return {
                    act_out_sample: get_sample_out(samples),
                    ssim_fig: visualization.line_plot(values = ssim_t[::-1], x_range = range(1,1000),title='SSIM with previous sample over time', x_title="Time Step", y_title="SSIM"),
                    activation_dist_fig: visualization.dist_fig(acts, track_mode, layers, sample_steps),
                    activation_step: gr.update(choices=[str(i) for i in sample_steps]),
                    activation_layers: layers,
                    activations: acts
                }

    elif track_mode == 1: #Gradients

        return {
                    grad_out_sample: get_sample_out(samples),
                    gradient_step: gr.update(choices=[str(i) for i in sample_steps]),
                    gradient_layers: layers,
                    ssim_fig:visualization.line_plot(values = ssim_t, x_range = range(1,1000),title='SSIM with previous sample over time', x_title="Time Step", y_title="SSIM"),
                    gradient_dist_fig: visualization.dist_fig(grads, track_mode, layers, sample_steps),
                    gradient_volume_fig: visualization.gradient_volume_fig(gv, sample_steps, im_shape = diffusion.image_size),
                    gradients:grads
                }


def compare_diffuse(label1, label2, random_seed, sample_res, l_b0, l_b1, l_b2, l_b3, l_b4, l_b5, track_mode, image_mask = None, vs_grad_step= None):
    y1 = sum([label_map[k] for k in label1])
    y1 = torch.tensor([y1]).to(device)
    y2 = sum([label_map[k] for k in label2])
    y2 = torch.tensor([y2]).to(device)

    layers =  [l_b0,l_b1,l_b2,l_b3,l_b4, l_b5]

    sample_steps = [k  for k in np.arange(0,1000)[::-1] if k % sample_res == 0 and k > 0]

    if track_mode == 1:#Gradients
        print(vs_grad_step)
        sample_steps = [int(vs_grad_step)]
        assert image_mask is not None, 'Please input a image mask for gradient calculation'
        #get gradient mask from input
        mask = torch.from_numpy(np.asarray(image_mask["mask"])[:,:,0:1]).to(device).reshape(1,diffusion.image_size,diffusion.image_size)
        analysis_settings = interface_helpers.set_analysis_settings(sample_steps = sample_steps, layers = layers, track_mode = track_mode, mask = mask,s_y=None, s_t=[])
    else:
        analysis_settings = interface_helpers.set_analysis_settings(sample_steps = sample_steps, layers = layers, track_mode=0, mask = None,s_y=None, s_t=[])

    #Sample twice from same noise seed
    samples1, gv1, acts1, grads1 = diffusion.sample(batch_size=1, random_seed=random_seed, condition = y1, analysis=analysis_settings)
    samples2, gv2, acts2, grads2 = diffusion.sample(batch_size=1, random_seed=random_seed, condition = y2, analysis=analysis_settings)


    ssim_t = [ssim(samples1[i][0].T, samples2[i][0].T, multichannel=True) for i in range(0,1000)]
    ssim_fig = visualization.line_plot(values= ssim_t, x_range = range(0,1000)[::-1],title='SSIM between Label '+str(y1.item())+' and Label '+str(y2.item()), x_title="Sample Step", y_title="SSIM")

    if track_mode == 0: #Activations
        distance_fig = visualization.vs_dist_fig(acts1, acts2, track_mode, layers, sample_steps, y1, y2)
        return get_sample_out(samples1),get_sample_out(samples2), distance_fig, ssim_fig

    elif track_mode == 1: #Gradients
        img_vol1_fig = visualization.gradient_volume_fig(gv1, str(vs_grad_step), diffusion.image_size)
        img_vol2_fig = visualization.gradient_volume_fig(gv2, str(vs_grad_step), diffusion.image_size)
        return get_sample_out(samples1), get_sample_out(samples2), ssim_fig, img_vol1_fig, img_vol2_fig

    elif track_mode == 2: # Attention
        map1 = interface_helpers.interpolate_attention_map(acts1, sample_steps)
        map2 = interface_helpers.interpolate_attention_map(acts2, sample_steps)

        attn_dists = [np.linalg.norm(map1[step] - map2[step]) for step in np.arange(len(map1))]

        attn_dist_fig = visualization.scatter_plot(x=sample_steps,y=attn_dists, title="Euclidean Distance between Attention Maps of Label "+str(y1.item())+" and Label "+str(y2.item()), x_title="Timestep", y_title="Euclidean Distance")

        attn_map1_fig = visualization.img_plot(values=np.stack(map1, axis=0), slider_title="Timestep",color="Inferno", title="Attention Map Label "+str(y1.item()))

        attn_map2_fig = visualization.img_plot(values=np.stack(map2, axis=0), slider_title="Timestep",color="Inferno", title="Attention Map Label "+str(y2.item()))

        return get_sample_out(samples1), get_sample_out(samples2), ssim_fig, attn_map1_fig, attn_map2_fig, attn_dist_fig


def generate_activation_maps(ts, signal, layers, image_mask=None):
    ts  = int(ts) #read timestep and convert to int

    fig_dic = {}#initialize emtpy dictionaries for figures and max channels

    for i in range(len(layers)):

        values = signal[str(ts)][layers[i]][0,:,:,:]
        values = (values-values.min())/(values.max()-values.min())
        fig_dic[str(i)] = visualization.img_plot(values, slider_title="Channel", color="Inferno", title=layers[i])

    print("done")
    fig_dic["0"].show()
    fig_dic["1"].show()
    fig_dic["2"].show()
    fig_dic["3"].show()
    fig_dic["4"].show()
    fig_dic["5"].show()
    return fig_dic["0"], fig_dic["1"], fig_dic["2"], fig_dic["3"], fig_dic["4"], fig_dic["5"]

def generate_gradient_maps(ts, signal, layers):
    ts  = int(ts) #read timestep and convert to int

    fig_dic = {}#initialize emtpy dictionaries for figures and max channels

    for i in range(len(layers)):
        values = np.stack(signal[str(ts)][layers[i]], axis=0).sum(axis=0)[0]
        values = (values-values.min())/(values.max()-values.min())
        fig_dic[str(i)] = visualization.img_plot(values,slider_title="Channel", color="Ice", title=layers[i])

    print("done")
    fig_dic["0"].show()
    fig_dic["1"].show()
    fig_dic["2"].show()
    fig_dic["3"].show()
    fig_dic["4"].show()
    fig_dic["5"].show()

    return fig_dic["0"], fig_dic["1"], fig_dic["2"], fig_dic["3"], fig_dic["4"], fig_dic["5"]


def calculate_max_channels(ts, signal, layers, mode, image_mask = None):
    ts  = int(ts)
    mc_dic = {}

    for i in range(len(layers)):
        if mode == 0: #activations
            data = signal[str(ts)][layers[i]][0,:,:,:]
            data = (data-data.min())/(data.max()-data.min())

        elif mode == 1: #gradients
            data = np.stack(signal[str(ts)][layers[i]], axis=0).sum(axis=0)[0]
            data = (data-data.min())/(data.max()-data.min())

        mc_dic[str(i)] = interface_helpers.max_sum_channel(data = data, img_mask=image_mask)

    return mc_dic["0"], mc_dic["1"], mc_dic["2"], mc_dic["3"], mc_dic["4"], mc_dic["5"]

def set_msk_ch(layer_id, channel):

    layer = dict([*model.named_modules()])[layer_id]
    hooks[layer_id + str(channel)] =  ObscureChannelHook(layer, int(channel))

    print("Hook attached to layer", layer_id, " at channel", int(channel))

def remove_msk_ch(layer_id, channel):
    hooks[layer_id + str(channel)].close()

    print("Hook removed from layer", layer_id, " at channel", int(channel))

def get_sample_out(samples):
    #Plot every 100th samples, normalize it and swap axes to simplify plot later on
    sample_out = [interface_helpers.normalize(i) for i in np.vstack(samples)[::100,:,:,:][1:].swapaxes(1,2).swapaxes(2,3)]
    #Append last sample at step 999
    sample_out.append(interface_helpers.normalize(samples[999][0,:,:,:].swapaxes(0,1).swapaxes(1,2)))

    return sample_out



In [None]:
with gr.Blocks() as demo:
    gr.Markdown(
      """
      # <center> XAI Diffusion Interface </center>
      """
    )

######################################### Settings ########################################################
    with gr.Row():

        gr.Markdown("""
                                #### ⌛ Load Model:
                            """)
    with gr.Row():
        model_select = gr.Dropdown(choices=["CelebA 64", "CelebA 128"],label="Which model would you like to use?")
        btn_model_select = gr.Button("Select Model")
    with gr.Row():
        gr.Markdown("""
                            #### 🧱 Select UNet blocks to analyze:
                            """)
    with gr.Row():
        blocks_select = gr.CheckboxGroup(choices=[""], label="Pick 6 blocks of the UNet to analyze")
        btn_blocks_select = gr.Button("Select Blocks")

    #states to store activations and gradients
    activations, gradients  = gr.State([]), gr.State([])

    with gr.Row():
        gr.Markdown("""
                            #### 📣 Set initial noise, sampling resolution and labels:
                            """)
    with gr.Row():
        with gr.Column():
            sample_res = gr.Slider(minimum=0, maximum= 998, label="Sample resolution")
        with gr.Column():
            random_seed = gr.Number(value=67654654, label="Initial Random Noise", precision=0)
        with gr.Column():
            diffusion_labels = gr.CheckboxGroup(choices=label_list, label="Labels you want to condition on")

    with gr.Row():# as set_2:
        gr.Markdown("""
                    #### ⛏️ Define Image Mask for gradient calculation and pick layers to extract activations/gradients from:
                            """)
    with gr.Row():
        with gr.Column():

            image_mask = gr.Image(
                        source = "upload",
                        tool="sketch",
                        image_mode="L",
                        label="Input Mask for gradient calculation",
                        type="pil"
                    ).style(height=220)
        with gr.Column():
            l_b0 = gr.Dropdown(choices=[""], label="Pick layer in Block 1")
            l_b1 = gr.Dropdown(choices=[""], label="Pick layer in Block 2")
            l_b2 = gr.Dropdown(choices=[""], label="Pick layer in Block 3")
        with gr.Column():
            l_b5 = gr.Dropdown(choices=[""], label="Pick layer in Block 6")
            l_b4 = gr.Dropdown(choices=[""],label="Pick layer in Block 5")
            l_b3 = gr.Dropdown(choices=[""], label="Pick layer in Block 4")

    with gr.Row():
        with gr.Column():
            gr.Markdown("""
                            #### 😷 Channel Masking:
                        """)
            with gr.Row():
                with gr.Column():
                    mask_layer = gr.Dropdown(label="Select the layer")
                with gr.Column():
                    mask_channel = gr.Number(label="Input the channel")
            with gr.Row():
                with gr.Column():
                    btn_msk_channel = gr.Button("Attach mask to channel")
                with gr.Column():
                    btn_unmsk_channel = gr.Button("Remove mask from channel")
        with gr.Column():
            gr.Markdown("""
                            #### 🔄 Label Switch:
                        """)
            with gr.Row():
                with gr.Column(min_width=300):
                    switch_labels = gr.CheckboxGroup(choices=label_list, label="Label to switch to")
                with gr.Column():
                    switch_step = gr.Number(label="Timestep of Switch")

######################################### ACTIVATIONS ########################################################

    with gr.Row():
        with gr.Tab("Activations"):
            activation_layers = gr.State([])
            with gr.Row():
                with gr.Column():
                    btn_diffuse_act = gr.Button("1. Sample with Activations")
            with gr.Row():
                act_out_sample = gr.Gallery(
                    label="Every 100th sample"
                ).style(grid=[10], height="auto")

            with gr.Row():
                with gr.Column():

                    with gr.Row():
                        fig0 = gr.Plot(label="Activations Block 1")

                    with gr.Row():
                        fig1 = gr.Plot(label="Activations Block 2")

                    with gr.Row():
                        fig2 = gr.Plot(label="Activations Block 3")

                with gr.Column(min_width=800):
                    with gr.Row():
                        act_unet_image = gr.Image(label="UNet")
                    with gr.Row():
                            activation_step = gr.Radio(choices=[0], label="Select Timestep for Activations")
                    with gr.Row():
                        with gr.Column():
                            mc0 = gr.Number(value=0,label="Max Channel Block 1")
                            mc1 = gr.Number(value=0,label="Max Channel Block 2")
                            mc2 = gr.Number(value=0,label="Max Channel Block 3")
                        with gr.Column():
                            mc5 = gr.Number(value=0,label="Max Channel Block 6")
                            mc4 = gr.Number(value=0,label="Max Channel Block 5")
                            mc3 = gr.Number(value=0,label="Max Channel Block 4")
                        with gr.Row():
                            acts_btn = gr.Button("2. Compute Activation Maps")
                            a_cnl_btn = gr.Button("3. Calculate Max Channels")

                with gr.Column():

                    with gr.Row():
                        fig5 = gr.Plot(label="Activations Block 6")

                    with gr.Row():
                        fig4 = gr.Plot(label="Activations Block 5")

                    with gr.Row():
                        fig3 = gr.Plot(label="Activations Block 4")


            with gr.Row():
                with gr.Column():
                    ssim_fig = gr.Plot(label="SSIM between sample and previous one over time")
                with gr.Column(min_width=700):
                    activation_dist_fig = gr.Plot(label="Euclidean Distance between activation maps over time")


    ######################################### GRADIENT ########################################################

        with gr.Tab("Gradient"):
            gradient_layers = gr.State([])

            with gr.Row():
                with gr.Column():
                    btn_diffuse_grad = gr.Button("1. Sample with Gradient")

            with gr.Row():

                grad_out_sample = gr.Gallery(
                        label="Every 100th sample"
                    ).style(grid=[10], height="auto")

            with gr.Row():
                with gr.Column():

                    with gr.Row():
                        fig7 = gr.Plot(label="Gradient Block 0")

                    with gr.Row():
                        fig8 = gr.Plot(label="Gradient Block 1")

                    with gr.Row():
                        fig9 = gr.Plot(label="Gradient Block 2")

                with gr.Column(min_width=700):
                    with gr.Row():
                        grad_unet_image = gr.Image(label="UNet")
                    with gr.Row():
                        gradient_step = gr.Radio(choices=[0], label="Select Timestep for Gradients")
                    with gr.Row():
                        gr_btn = gr.Button("2. Get Gradient Maps")
                        gr_cnl_btn = gr.Button("3. Calculate Max Channels")
                    with gr.Row():
                        with gr.Column():
                            mc7 = gr.Number(value=0,label="Max Channel Downs 0")
                            mc8 = gr.Number(value=0,label="Max Channel Downs 1")
                            mc9 = gr.Number(value=0,label="Max Channel Downs 2")
                        with gr.Column():
                            mc12 = gr.Number(value=0,label="Max Channel Ups 1")
                            mc11 = gr.Number(value=0,label="Max Channel Ups 0")
                            mc10 = gr.Number(value=0,label="Max Channel Mid")

                with gr.Column():
                    with gr.Row():
                        fig12 = gr.Plot(label="Gradient Block 5")

                    with gr.Row():
                        fig11 = gr.Plot(label="Gradient Block 4")

                    with gr.Row():
                        fig10 = gr.Plot(label="Gradient Block 3")

            with gr.Row():
                with gr.Column(min_width=700):
                    gradient_dist_fig = gr.Plot(label="Euclidean Distance between gradient timesteps")


            with gr.Row():
                with gr.Column():
                    gradient_volume_fig = gr.Plot(label="Image Channel Gradient")


    ######################################### MUTUAL INFO ##########################################################

        with gr.Tab("Compare samples"):

            with gr.Row():
                with gr.Column():
                    vs_label_1 = gr.CheckboxGroup(label_list, label="First Sample Label")
                with gr.Column():
                    vs_label_2 = gr.CheckboxGroup(label_list, label="Second Sample Label")

            with gr.Row():
                with gr.Column():
                    vs_sample_1 = gr.Gallery(
                        label="Every 100th sample"
                    ).style(grid=[4], height="auto")
                with gr.Column(min_width=700):
                    vs_ssim_plot = gr.Plot(label="SSIM between samples")
                with gr.Column():
                    vs_sample_2 = gr.Gallery(
                        label="Every 100th sample"
                    ).style(grid=[4], height="auto")

            with gr.Row():
                with gr.Tab("Activation Distance"):
                    gr.Markdown("""
                                    ## <center> Activation Maps Distance</center>

                                """)

                    with gr.Row():
                        with gr.Column():
                            vs_act_plot = gr.Plot(label="L2 distance between activations")

                    with gr.Row():
                        btn_vs_act = gr.Button("Calculate activation distance")

                with gr.Tab("Gradient Volumes"):
                    with gr.Row():
                        gr.Markdown("""
                                        ## <center>Compare Gradient Volumes at one timestep</center>

                                    """)
                    with gr.Row():
                        vs_grad_step = gr.Number(value=500, label="Sample Timestep", precision = 0)
                    with gr.Row():
                        with gr.Column():
                            img_vol1_fig = gr.Plot(label="Gradient Volume Sample 1")
                        with gr.Column():
                            img_vol2_fig = gr.Plot(label="Gradient Volume Sample 2")
                    with gr.Row():
                        btn_vs_grad = gr.Button("Calculate gradient volumes")

                with gr.Tab("Attention Map"):
                    with gr.Row():
                        gr.Markdown("""
                                        ## <center>Attention Maps Distance</center>

                                        <center>To compute the attention map, you need activations from all attention layers, e.g. layer ending in fn.to_out</center>
                                    """)
                    with gr.Row():
                        with gr.Column():
                            attn_map1 = gr.Plot(label="Attention Map 1")
                        with gr.Column(min_width=700):
                            attn_dist_fig = gr.Plot(label="Distance between Attention Maps")
                        with gr.Column():
                            attn_map2 = gr.Plot(label="Attention Map 2")
                    with gr.Row():
                        btn_vs_attn = gr.Button("Calculate Attention Distance")


    act_mode = gr.State(0)
    grad_mode = gr.State(1)
    attn_mode = gr.State(2)


    #Individual Trajectories
    btn_diffuse_act.click(diffuse,
                          inputs=[diffusion_labels,sample_res, l_b0,l_b1,l_b2,l_b3,l_b4,l_b5, random_seed, act_mode, image_mask, switch_labels,switch_step],
                          outputs=[act_out_sample, ssim_fig, activation_dist_fig, activation_step, activation_layers, activations])



    btn_diffuse_grad.click(diffuse,
                         inputs=[diffusion_labels,sample_res,l_b0,l_b1,l_b2,l_b3,l_b4,l_b5, random_seed, grad_mode, image_mask],
                         outputs=[grad_out_sample, gradient_step, gradient_layers, ssim_fig, gradient_dist_fig,  gradient_volume_fig, gradients])


    btn_vs_act.click(compare_diffuse,
                     inputs=[vs_label_1, vs_label_2, random_seed,sample_res, l_b0, l_b1, l_b2, l_b3, l_b5, l_b4, act_mode],
                     outputs=[vs_sample_1, vs_sample_2, vs_act_plot, vs_ssim_plot])

    btn_vs_grad.click(compare_diffuse,
                      inputs=[vs_label_1, vs_label_2, random_seed,sample_res, l_b0, l_b1, l_b2, l_b3, l_b5, l_b4, grad_mode, image_mask, vs_grad_step],
                      outputs=[vs_sample_1, vs_sample_2, vs_ssim_plot,img_vol1_fig, img_vol2_fig])

    btn_vs_attn.click(compare_diffuse, inputs=[vs_label_1, vs_label_2, random_seed, sample_res, l_b0, l_b1, l_b2, l_b3, l_b5, l_b4, attn_mode],
                   outputs=[vs_sample_1, vs_sample_2, vs_ssim_plot, attn_map1, attn_map2, attn_dist_fig])

    #Channel Masking

    btn_msk_channel.click(set_msk_ch,
                          inputs=[mask_layer, mask_channel])

    btn_unmsk_channel.click(remove_msk_ch,
                           inputs=[mask_layer, mask_channel])

    #Maps Generation

    acts_btn.click(generate_activation_maps,
                   inputs=[activation_step, activations, activation_layers, image_mask],
                   outputs=[fig0, fig1, fig2, fig3, fig4, fig5])

    gr_btn.click(generate_gradient_maps,
                 inputs=[gradient_step, gradients, gradient_layers],
                 outputs=[fig7, fig8, fig9, fig10, fig11, fig12])

    #Max Channel

    a_cnl_btn.click(calculate_max_channels,
                    inputs=[activation_step, activations, activation_layers, act_mode, image_mask],
                   outputs=[mc0, mc1, mc2, mc3, mc4, mc5])



    gr_cnl_btn.click(calculate_max_channels,
                    inputs=[gradient_step, gradients, gradient_layers, grad_mode],
                    outputs=[mc7, mc8, mc9, mc10, mc11, mc12])

    #Configure Interface
    btn_model_select.click(load_model,
                   inputs=[model_select],
                   outputs=[blocks_select, act_unet_image, grad_unet_image])

    btn_blocks_select.click(configure_interface,
                           inputs=[blocks_select],
                           outputs=[l_b0,l_b1,l_b2,l_b3,l_b4,l_b5,mask_layer])


demo.launch(debug=False, share=True, show_error=True)