In [1]:

import torch
from acoustic_autoencoder import recon_model
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import os
#%matplotlib ipympl
%matplotlib inline



In [2]:
model = recon_model()
model.load_state_dict(torch.load("final_model.pth"))

model.eval()


activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

model.inc.conv.register_forward_hook(get_activation('inc'))
model.down1.double_conv.register_forward_hook(get_activation('down1'))
model.down2.double_conv.register_forward_hook(get_activation('down2'))
model.spec_block.register_forward_hook(get_activation('spec_block'))
model.up2.deconv.register_forward_hook(get_activation('up2'))
model.up3.deconv.register_forward_hook(get_activation('up3'))
model.outc.out.register_forward_hook(get_activation('outc'))


img = np.load("trap.npy")

img = img[:, np.newaxis ]
img = torch.Tensor(img)

_ = model(img)

In [3]:

#add input and sigmoid 

def show_activations_widget(input_img, activations_dict, base_title=""):
    """
    Visualize arbitrary channel from any layer in activations_dict.
    """

    layer_names = list(activations_dict.keys())
    layer_names.insert(0, "Input")
    layer_names.append("After Sigmoid")

    layer_dropdown = widgets.Dropdown(
        options=layer_names,
        value=layer_names[0],
        description='Layer:'
    )

    channel_slider = widgets.IntSlider(
        value=0, min=0, max=0,
        description='Channel:'
    )

    def update_channel_range(*args):
        layer = layer_dropdown.value

        if layer not in ["Input", "After Sigmoid"]:
            max_channels = activations_dict[layer].shape[1]
            channel_slider.max = max(0, max_channels - 1) 
            channel_slider.value = 0
       
        else: 
            channel_slider.max = 0
            channel_slider.value = 0

    
    layer_dropdown.observe(update_channel_range, names='value')

    def show(layer, channel):
        
        if layer not in ["Input", "After Sigmoid"]:
            tensor = activations_dict[layer][0, channel]  # ([B, C, H, W] → one image)
        
        elif layer == "Input": 
            tensor = input_img[0][0]
        elif layer == "After Sigmoid": 
            tensor = torch.sigmoid(activations_dict["outc"][0, channel] )

        img = tensor.detach().numpy()


        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(img, cmap='viridis')
        ax.set_title(f"{base_title}Layer: {layer} | Channel: {channel+1}")
        ax.set_axis_off()
        plt.show()

    out = widgets.interactive_output(show, {'layer': layer_dropdown, 'channel': channel_slider})
    display(widgets.VBox([layer_dropdown, channel_slider]), out)


In [4]:
show_activations_widget(img, activations, base_title="Model Activations - ")


VBox(children=(Dropdown(description='Layer:', options=('Input', 'inc', 'down1', 'down2', 'spec_block', 'up2', …

Output()

In [None]:
# import os
# from PIL import Image
# import re


# figs_dir = "progression_arrays/"

# def plot_3_ims(path, og_diffr, pred_diffr, pred_phase):

#     match = re.search(r"epoch_(\d+)_batch(\d+)", path)
#     epoch = int(match.group(1))
#     batch = int(match.group(2))

#     fig, axes = plt.subplots(1, 3, figsize = (12, 4))
#     im1 = axes[0].imshow(og_diffr)
#     axes[0].set_title("Original Acoustic Field Magnitude")
#     im2 = axes[1].imshow(pred_diffr)
#     axes[1].set_title("Predicted Acoustic Field Magnitude")
#     im3 = axes[2].imshow(pred_phase, cmap = "twilight")
#     axes[2].set_title("Predicted Phases")

#     for ax in axes: 
#         ax.set_axis_off()

#     fig.colorbar(im1, ax = axes[0], shrink = 0.7)
#     fig.colorbar(im2, ax = axes[1], shrink = 0.7)
#     fig.colorbar(im3, ax = axes[2], shrink = 0.7)

#     fig.suptitle(f"1st prediction of Epoch {epoch} Batch {batch}")
#     plt.show()


# def natural_sort_key(s):
#     return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# files = sorted(os.listdir(figs_dir), key=natural_sort_key)



# def load_image(idx):

#     path = os.path.join(figs_dir, files[idx])
#     array_of_3 = np.load(path)

#     og_diffr, pred_diffr, pred_phase = array_of_3[0], array_of_3[1], array_of_3[2]

#     plot_3_ims(path, og_diffr, pred_diffr, pred_phase)
    

# slider = widgets.IntSlider(0, 0, len(files)-1, description='Idx:')
# out = widgets.interactive_output(load_image, {'idx': slider})
# display(slider, out)

IntSlider(value=0, description='Idx:', max=174)

Output()