## Generation of Weight Matrices

In [None]:
import os
import re
import numpy as np
import torch
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import datetime
import seaborn as sns
import networkx as nx
from sklearn.cluster import KMeans

# For Ray Tune and TensorBoard logging
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from torch.utils.tensorboard import SummaryWriter
import pandas as pd

#############################################
#  CREATE FOLDERS & UTILITY FUNCTIONS
#############################################

def create_directory(dir_path):
    """Create directory if it does not exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def rescale_matrix(W, desired_radius):
    """
    Rescale the matrix W to have the desired spectral radius.
    
    Returns:
        W_rescaled: The rescaled matrix,
        current_radius: The original spectral radius,
        scaling_factor: The factor applied.
    """
    eigenvalues = torch.linalg.eigvals(W)
    current_radius = eigenvalues.abs().max().item()
    scaling_factor = desired_radius / current_radius if current_radius != 0 else 1.0
    W_rescaled = W * scaling_factor
    return W_rescaled, current_radius, scaling_factor

#############################################
#  BASIC HISTOGRAM PLOTS (Matplotlib & Plotly)
#############################################

def plot_and_save_histogram(matrix, title, save_path, bins=50):
    """Plot and save a static (PNG) histogram of matrix values using Matplotlib."""
    plt.figure()
    plt.hist(matrix.flatten(), bins=bins, color='skyblue', edgecolor='black')
    plt.title(title)
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_and_save_interactive_histogram(matrix, title, save_path, bins=50):
    """Plot and save an interactive (HTML) histogram of matrix values using Plotly."""
    data = matrix.flatten()
    fig = px.histogram(data, nbins=bins, title=title,
                       labels={'value': 'Weight Value', 'count': 'Frequency'})
    fig.write_html(save_path)

#############################################
#  ADDITIONAL VISUALIZATION FUNCTIONS
#############################################

def plot_sorted_bar(matrix, title, save_path):
    """Plot a sorted bar plot of all weights."""
    flat = np.sort(matrix.flatten())
    plt.figure(figsize=(10, 4), dpi=300)
    plt.bar(range(len(flat)), flat, color='mediumpurple')
    plt.title(title)
    plt.xlabel('Sorted Index')
    plt.ylabel('Weight Value')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_violin(matrix, title, save_path):
    """Plot a violin plot of weight values using Seaborn."""
    flat = matrix.flatten()
    plt.figure(figsize=(6, 4), dpi=300)
    # Use fill=True to avoid FutureWarning
    sns.violinplot(data=flat, color='lightgreen')
    plt.title(title)
    plt.xlabel('Weights')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_kde(matrix, title, save_path):
    """Plot a KDE of the weight distribution using Seaborn. 
       If variance is 0, plot a flat line instead."""
    flat = matrix.flatten()
    plt.figure(figsize=(6, 4), dpi=300)
    if np.var(flat) == 0:
        # Plot a flat density line
        plt.plot([flat[0], flat[0]], [0, 1], color='teal')
        plt.title(title + "\n(Data has zero variance)")
        plt.xlabel('Weight Value')
        plt.ylabel('Density')
    else:
        sns.kdeplot(flat, fill=True, color='teal', warn_singular=False)
        plt.title(title)
        plt.xlabel('Weight Value')
        plt.ylabel('Density')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_adjacency_graph(matrix, title, save_path, threshold=0.01):
    """
    For sparse matrices: plot an adjacency graph.
    Only edges with weight above 'threshold' are drawn.
    """
    G = nx.from_numpy_array(matrix, create_using=nx.DiGraph())
    G.remove_edges_from([(u, v) for u, v, w in G.edges(data='weight') if abs(w) < threshold])
    plt.figure(figsize=(8, 6), dpi=300)
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_size=300, node_color='lightblue', edge_color='gray')
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_clustered_heatmap(matrix, title, save_path, n_clusters=4):
    """
    Cluster rows of the matrix using KMeans and plot a clustered heatmap.
    If the number of unique rows is less than n_clusters, reduce n_clusters.
    """
    unique_rows = np.unique(matrix, axis=0)
    effective_clusters = min(n_clusters, len(unique_rows))
    kmeans = KMeans(n_clusters=effective_clusters, random_state=42).fit(matrix)
    idx = np.argsort(kmeans.labels_)
    clustered_matrix = matrix[idx, :]
    plt.figure(figsize=(8, 6), dpi=300)
    sns.heatmap(clustered_matrix, cmap='viridis')
    plt.title(title)
    plt.xlabel('Columns')
    plt.ylabel('Rows (clustered)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

#############################################
#  ADVANCED VISUALIZATIONS (Heatmaps & 3D Surfaces)
#############################################

def plot_heatmap_matplotlib(W, title, save_path, interpolation='bicubic', cmap='viridis'):
    """
    Plot a 2D heatmap of the weight matrix W using Matplotlib.
    Saves a static PNG file.
    """
    plt.figure(figsize=(8, 6), dpi=300)
    plt.imshow(W, cmap=cmap, interpolation=interpolation, aspect='auto')
    plt.title(title)
    plt.colorbar(label='Weight Value')
    plt.xlabel('Post-Synaptic Neuron Index')
    plt.ylabel('Pre-Synaptic Neuron Index')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_heatmap_plotly(W, title, save_path):
    """
    Plot an interactive 2D heatmap of the weight matrix W using Plotly.
    Saves an HTML file.
    """
    fig = go.Figure(data=go.Heatmap(
        z=W,
        colorscale='Viridis',
        colorbar=dict(title='Weight Value')
    ))
    fig.update_layout(
        title=title,
        xaxis_title="Post-Synaptic Index",
        yaxis_title="Pre-Synaptic Index"
    )
    fig.write_html(save_path)

def plot_3d_surface_matplotlib(W, title, save_path, cmap='viridis'):
    """
    Plot a 3D surface of the weight matrix W using Matplotlib.
    Saves a static PNG file.
    Note: tight_layout() is not used here for 3D plots.
    """
    from mpl_toolkits.mplot3d import Axes3D  # noqa
    rows, cols = W.shape
    X, Y = np.meshgrid(range(cols), range(rows))
    fig = plt.figure(figsize=(8, 6), dpi=300)
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(X, Y, W, cmap=cmap, edgecolor='none')
    ax.set_title(title)
    ax.set_xlabel("Post-Synaptic Index")
    ax.set_ylabel("Pre-Synaptic Index")
    ax.set_zlabel("Weight Value")
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_3d_surface_plotly(W, title, save_path):
    """
    Plot a 3D surface of the weight matrix W using Plotly.
    Saves an HTML file.
    """
    rows, cols = W.shape
    X, Y = np.meshgrid(range(cols), range(rows))
    fig = go.Figure(data=[go.Surface(z=W, x=X, y=Y, colorscale='Viridis')])
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='Post-Synaptic Index',
            yaxis_title='Pre-Synaptic Index',
            zaxis_title='Weight Value'
        )
    )
    fig.write_html(save_path)

#############################################
#  HELPER FUNCTION TO LOAD A MATRIX
#############################################

def load_matrix(matrix_path):
    """Load a matrix (stored as a .npy file) from the given path."""
    return np.load(matrix_path)

#############################################
#  MAIN CODE: GENERATE, SAVE & VISUALIZE WEIGHT MATRIX
#############################################

seed = 42
reservoir_size = 100

init_weight_a = -1.0
init_weight_b = 1.0

gaussian_mean = 1.0
gaussian_std = 1.0

desired_radii = [0.1, 0.5, 1.0, 2.0, 10.0]

# Initialization method: "uniform", "gaussian", or "sparse".
initialization_method = "uniform"
sparsity = 0.1  # NOTE: sparsity 0.1 means that 90% of weights are zero

# Set random seeds for reproducibility.
torch.manual_seed(seed)
np.random.seed(seed)

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
base_folder = '/home/workspaces/polimikel/UCR/Weight_matrices'
if initialization_method == "sparse":
    sparsity_str = str(sparsity).replace('.', 'x')
    unique_folder_name = f"matrix_seed_{seed}_sparse{sparsity_str}_{timestamp}"
else:
    unique_folder_name = f"matrix_seed_{seed}_{initialization_method}_{timestamp}"
unique_folder_path = os.path.join(base_folder, unique_folder_name)
create_directory(unique_folder_path)

params_file = os.path.join(unique_folder_path, "initialization_params.txt")
with open(params_file, "w") as f:
    f.write("Initialization Parameters:\n")
    f.write(f"Seed: {seed}\n")
    f.write(f"Initialization Method: {initialization_method}\n")
    if initialization_method == "sparse":
        f.write(f"Sparse Density: {sparsity}\n")
    f.write(f"Reservoir Size: {reservoir_size}\n")
    f.write(f"Uniform Boundaries: {init_weight_a} to {init_weight_b}\n")
    if initialization_method == "gaussian":
        f.write(f"Gaussian Parameters: mean={gaussian_mean}, std={gaussian_std}\n")
    f.write(f"Desired Spectral Radii: {desired_radii}\n")
    f.write(f"Timestamp: {timestamp}\n")

W_initial = torch.empty((reservoir_size, reservoir_size))
if initialization_method == "uniform":
    torch.nn.init.uniform_(W_initial, a=init_weight_a, b=init_weight_b)
elif initialization_method == "gaussian":
    torch.nn.init.normal_(W_initial, mean=gaussian_mean, std=gaussian_std)
elif initialization_method == "sparse":
    torch.nn.init.uniform_(W_initial, a=init_weight_a, b=init_weight_b)
    mask = (torch.rand(W_initial.shape) < sparsity).float()
    W_initial = W_initial * mask
else:
    raise ValueError("Invalid initialization method selected. Choose 'uniform', 'gaussian', or 'sparse'.")

original_folder = os.path.join(unique_folder_path, 'original')
create_directory(original_folder)
original_matrix_path = os.path.join(original_folder, 'W_original.npy')
np.save(original_matrix_path, W_initial.numpy())

plot_and_save_histogram(W_initial.numpy(), 'Histogram of Original Connectivity Matrix', 
                          os.path.join(original_folder, 'W_original_hist.png'))
plot_and_save_interactive_histogram(W_initial.numpy(), 'Interactive Histogram of Original Connectivity Matrix', 
                                    os.path.join(original_folder, 'W_original_hist_interactive.html'))

plot_heatmap_matplotlib(W_initial.numpy(), "Original Weight Matrix (Heatmap)", 
                        os.path.join(original_folder, 'W_original_heatmap.png'))
plot_heatmap_plotly(W_initial.numpy(), "Original Weight Matrix (Interactive Heatmap)", 
                    os.path.join(original_folder, 'W_original_heatmap_interactive.html'))
plot_3d_surface_matplotlib(W_initial.numpy(), "Original Weight Matrix (3D Surface)", 
                           os.path.join(original_folder, 'W_original_3Dsurface.png'))
plot_3d_surface_plotly(W_initial.numpy(), "Original Weight Matrix (3D Surface, Interactive)", 
                       os.path.join(original_folder, 'W_original_3Dsurface_interactive.html'))

plot_sorted_bar(W_initial.numpy(), "Sorted Bar Plot of Original Weights", 
                os.path.join(original_folder, 'W_original_sorted_bar.png'))
plot_violin(W_initial.numpy(), "Violin Plot of Original Weights", 
            os.path.join(original_folder, 'W_original_violin.png'))
plot_kde(W_initial.numpy(), "KDE Plot of Original Weights", 
         os.path.join(original_folder, 'W_original_KDE.png'))
plot_adjacency_graph(W_initial.numpy(), "Adjacency Graph of Original Weights", 
                     os.path.join(original_folder, 'W_original_adjacency.png'), threshold=0.05)
plot_clustered_heatmap(W_initial.numpy(), "Clustered Heatmap of Original Weights", 
                       os.path.join(original_folder, 'W_original_clustered_heatmap.png'))

print(f"Saved original connectivity matrix and all plots in: {original_folder}")

for radius in desired_radii:
    W_rescaled, current_radius, scaling_factor = rescale_matrix(W_initial, radius)
    folder_name = f'rho{str(radius).replace(".", "x")}'
    rho_folder_path = os.path.join(unique_folder_path, folder_name)
    create_directory(rho_folder_path)
    matrix_save_path = os.path.join(rho_folder_path, f'W_rescaled_rho{radius}.npy')
    np.save(matrix_save_path, W_rescaled.numpy())
    base_title = (f"Rescaled to Spectral Radius = {radius}\n"
                  f"(Orig. SR: {current_radius:.3f}, Scaling: {scaling_factor:.3f})")
    plot_and_save_histogram(W_rescaled.numpy(), f"Histogram of W ({base_title})",
                            os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_hist.png'))
    plot_and_save_interactive_histogram(W_rescaled.numpy(), f"Interactive Histogram of W ({base_title})",
                                        os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_hist_interactive.html'))
    plot_heatmap_matplotlib(W_rescaled.numpy(), f"Heatmap of W ({base_title})",
                            os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_heatmap.png'))
    plot_heatmap_plotly(W_rescaled.numpy(), f"Interactive Heatmap of W ({base_title})",
                        os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_heatmap_interactive.html'))
    plot_3d_surface_matplotlib(W_rescaled.numpy(), f"3D Surface of W ({base_title})",
                               os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_3Dsurface.png'))
    plot_3d_surface_plotly(W_rescaled.numpy(), f"3D Surface of W ({base_title})",
                           os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_3Dsurface_interactive.html'))
    plot_sorted_bar(W_rescaled.numpy(), f"Sorted Bar Plot of W ({base_title})",
                    os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_sorted_bar.png'))
    plot_violin(W_rescaled.numpy(), f"Violin Plot of W ({base_title})",
                os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_violin.png'))
    plot_kde(W_rescaled.numpy(), f"KDE Plot of W ({base_title})",
             os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_KDE.png'))
    plot_adjacency_graph(W_rescaled.numpy(), f"Adjacency Graph of W ({base_title})",
                         os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_adjacency.png'), threshold=0.05)
    plot_clustered_heatmap(W_rescaled.numpy(), f"Clustered Heatmap of W ({base_title})",
                           os.path.join(rho_folder_path, f'W_rescaled_rho{radius}_clustered_heatmap.png'))
    print(f"Saved rescaled matrix and plots for spectral radius {radius} in: {rho_folder_path}")


## 80/20 weight matrix

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

def rescale_matrix(W, desired_radius):
    """
    Rescale the matrix W to have the desired spectral radius.
    
    Parameters:
        W (np.ndarray): Weight matrix.
        desired_radius (float): The spectral radius we want to achieve.
        
    Returns:
        W_rescaled (np.ndarray): The rescaled weight matrix.
        current_radius (float): The original spectral radius.
        scaling_factor (float): The scaling factor applied.
    """
    eigenvalues = np.linalg.eigvals(W)
    current_radius = np.max(np.abs(eigenvalues))
    scaling_factor = desired_radius / current_radius if current_radius != 0 else 1.0
    W_rescaled = W * scaling_factor
    return W_rescaled, current_radius, scaling_factor

def generate_weight_matrix(N, sparsity, exc_ratio=0.8, seed=0, weight_scale=1.0, 
                           shuffle_neuron_types=True, desired_radius=None):
    """
    Generate a random weight matrix with the following properties:
      - N neurons, with exc_ratio*N excitatory neurons and the rest inhibitory.
      - The presynaptic neuron's type determines the sign of the weight:
          excitatory: positive, inhibitory: negative.
      - A given sparsity level means each potential connection exists with probability 'sparsity'.
      - Weight magnitudes are randomly drawn from a uniform distribution scaled by weight_scale.
      - Optionally, the matrix is rescaled to have the desired spectral radius.
    
    Parameters:
        N (int): Total number of neurons.
        sparsity (float): Probability (0 to 1) that a connection exists.
        exc_ratio (float): Fraction of excitatory neurons (default is 0.8).
        seed (int): Random seed for reproducibility.
        weight_scale (float): Scale factor for random weight magnitudes.
        shuffle_neuron_types (bool): If True, randomly shuffle neuron types.
        desired_radius (float or None): If provided, rescales the matrix to this spectral radius.
        
    Returns:
        np.ndarray: The generated (and possibly rescaled) weight matrix of shape (N, N).
    """
    np.random.seed(seed)
    num_exc = int(N * exc_ratio)
    num_inh = N - num_exc

    # Define neuron types: excitatory (+1) and inhibitory (-1)
    neuron_types = np.concatenate((np.ones(num_exc), -np.ones(num_inh)))
    if shuffle_neuron_types:
        np.random.shuffle(neuron_types)

    # Create the connectivity mask (True where a connection exists)
    mask = np.random.rand(N, N) < sparsity

    # Create random weight magnitudes
    weights = np.random.uniform(0, weight_scale, size=(N, N))
    
    Dales_principle = False  # If True, apply Dale's principle (each neuron can only be either excitatory/inhibitory, so it will show as row-based signs)
    
    if Dales_principle:
        W = (mask * weights) * neuron_types[:, np.newaxis]

    else:
        # For each potential connection, assign sign +1 with probability 0.8 (if exc_ratio=0.8) and -1 with probability 0.2
        random_signs = np.where(np.random.rand(N, N) < exc_ratio, 1, -1)
        W = mask * weights * random_signs


    # Optionally, rescale the matrix to achieve the desired spectral radius
    if desired_radius is not None:
        W, current_radius, scaling_factor = rescale_matrix(W, desired_radius)
        print(f"Seed {seed}: Original spectral radius = {current_radius:.3f}, "
              f"scaling factor = {scaling_factor:.3f}, "
              f"desired spectral radius = {desired_radius}")
        
    return W

def generate_multiple_weight_matrices(N, sparsity, seeds, exc_ratio=0.8, weight_scale=1.0, 
                                      output_dir="weights", shuffle_neuron_types=True,
                                      desired_radius=None):
    """
    Generate and save multiple weight matrices based on a list of seeds.
    Each weight matrix is saved with the seed number included in the filename.
    
    Parameters:
        N (int): Total number of neurons.
        sparsity (float): Connection probability (0 to 1).
        seeds (list or array-like): Seeds for generating different weight matrices.
        exc_ratio (float): Fraction of excitatory neurons (default: 0.8).
        weight_scale (float): Scale factor for random weight magnitudes.
        output_dir (str): Directory in which to save the weight matrices.
        shuffle_neuron_types (bool): If True, randomly shuffle neuron types before creating the matrix.
        desired_radius (float or None): If provided, rescales each matrix to this spectral radius.
        
    Returns:
        dict: A dictionary mapping seed values to the corresponding saved file path.
    """
    # Create the output directory if it doesn't exist.
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    weight_files = {}
    for seed in seeds:
        W = generate_weight_matrix(N, sparsity, exc_ratio, seed, weight_scale, 
                                   shuffle_neuron_types, desired_radius)
        filename = f"weight_matrix_seed_{seed}.npy"
        filepath = os.path.join(output_dir, filename)
        np.save(filepath, W)
        weight_files[seed] = filepath
        print(f"Saved weight matrix for seed {seed} to {filepath}")
    
    return weight_files

# --------------------- Plotting Functions ---------------------
def plot_weight_histogram(W, bins=50, title="Histogram of Weight Values", show=True, save_path=None):
    plt.figure()
    plt.hist(W.flatten(), bins=bins, color='blue', edgecolor='black')
    plt.xlabel("Weight Value")
    plt.ylabel("Frequency")
    plt.title(title)
    if save_path:
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()

def plot_heatmap(W, title="Heatmap of Weight Matrix", show=True, save_path=None):
    plt.figure()
    plt.imshow(W, interpolation="nearest", aspect="auto", cmap="viridis")
    plt.colorbar(label="Weight Value")
    plt.xlabel("Post-synaptic neuron")
    plt.ylabel("Pre-synaptic neuron")
    plt.title(title)
    if save_path:
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()

from matplotlib.colors import ListedColormap

def plot_connectivity_heatmap(W, title="Connectivity Pattern Heatmap", show=True, save_path=None):
    """
    Plots a connectivity pattern as a heatmap.
    If there is a connection (nonzero entry in W), the pixel is colored blue;
    if there is no connection, the pixel is white.
    
    Parameters:
        W (np.ndarray): The weight matrix.
        title (str): Title for the plot.
        show (bool): Whether to display the plot immediately.
        save_path (str): Path to save the figure (if provided).
    """
    # Create a binary mask: 1 where a connection exists and 0 where it does not.
    connection_mask = (W != 0).astype(float)
    
    # Define a colormap: white for 0 (no connection) and blue for 1 (connection exists)
    cmap = ListedColormap(["white", "blue"])
    
    plt.figure()
    plt.imshow(connection_mask, interpolation="nearest", aspect="auto", cmap=cmap)
    plt.colorbar(ticks=[0, 1], label="Connection (0: no, 1: yes)")
    plt.xlabel("Post-synaptic neuron")
    plt.ylabel("Pre-synaptic neuron")
    plt.title(title)
    if save_path:
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()


def plot_row_weight_stats(W, title="Row Weight Statistics", show=True, save_path=None):
    row_means = np.mean(W, axis=1)
    row_std = np.std(W, axis=1)
    neurons = np.arange(W.shape[0])
    
    plt.figure(figsize=(10,5))
    plt.errorbar(neurons, row_means, yerr=row_std, fmt='o', markersize=3, capsize=3)
    plt.xlabel("Neuron index")
    plt.ylabel("Mean Weight (± std)")
    plt.title(title)
    if save_path:
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()

def plot_all_weight_matrix_visualizations(W, base_title="Weight Matrix Visualization", show=True, save_dir=None):
    if save_dir is not None and not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plot_weight_histogram(W,
                          title=base_title + " - Histogram",
                          show=show,
                          save_path=os.path.join(save_dir, "histogram.png") if save_dir else None)
    plot_heatmap_centered(W,
                              title=base_title + " - Connectivity Pattern",
                              show=show,
                              save_path=os.path.join(save_dir, "spy_plot.png") if save_dir else None)



import numpy as np
import matplotlib.pyplot as plt

def plot_heatmap_centered(W, title="Centered Heatmap of Weight Matrix", show=True, save_path=None):
    """
    Plots a heatmap where the color scale is symmetric with respect to 0.
    Zero is always represented with the same neutral color.
    
    Parameters:
        W (np.ndarray): The weight matrix.
        title (str): Title for the plot.
        show (bool): Whether to display the plot immediately.
        save_path (str): If provided, saves the plot to this path.
    """
    plt.figure()
    
    # Determine the symmetric limits around 0 based on the maximum absolute value in W.
    max_val = np.max(np.abs(W))
    
    # Choose a diverging colormap (e.g., 'bwr' , 'seismic' , 'RdBu')
    cmap = "bwr"
    
    # Plot the matrix with symmetric limits so that v=0 always maps to the same color.
    plt.imshow(W, interpolation="nearest", aspect="auto", cmap=cmap, vmin=-max_val, vmax=max_val)
    plt.colorbar(label="Weight Value")
    plt.xlabel("Post-synaptic neuron")
    plt.ylabel("Pre-synaptic neuron")
    plt.title(title)
    
    if save_path:
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()

# --------------------- Example Usage ---------------------
if __name__ == "__main__":
    # Parameters for the weight matrix generation
    N = 100              # number of neurons
    exc_ratio = 0.8      # 80% excitatory, 20% inhibitory
    weight_scale = 1.0   # maximum magnitude of weight values
    
    # Toggle for single or multiple seeds
    multiple_seeds = True

    # Desired spectral radius; can be set to None to skip scaling.
    """
    SPECTRAL RADIUS DEFINED HERE
    """
    desired_radius = 0.1


    if not multiple_seeds:
        seed = 1
        sparsity = 0.1
        W = generate_weight_matrix(N, sparsity, exc_ratio, seed, weight_scale, 
                                   shuffle_neuron_types=True, desired_radius=desired_radius)
        plot_weight_histogram(W, title=f"Histogram of Weights for Seed {seed}")
        #plot_heatmap(W, title=f"Heatmap of Weight Matrix for Seed {seed}")
        #plot_connectivity_heatmap(W, title=f"Connectivity Pattern for Seed {seed}")
        plot_heatmap_centered(W, title="Centered Heatmap Example")
    else:
        seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        sparsities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        
        for sparsity in sparsities:
            weight_files = generate_multiple_weight_matrices(N, sparsity, seeds, exc_ratio, weight_scale,
                                                            output_dir=f"80_20_weights_sparsity_{sparsity}_rho{desired_radius}",
                                                            shuffle_neuron_types=True,
                                                            desired_radius=desired_radius)
            for seed in seeds:
                W = np.load(weight_files[seed])
                seed_plot_dir = os.path.join("plots", f"seed_{seed}")
                plot_all_weight_matrix_visualizations(W, base_title=f"Weight Matrix for Seed {seed} Visualization",
                                                    show=True, save_dir=seed_plot_dir)

## nn.Linear weights generator

In [82]:
import numpy as np

# Define reservoir size.
reservoir_size = 100
seed = 42
# Generate a random vector uniformly distributed in [0, 1] with shape (100, 1)
input_weights = np.random.uniform(0, 1, size=(reservoir_size, 1))

# Save the vector to a .npy file.
np.save(f"nnLinear_weights_seed{seed}.npy", input_weights)

print("Input weights vector of shape", input_weights.shape, "saved as input_weights.npy")

Input weights vector of shape (100, 1) saved as input_weights.npy


In [85]:
# Load the vector from the .npy file and see a sample of its values.
seed=42
#loaded_weights = np.load(f"/Users/mikel/Documents/GitHub/polimikel/UCR/Weight_matrices/nnLinear_weights_10seeds/nnLinear_weights_seed{seed}.npy")
loaded_weights = np.load(f"/Users/mikel/Documents/GitHub/polimikel/UCR/Weight_matrices/nnLinear_weights.npy")
print("Loaded weights vector shape:", loaded_weights.shape)
print("Sample values from loaded weights vector:", loaded_weights[:5])

Loaded weights vector shape: (100, 1)
Sample values from loaded weights vector: [[0.286081  ]
 [0.55498837]
 [0.52608075]
 [0.25873068]
 [0.02077862]]


## Eigenvalue Plotting

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

# -------------------------------
# Configure the file path
# -------------------------------
# Replace with the full path to your .npy file containing the matrix W.
file_path = "/Users/mikel/Documents/GitHub/polimikel/UCR/Weight_matrices/neg_pos/matrix_seed_42_uniform_20250313_091411/rho0x1/W_rescaled_rho0.1.npy"

# -------------------------------
# Load the matrix from the file based on its extension
# -------------------------------
if file_path.endswith('.npy'):
    # Load from a NumPy binary file
    W_np = np.load(file_path)
    W = torch.tensor(W_np)
elif file_path.endswith('.py'):
    import importlib.util
    spec = importlib.util.spec_from_file_location("loaded_module", file_path)
    if spec is None or spec.loader is None:
        raise FileNotFoundError(f"Could not load the module from file path: {file_path}.")
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    if not hasattr(mod, "W"):
        raise AttributeError("The module does not contain an attribute named 'W'.")
    W = mod.W
else:
    raise ValueError("Unsupported file extension. Please use a .npy or .py file.")

# -------------------------------
# Compute Eigenvalues: lambda = a + i*b where a = real part, b = imaginary part
# -------------------------------
eigenvalues = torch.linalg.eigvals(W)
eigenvalues_np = eigenvalues.detach().cpu().numpy()

# Get the directory where the file is located
dir_path = os.path.dirname(file_path)

# -------------------------------
# Plot 1: Eigenvalues in the Complex Plane
# -------------------------------
plt.figure(figsize=(6,6))
plt.scatter(eigenvalues_np.real, eigenvalues_np.imag, color='blue', alpha=0.7)
plt.xlabel("Real Part")
plt.ylabel("Imaginary Part")
plt.title("Eigenvalues of W in the Complex Plane")
plt.grid(True)
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)

# Save the complex plane plot
plot1_path = os.path.join(dir_path, "eigenvalues_complex_plane.png")
plt.savefig(plot1_path)
plt.show()

# -------------------------------
# Plot 2: Histogram of Eigenvalue Magnitudes, so we be plotting |lambda| = sqrt(a^2 + b^2)
# -------------------------------
magnitudes = np.abs(eigenvalues_np)
plt.figure()
plt.hist(magnitudes, bins=20, color='green', alpha=0.7)
plt.xlabel("Magnitude")
plt.ylabel("Frequency")
plt.title("Histogram of Eigenvalue Magnitudes")

# Save the histogram plot
plot2_path = os.path.join(dir_path, "eigenvalues_magnitude_histogram.png")
plt.savefig(plot2_path)
plt.show()

print(f"Plots saved to:\n{plot1_path}\n{plot2_path}")