In [None]:
# read all the permutations files in 'output/CIFAR10/frank_wolfe/history'

import os
import json
from ccmm.matching.utils import load_permutations
from nn_core.common import PROJECT_ROOT

In [None]:
permutations_dir = os.path.join(PROJECT_ROOT, "output", "CIFAR10", "frank_wolfe", "history")
chosen_perm = "P_bg0"

all_permutations = []
for filename in sorted(os.listdir(permutations_dir)):
    if filename.endswith(".json"):
        file = os.path.join(permutations_dir, filename)
        permutation = load_permutations(file, matrix_format=True)
        all_permutations.append(permutation["a"]["b"][chosen_perm])

In [None]:
import matplotlib.pyplot as plt
from ccmm.matching.utils import perm_indices_to_perm_matrix
import numpy as np
import torch

# Assuming `permutation_matrices` is your list of torch tensors

# Set the number of rows and columns for subplot
num_matrices = len(all_permutations)
ncols = 3  # for example, can change as needed
nrows = (num_matrices + ncols - 1) // ncols

plt.figure(figsize=(ncols * 4, nrows * 4))

for i, perm in enumerate(all_permutations):
    plt.subplot(nrows, ncols, i + 1)
    plt.imshow(perm.numpy(), cmap="hot", interpolation="nearest")
    plt.title(f"Matrix {i+1}")
    plt.colorbar()

plt.tight_layout()
plt.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import plotly.io as pio

num_matrices = 6  # len(all_permutations)
ncols = 3
nrows = (num_matrices + ncols - 1) // ncols

fig = make_subplots(
    rows=nrows,
    cols=ncols,
    subplot_titles=["Initialization"] + [f"Step {i+1}" for i in range(1, num_matrices)],
    vertical_spacing=0.2,
)  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed

fig.update_layout(
    height=300 * nrows,
    width=300 * ncols,
    plot_bgcolor="rgba(255,255,255,255)",  # Transparent background
    margin=dict(l=25, r=25, t=25, b=25),  # Margin around the whole figure
    paper_bgcolor="rgba(255,255,255,1)",  # White background for the paper
    font=dict(size=25, color="black"),  # Font for titles and labels
)

for i in range(1, nrows + 1):
    for j in range(1, ncols + 1):
        fig.update_xaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)
        fig.update_yaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)

colorscale = [[0, "cornsilk"], [1e-8, "red"], [1, "blue"]]

for i, perm in enumerate(all_permutations[:num_matrices]):
    row = i // ncols + 1
    col = i % ncols + 1
    fig.add_trace(go.Heatmap(z=perm.numpy(), colorscale=colorscale, showscale=False), row=row, col=col)

fig.update_layout(height=300 * nrows, width=300 * ncols)
fig.update_annotations(font_size=25)

# add colorbar
fig.add_trace(go.Heatmap(z=[[0, 1]], colorscale=colorscale, showscale=True), row=nrows, col=ncols)

fig.show()

pio.write_image(fig, "figures/permutation_matrices.pdf")

In [None]:
ncols = 3  # for example, can change as needed
nrows = (num_matrices + ncols - 1) // ncols


margin = 0.02  # Margin around each subplot
spacing = 0.02  # Spacing between subplots

fig = make_subplots(
    rows=nrows, cols=ncols, subplot_titles=["Initialization"] + [f"Step {i+1}" for i in range(1, num_matrices)]
)  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed

for i, matrix in enumerate(all_permutations[:num_matrices]):
    row = i // ncols + 1
    col = i % ncols + 1

    # Convert tensor to numpy array
    matrix_np = matrix.numpy()

    # Define a custom colorscale
    colorscale = [[0, "white"], [1e-8, "red"], [1, "blue"]]

    fig.add_trace(go.Heatmap(z=matrix_np, colorscale=colorscale, showscale=False), row=row, col=col)


# Update layout with a border and hide axes
fig.update_layout(
    height=300 * nrows,
    width=300 * ncols,
    plot_bgcolor="rgba(255,255,255,255)",  # Transparent background
    # margin=dict(l=20, r=20, t=20, b=20),  # Margin around the whole figure
    paper_bgcolor="rgba(255,255,255,1)",  # White background for the paper
    font=dict(size=25, color="black"),  # Font for titles and labels
)

# Calculate and draw borders
for i in range(num_matrices):
    row = i // ncols
    col = i % ncols

    x0 = (col / ncols) + margin
    y0 = 1 - ((row + 1) / nrows) + margin
    x1 = ((col + 1) / ncols) - margin - spacing
    y1 = 1 - (row / nrows) - margin - spacing

    fig.add_shape(
        type="rect", xref="paper", yref="paper", x0=x0, y0=y0, x1=x1, y1=y1, line=dict(color="Black", width=2)
    )

# Update layout

# Hide x and y axes for all subplots
for i in range(1, nrows + 1):
    for j in range(1, ncols + 1):
        fig.update_xaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)
        fig.update_yaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)


fig.show()