In [16]:
import os
import torch
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib.colors import LinearSegmentedColormap

def load_model_weights(model_path):
    """Load model weights from a .pth file."""
    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    model_state = checkpoint["student_model_state_dict"]
    return {k: v.cpu().numpy() for k, v in model_state.items() if "weight" in k}

def create_heatmaps(weights_dict, model_name="Model Weights", use_plotly=False):
    num_layers = len(weights_dict)

    if use_plotly: # Interactive Plotly visualization
        fig = make_subplots(rows=1, cols=num_layers, subplot_titles=[f"Layer {i+1}" for i in range(num_layers)])
        global_min = min(w.min() for w in weights_dict.values())
        global_max = max(w.max() for w in weights_dict.values())

        # for i, (layer_name, weights) in enumerate(weights_dict.items()):
        #     heatmap = go.Heatmap(z=weights, colorscale="RdBu", zmin=global_min, zmax=global_max)
        #     fig.add_trace(heatmap, row=1, col=i+1)

        colorscale = [
            [global_min, "red"],    # Minimum value (negative)
            [0, "white"],  # Zero value
            [global_max, "blue"]   # Maximum value (positive)
        ]
        
        for i, (layer_name, weights) in enumerate(weights_dict.items()):
            heatmap = go.Heatmap(
                z=weights, 
                colorscale=colorscale,  # Custom scale
                zmid=0,  # Ensure 0 is mapped to white
                zmin=weights.min(), 
                zmax=weights.max()
        )
        fig.add_trace(heatmap, row=1, col=i+1)
        fig.add_trace(heatmap, row=1, col=i+1)

        fig.update_layout(title=f"{model_name} - Weight Heatmaps", height=500, width=250*num_layers)
        fig.show()

    else:  # Matplotlib visualization
        fig, axes = plt.subplots(1, num_layers, figsize=(4*num_layers, 4))

        if num_layers == 1:
            axes = [axes]

        for i, (layer_name, weights) in enumerate(weights_dict.items()):
            ax = axes[i]
            cmap = LinearSegmentedColormap.from_list("custom_cmap", ["red", "white", "green"], N=100)
            im = ax.imshow(weights, cmap=cmap, aspect="auto", vmin=-2, vmax=2)
            ax.set_title(f"{layer_name}")
            ax.set_xticks([])
            ax.set_yticks([])

        fig.colorbar(im, ax=axes, fraction=0.05)
        plt.suptitle(f"{model_name} - Weight Heatmaps")
        plt.show()

In [None]:
def create_all_heatmaps(directory):
    model_files = [f for f in os.listdir(directory) if f.endswith(".pth")]
    for model_file in model_files:
        model_path = os.path.join(directory, model_file)
        print(f"\nLoading weights from {model_file}...")
        
        weights_dict = load_model_weights(model_path)
        create_heatmaps(weights_dict, model_name=model_file, use_plotly=False)  # Set to False for Matplotlib

In [None]:
directory = "../experiment_output/Davide_MLP/experiments_davide_lr0.01_b16_init"
create_all_heatmaps(directory)


Loading weights from nonoverlappingCNN_sigmoid__fcnn_sigmoid.pth...



Loading weights from nonoverlappingCNN_tanh__fcnn_tanh.pth...



Loading weights from nonoverlappingCNN_relu__fcnn_relu.pth...


In [None]:
directory = "../experiment_output/Davide_MLP/experiments_davide_lr0.01_b32_init"
create_all_heatmaps(directory)


Loading weights from nonoverlappingCNN_sigmoid__fcnn_sigmoid.pth...



Loading weights from nonoverlappingCNN_tanh__fcnn_tanh.pth...



Loading weights from nonoverlappingCNN_relu__fcnn_relu.pth...


In [36]:
directory = "../experiment_output/fcnn/experiments_l1_1e-5_b32_init_hidden128"
create_all_heatmaps(directory)


Loading weights from nonoverlappingCNN_sigmoid__fcnn_sigmoid.pth...



Loading weights from nonoverlappingCNN_tanh__fcnn_tanh.pth...



Loading weights from nonoverlappingCNN_relu__fcnn_relu.pth...


In [18]:
directory = "../experiment_output/experiments_20032025"
create_all_heatmaps(directory)


Loading weights from nonoverlappingCNN_relu__fcnn_decreasing_relu.pth...


ValueError: 
    Invalid value of type 'builtins.list' received for the 'colorscale' property of heatmap
        Received value: [[np.float32(-3.0578067), 'red'], [0, 'white'], [np.float32(4.648993), 'blue']]

    The 'colorscale' property is a colorscale and may be
    specified as:
      - A list of colors that will be spaced evenly to create the colorscale.
        Many predefined colorscale lists are included in the sequential, diverging,
        and cyclical modules in the plotly.colors package.
      - A list of 2-element lists where the first element is the
        normalized color level value (starting at 0 and ending at 1),
        and the second item is a valid color string.
        (e.g. [[0, 'green'], [0.5, 'red'], [1.0, 'rgb(0, 0, 255)']])
      - One of the following named colorscales:
            ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
             'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg',
             'brwnyl', 'bugn', 'bupu', 'burg', 'burgyl', 'cividis', 'curl',
             'darkmint', 'deep', 'delta', 'dense', 'earth', 'edge', 'electric',
             'emrld', 'fall', 'geyser', 'gnbu', 'gray', 'greens', 'greys',
             'haline', 'hot', 'hsv', 'ice', 'icefire', 'inferno', 'jet',
             'magenta', 'magma', 'matter', 'mint', 'mrybm', 'mygbm', 'oranges',
             'orrd', 'oryel', 'oxy', 'peach', 'phase', 'picnic', 'pinkyl',
             'piyg', 'plasma', 'plotly3', 'portland', 'prgn', 'pubu', 'pubugn',
             'puor', 'purd', 'purp', 'purples', 'purpor', 'rainbow', 'rdbu',
             'rdgy', 'rdpu', 'rdylbu', 'rdylgn', 'redor', 'reds', 'solar',
             'spectral', 'speed', 'sunset', 'sunsetdark', 'teal', 'tealgrn',
             'tealrose', 'tempo', 'temps', 'thermal', 'tropic', 'turbid',
             'turbo', 'twilight', 'viridis', 'ylgn', 'ylgnbu', 'ylorbr',
             'ylorrd'].
        Appending '_r' to a named colorscale reverses it.
