In [4]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import datetime

###############################
#  CREATE FOLDERS & UTILITIES
###############################

def create_directory(dir_path):
    """Create directory if it does not exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def rescale_matrix(W, desired_radius):
    """
    Rescale the matrix W to have the desired spectral radius.
    
    Returns:
        W_rescaled: The rescaled matrix,
        current_radius: The original spectral radius,
        scaling_factor: The factor applied.
    """
    eigenvalues = torch.linalg.eigvals(W)
    current_radius = eigenvalues.abs().max().item()
    scaling_factor = desired_radius / current_radius if current_radius != 0 else 1.0
    W_rescaled = W * scaling_factor
    return W_rescaled, current_radius, scaling_factor

###############################
#  BASIC HISTOGRAM PLOTS
###############################

def plot_and_save_histogram(matrix, title, save_path, bins=50):
    """Plot and save a static (PNG) histogram of matrix values using Matplotlib."""
    plt.figure()
    plt.hist(matrix.flatten(), bins=bins, color='skyblue', edgecolor='black')
    plt.title(title)
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_and_save_interactive_histogram(matrix, title, save_path, bins=50):
    """Plot and save an interactive (HTML) histogram of matrix values using Plotly."""
    data = matrix.flatten()
    fig = px.histogram(data, nbins=bins, title=title,
                       labels={'value': 'Weight Value', 'count': 'Frequency'})
    fig.write_html(save_path)

###############################
#  ADVANCED VISUALIZATIONS
###############################

# 1) Matplotlib 2D Heatmap
def plot_heatmap_matplotlib(W, title, save_path, interpolation='bicubic', cmap='viridis'):
    """
    Plot a 2D heatmap of the weight matrix W using Matplotlib, with optional interpolation.
    Saves a static PNG file.
    """
    plt.figure(figsize=(8, 6), dpi=300)
    plt.imshow(W, cmap=cmap, interpolation=interpolation, aspect='auto')
    plt.title(title)
    plt.colorbar(label='Weight Value')
    plt.xlabel('Post-Synaptic Neuron Index')
    plt.ylabel('Pre-Synaptic Neuron Index')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

# 2) Interactive Plotly 2D Heatmap
def plot_heatmap_plotly(W, title, save_path):
    """
    Plot an interactive 2D heatmap of the weight matrix W using Plotly.
    Saves an HTML file.
    """
    fig = go.Figure(data=go.Heatmap(
        z=W,
        colorscale='Viridis',
        colorbar=dict(title='Weight Value')
    ))
    fig.update_layout(
        title=title,
        xaxis_title="Post-Synaptic Index",
        yaxis_title="Pre-Synaptic Index"
    )
    fig.write_html(save_path)

# 3) Matplotlib 3D Surface
def plot_3d_surface_matplotlib(W, title, save_path, cmap='viridis'):
    """
    Plot a 3D surface of the weight matrix W using Matplotlib.
    Saves a static PNG file.
    """
    from mpl_toolkits.mplot3d import Axes3D  # noqa
    rows, cols = W.shape
    X, Y = np.meshgrid(range(cols), range(rows))

    fig = plt.figure(figsize=(8, 6), dpi=300)
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(X, Y, W, cmap=cmap, edgecolor='none')
    ax.set_title(title)
    ax.set_xlabel("Post-Synaptic Index")
    ax.set_ylabel("Pre-Synaptic Index")
    ax.set_zlabel("Weight Value")
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

# 4) Interactive Plotly 3D Surface
def plot_3d_surface_plotly(W, title, save_path):
    """
    Plot a 3D surface of the weight matrix W using Plotly.
    Saves an HTML file for interactive exploration.
    """
    rows, cols = W.shape
    X, Y = np.meshgrid(range(cols), range(rows))
    
    fig = go.Figure(data=[go.Surface(z=W, x=X, y=Y, colorscale='Viridis')])
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='Post-Synaptic Index',
            yaxis_title='Pre-Synaptic Index',
            zaxis_title='Weight Value'
        )
    )
    fig.write_html(save_path)

###############################
#  MAIN CODE
###############################

# PARAMETERS
seed = 42
reservoir_size = 100

init_weight_a = 0.0
init_weight_b = 1.0

gaussian_mean = 0.0
gaussian_std = 1.0

desired_radii = [0.1, 0.5, 1.0, 2.0, 10.0]

# Initialization method: "uniform", "gaussian", or "sparse".
initialization_method = "uniform"

# If using sparse initialization, adjust the sparsity level below:
# NOTE: 0.1 sparsity means that 90% of the weights are zero.
sparsity = 0.1

# Set random seeds for reproducibility.
torch.manual_seed(seed)
np.random.seed(seed)

# Create a unique folder for this run based on seed, initialization method, and timestamp.
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
base_folder = '/home/workspaces/polimikel/UCR/Weight_matrices'
if initialization_method == "sparse":
    sparsity_str = str(sparsity).replace('.', '')
    unique_folder_name = f"matrix_seed_{seed}_sparse{sparsity_str}_{timestamp}"
else:
    unique_folder_name = f"matrix_seed_{seed}_{initialization_method}_{timestamp}"
unique_folder_path = os.path.join(base_folder, unique_folder_name)
create_directory(unique_folder_path)

# Save the hyperparameters and initialization parameters in a text file.
params_file = os.path.join(unique_folder_path, "initialization_params.txt")
with open(params_file, "w") as f:
    f.write("Initialization Parameters:\n")
    f.write(f"Seed: {seed}\n")
    f.write(f"Initialization Method: {initialization_method}\n")
    if initialization_method == "sparse":
        f.write(f"Sparse Density: {sparsity}\n")
    f.write(f"Reservoir Size: {reservoir_size}\n")
    f.write(f"Uniform Initialization Boundaries: {init_weight_a} to {init_weight_b}\n")
    if initialization_method == "gaussian":
        f.write(f"Gaussian Initialization Parameters: mean={gaussian_mean}, std={gaussian_std}\n")
    f.write(f"Desired Spectral Radii: {desired_radii}\n")
    f.write(f"Timestamp: {timestamp}\n")

############################################
#  Generate the initial connectivity matrix
############################################

W_initial = torch.empty((reservoir_size, reservoir_size))
if initialization_method == "uniform":
    torch.nn.init.uniform_(W_initial, a=init_weight_a, b=init_weight_b)
elif initialization_method == "gaussian":
    torch.nn.init.normal_(W_initial, mean=gaussian_mean, std=gaussian_std)
elif initialization_method == "sparse":
    torch.nn.init.uniform_(W_initial, a=init_weight_a, b=init_weight_b)
    mask = (torch.rand(W_initial.shape) < sparsity).float()
    W_initial = W_initial * mask
else:
    raise ValueError("Invalid initialization method selected. Use 'uniform', 'gaussian', or 'sparse'.")

###############################
#  SAVE & PLOT ORIGINAL MATRIX
###############################
original_folder = os.path.join(unique_folder_path, 'original')
create_directory(original_folder)

# Save matrix as .npy
original_matrix_path = os.path.join(original_folder, 'W_original.npy')
np.save(original_matrix_path, W_initial.numpy())

# 1) Histogram (Static + Interactive)
hist_path_static = os.path.join(original_folder, 'W_original_hist.png')
plot_and_save_histogram(W_initial.numpy(), 'Histogram of Original Connectivity Matrix', hist_path_static)

hist_path_interactive = os.path.join(original_folder, 'W_original_hist_interactive.html')
plot_and_save_interactive_histogram(W_initial.numpy(), 'Interactive Histogram of Original Connectivity Matrix', hist_path_interactive)

# 2) Heatmaps (Matplotlib + Plotly)
heatmap_path_static = os.path.join(original_folder, 'W_original_heatmap.png')
plot_heatmap_matplotlib(W_initial.numpy(), "Original Weight Matrix (Heatmap)", heatmap_path_static)

heatmap_path_interactive = os.path.join(original_folder, 'W_original_heatmap_interactive.html')
plot_heatmap_plotly(W_initial.numpy(), "Original Weight Matrix (Interactive Heatmap)", heatmap_path_interactive)

# 3) 3D Surface (Matplotlib + Plotly)
surface_path_static = os.path.join(original_folder, 'W_original_3Dsurface.png')
plot_3d_surface_matplotlib(W_initial.numpy(), "Original Weight Matrix (3D Surface)", surface_path_static)

surface_path_interactive = os.path.join(original_folder, 'W_original_3Dsurface_interactive.html')
plot_3d_surface_plotly(W_initial.numpy(), "Original Weight Matrix (3D Surface, Interactive)", surface_path_interactive)

print(f"Saved original connectivity matrix and plots in: {original_folder}")

###############################
#  RESCALE & PLOT FOR EACH RHO
###############################
for radius in desired_radii:
    
    # Rescale the matrix.
    W_rescaled, current_radius, scaling_factor = rescale_matrix(W_initial, radius)
    
    # Create a subfolder for this spectral radius.
    folder_name = f'rho{str(radius).replace(".", "")}'
    rho_folder_path = os.path.join(unique_folder_path, folder_name)
    create_directory(rho_folder_path)
    
    # Save the rescaled matrix.
    matrix_save_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}.npy')
    np.save(matrix_save_path, W_rescaled.numpy())
    
    # Create a title for the plots including rescaling details.
    base_title = (f"Rescaled to Spectral Radius = {radius}\n"
                  f"(Original SR: {current_radius:.3f}, Scaling: {scaling_factor:.3f})")
    
    # HISTOGRAMS
    hist_save_path_static = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_hist.png')
    plot_and_save_histogram(W_rescaled.numpy(), f"Histogram of W ({base_title})", hist_save_path_static)
    
    hist_save_path_interactive = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_hist_interactive.html')
    plot_and_save_interactive_histogram(W_rescaled.numpy(), f"Interactive Histogram of W ({base_title})", hist_save_path_interactive)
    
    # HEATMAPS (2D)
    heatmap_static_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_heatmap.png')
    plot_heatmap_matplotlib(W_rescaled.numpy(), f"Heatmap of W ({base_title})", heatmap_static_path)
    
    heatmap_interactive_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_heatmap_interactive.html')
    plot_heatmap_plotly(W_rescaled.numpy(), f"Interactive Heatmap of W ({base_title})", heatmap_interactive_path)
    
    # 3D SURFACES
    surface_static_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_3Dsurface.png')
    plot_3d_surface_matplotlib(W_rescaled.numpy(), f"3D Surface of W ({base_title})", surface_static_path)
    
    surface_interactive_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_3Dsurface_interactive.html')
    plot_3d_surface_plotly(W_rescaled.numpy(), f"3D Surface of W ({base_title})", surface_interactive_path)
    
    print(f"Saved rescaled matrix and plots for spectral radius {radius} in: {rho_folder_path}")


Saved original connectivity matrix and plots in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/original
Saved rescaled matrix and plots for spectral radius 0.1 in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/rho01
Saved rescaled matrix and plots for spectral radius 0.5 in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/rho05
Saved rescaled matrix and plots for spectral radius 1.0 in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/rho10
Saved rescaled matrix and plots for spectral radius 2.0 in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/rho20
Saved rescaled matrix and plots for spectral radius 10.0 in: /home/workspaces/polimikel/UCR/Weight_matrices/matrix_seed_42_uniform_20250305_170543/rho100
