In [None]:
import numpy as np
from scipy import stats
import pandas as pd
from PIL import Image
import matplotlib
from matplotlib import pyplot as plt
import panel as pn
from io import BytesIO
import urllib.request

In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
pn.extension()

In [None]:
def add_noise(img_tensor_minus_plus_one, beta):
    noise = np.random.normal(loc=np.zeros_like(img_tensor_minus_plus_one), 
                             scale=np.ones_like(img_tensor_minus_plus_one)*np.sqrt(beta))
    img_tensor_w_noise = (np.sqrt(1-beta) * img_tensor_minus_plus_one + noise)
    return noise.astype(np.float16), img_tensor_w_noise.astype(np.float16)

In [None]:
file_input = pn.widgets.FileInput(accept="image/*", 
                                  multiple=False,
                                  margin=(21, 10, 5, 10))

In [None]:
beta_input = pn.widgets.FloatInput(name="\u03B2:", start=0, end=1, step=0.1, value=0.3)

In [None]:
num_corruption_steps=16

corruption_results_list = []

In [None]:
plot = pn.pane.Matplotlib(object=None, tight=True, width=700)

In [None]:
def plot_process(event):
    
    player_step = player.value
    
    fig = plt.Figure(figsize=(12,8))

    start_image, noise, noisier_image = corruption_results_list[player_step]

    ax1 = fig.add_axes([0., 0.1, 0.3, 0.3])
    ax1.imshow(((start_image + 1) / 2).astype(np.float32).clip(min=0, max=1), 
               interpolation='nearest', 
               aspect="auto")
    
    ax2 = fig.add_axes([0.5, 0.1, 0.3, 0.3])
    ax2.set_xlim(ax1.get_xlim())
    ax2.set_ylim(ax1.get_ylim())
    
    ax3 = fig.add_axes([0.25, 0.5, 0.3, 0.3])
    ax3.set_xlim(ax1.get_xlim())
    ax3.set_ylim(ax1.get_ylim())
    
    if player_step not in [0,1,4]:
        ax2.imshow(((noisier_image + 1) / 2).astype(np.float32).clip(min=0, max=1), 
                   interpolation='nearest', 
                   aspect="auto")
    else:
        for pos in ["left", "right", "top", "bottom"]:
            ax2.spines[pos].set_edgecolor("#ffffff00")
        ax2.tick_params(axis='both', colors='#ffffff00')
        
    if player_step not in [0,4]:
        ax3.imshow(((noise + 1) / 2).astype(np.float32).clip(min=0, max=1), 
                   interpolation='nearest', 
                   aspect="auto")
    else:
        for pos in ["left", "right", "top", "bottom"]:
            ax3.spines[pos].set_edgecolor("#ffffff00")
        ax3.tick_params(axis='both', colors='#ffffff00')
        
    if player_step in [0,1,2,4]:
        arrow_1_color = "#ffffff00"
    else:
        arrow_1_color = "black"
    arrow_1 = ax1.annotate('', 
                           xy=(0.72, 0.09),  
                           xycoords='figure fraction',
                           xytext=(0.18, 0.09),
                           arrowprops=dict(arrowstyle="<|-,head_width=0.8, head_length=0.8",
                                           connectionstyle="bar,fraction=0.1",
                                           color=arrow_1_color,
                                           linewidth=2)
                          )
    
    if player_step in [0,1,4]:
        circle_facecolor = "#ffffff00"
        circle_edgecolor = "#ffffff00"
    else:
        circle_facecolor = "#999999"
        circle_edgecolor = "black"
    circle = matplotlib.patches.Ellipse(xy=(0.4,0.25), 
                                        width=0.05, 
                                        height=0.075, 
                                        facecolor=circle_facecolor,
                                        edgecolor=circle_edgecolor)
    fig.add_artist(circle)
    
    if player_step in [0,1,4]:
        addition_symbol_color = "#ffffff00"
    else:
        addition_symbol_color = "black"
    addition_symbol = matplotlib.text.Text(x=0.3805, 
                                           y=0.228, 
                                           text="+", 
                                           fontsize=40,
                                           color=addition_symbol_color)
    fig.add_artist(addition_symbol)
    
    if player_step in [0,1,4]:
        arrow_2_color = "#ffffff00"
    else:
        arrow_2_color = "black"
    arrow_2 = ax1.annotate('', 
                 xy=(0.415, 0.25), 
                 xytext=(0.34,0.25), 
                 xycoords='figure fraction',
                 arrowprops=dict(arrowstyle="-|>,head_width=0.8, head_length=0.8",
                                 linewidth=2,
                                 color=arrow_2_color
                                ))
    
    if player_step in [0,1,4]:
        arrow_3_color = "#ffffff00"
    else:
        arrow_3_color = "black"
    arrow_3 = ax1.annotate('', 
                 xy=(0.54, 0.25), 
                 xytext=(0.468,0.25), 
                 xycoords='figure fraction',
                 arrowprops=dict(arrowstyle="-|>,head_width=0.8, head_length=0.8",
                                 linewidth=2,
                                 color=arrow_3_color))
    
    if player_step in [0,1,4]:
        arrow_4_color = "#ffffff00"
    else:
        arrow_4_color = "black"
    arrow_4 = ax1.annotate('', 
                 xy=(0.438, 0.29), 
                 xytext=(0.438,0.5), 
                 xycoords='figure fraction',
                 arrowprops=dict(arrowstyle="-|>,head_width=0.8, head_length=0.8",
                                 linewidth=2,
                                 color=arrow_4_color))
    
    if player_step in [0,1,2,3]:
        text_t = "1"
    elif player_step in [4,5]:
        text_t = "2"
    else:
        text_t = str(player_step-3)
    timestep_box = ax1.annotate("$t=%s$" % text_t, 
                                xy=(0.5,0.5), 
                                xycoords='figure fraction',
                                xytext=(0.075, 0.75), 
                                textcoords='figure fraction',
                                size=20, 
                                va="center", 
                                ha="center",
                                bbox=dict(boxstyle="square", fc="w"))
    
    if player_step in [1,2,3,5,11]:
        if player_step == 1:
            text = "Generate noise as\n"r"$\mathcal{N}(\mathbf{x}_t;\ 0,\ \beta_t\,\mathbf{I})$"
        elif player_step == 2:
            text = "Scale image by\n"r"$\sqrt{1 - \beta_t}$""\nand add the noise"
        elif player_step==3:
            text = r"Output from step $t$""\nbecomes the input\n"r"for step $t+1$"
        elif player_step == 5:
            text = "Repeat process\n(generate noise, \nscale and add)"
        elif player_step == 11:
            text = "Image progressively\nbecomes noisier,\nwith less info"
        description_box = ax1.annotate(text, 
                                       xy=(0.5,0.5), 
                                       xycoords='figure fraction',
                                       xytext=(0.71, 0.75), 
                                       textcoords='figure fraction',
                                       size=15, 
                                       va="center", 
                                       ha="center",
                                       bbox=dict(boxstyle="square", fc="w"))

    plt.close(fig)
    plot.object = fig

In [None]:
# Make that when beta_input changes, the player goes back to 1;
# as well as re-computing corruption_results_list:
player = pn.widgets.Player(name='Discrete Player', 
                           start=0, 
                           end=19, 
                           loop_policy='once',
                           interval=2000,
                           show_loop_controls=False,
                           sizing_mode="stretch_width")

In [None]:
player.param.watch(plot_process, "value", onlychanged=False)

In [None]:
def reset_demo(event):
    
    corruption_results_list.clear()
    
    with BytesIO(file_input.value) as buffer:
        img = Image.open(buffer)
        img_tensor = np.array(img) / 255

    # If we have alpha channel, remove it:
    if img_tensor.shape[-1] == 4:
        img_tensor = img_tensor[:,:,:3]

    # Scale to [-1,1]:
    img_tensor_minus_plus_one = img_tensor * 2 - 1
    img_tensor_minus_plus_one = img_tensor_minus_plus_one.astype(np.float16)
    
    start_image = img_tensor_minus_plus_one
    for corruption_step in range(num_corruption_steps):
        noise, noisier_image = add_noise(start_image, beta_input.value)
        corruption_results_list.append((start_image, noise, noisier_image))
        start_image = noisier_image
    first_step = corruption_results_list[0]
    second_step = corruption_results_list[1]
    corruption_results_list.insert(1, first_step)
    corruption_results_list.insert(1, first_step)
    corruption_results_list.insert(1, first_step)
    corruption_results_list.insert(4, second_step)
    player.value = 0

In [None]:
beta_input.param.watch(reset_demo, 'value')
file_input.param.watch(reset_demo, 'value')

In [None]:
file_input.value = urllib.request.urlopen("https://i.postimg.cc/nhVpcdKF/demo-dog.jpg").read()

In [None]:
template = pn.template.VanillaTemplate(title='Forward diffusion process',
                                       busy_indicator=None,
                                       header_background="#434343")

template.config.raw_css.append(
"""#header{
           flex-wrap: wrap;
           justify-content: center;
          }
""")

In [None]:
app = pn.Row(pn.Spacer(height=10, sizing_mode="stretch_width"),
       pn.Column(pn.Row(file_input, beta_input), 
                 plot, 
                 player,
                 pn.pane.Markdown("Demo by Julio Antonio Soto for IE University. Made with [Panel](https://panel.holoviz.org)")),
       pn.Spacer(height=10, sizing_mode="stretch_width"),
       sizing_mode="stretch_width"
      )

In [None]:
template.main.append(app)

In [None]:
template.servable()