In [3]:

import torch
from autoencoder import recon_model
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
%matplotlib ipympl



dir = "C:/Users/nicol/OneDrive - University of Bristol/MSc_project-DESKTOP-M3M0RRL/maxEnt_simulation/DNN/"
exp = "exp_Thu-12-Jun-2025-at-10-47-11AM/"




In [5]:
model = recon_model()
model.load_state_dict(torch.load(f"{dir}experiments/{exp}final_model.pth"))
model.eval()


activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

model.inc.conv.register_forward_hook(get_activation('inc'))
model.down1.double_conv.register_forward_hook(get_activation('down1'))
model.down2.double_conv.register_forward_hook(get_activation('down2'))
model.down3.conv.register_forward_hook(get_activation('down3'))
model.down4.conv.register_forward_hook(get_activation('down4'))
model.up1.deconv.register_forward_hook(get_activation('up1'))
model.up2.deconv.register_forward_hook(get_activation('up2'))
model.up3.deconv.register_forward_hook(get_activation('up3'))
model.outc.out.register_forward_hook(get_activation('outc'))


img = np.load(f"{dir}data/one_square/intensities.npy")[:1]
img = img[:, np.newaxis ]
img = torch.Tensor(img)

pred_diffr, pred_amp = model(img)

In [6]:

#add input and sigmoid 

def show_activations_widget(input_img, activations_dict, base_title=""):
    """
    Visualize arbitrary channel from any layer in activations_dict.
    """

    layer_names = list(activations_dict.keys())
    layer_names.insert(0, "Input")
    layer_names.append("After Sigmoid")

    layer_dropdown = widgets.Dropdown(
        options=layer_names,
        value=layer_names[0],
        description='Layer:'
    )

    channel_slider = widgets.IntSlider(
        value=0, min=0, max=0,
        description='Channel:'
    )

    def update_channel_range(*args):
        layer = layer_dropdown.value

        if layer not in ["Input", "After Sigmoid"]:
            max_channels = activations_dict[layer].shape[1]
            channel_slider.max = max(0, max_channels - 1) 
            channel_slider.value = 0
       
        else: 
            channel_slider.max = 0
            channel_slider.value = 0

    
    layer_dropdown.observe(update_channel_range, names='value')

    def show(layer, channel):
        
        if layer not in ["Input", "After Sigmoid"]:
            tensor = activations_dict[layer][0, channel]  # ([B, C, H, W] → one image)
        
        elif layer == "Input": 
            tensor = input_img[0][0]
        elif layer == "After Sigmoid": 
            tensor = torch.sigmoid(activations_dict["outc"][0, channel] )

        img = tensor.detach().numpy()


        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(img, cmap='viridis')
        ax.set_title(f"{base_title}Layer: {layer} | Channel: {channel+1}")
        ax.set_axis_off()
        plt.show()

    out = widgets.interactive_output(show, {'layer': layer_dropdown, 'channel': channel_slider})
    display(widgets.VBox([layer_dropdown, channel_slider]), out)


In [7]:
show_activations_widget(img, activations, base_title="Model Activations - ")


VBox(children=(Dropdown(description='Layer:', options=('Input', 'inc', 'down1', 'down2', 'down3', 'down4', 'up…

Output()

In [9]:
import os
from PIL import Image

figs_dir = f"{dir}experiments/{exp}progression_figs/"

image_files = sorted([f for f in os.listdir(figs_dir)])

def load_image(idx):
    path = os.path.join(figs_dir, image_files[idx])
    img = Image.open(path)
    fig, ax2 = plt.subplots(figsize = (17, 5))
    ax2.imshow(img)
    ax2.set_title(f"1st prediction of Epoch 1 Batch {idx}")
    ax2.set_axis_off()
    plt.show()

slider = widgets.IntSlider(0, 0, len(image_files)-1, description='Batch:')
out = widgets.interactive_output(load_image, {'idx': slider})
display(slider, out)

IntSlider(value=0, description='Batch:', max=99)

Output()