In [2]:
from datasets import Dataset
import torch

In [None]:
# ds_voice = Dataset.load_from_disk('activations/musiccaps_voice_along_freqs/stable-audio-open-1.0/transformer_blocks.11.attn2')
# ds_voice = Dataset.load_from_disk('activations/musiccaps_voice.csv/stable-audio-open-1.0/transformer_blocks.11.attn2')
ds_voice = Dataset.load_from_disk('activations/musiccaps_voice_along_time/stable-audio-open-1.0/transformer_blocks.11.attn2')
ds_voice.set_format(
    type="torch",
    columns=["activations", "timestep"],
    dtype=torch.float16
)

In [None]:
ds_ts = ds_voice['timestep']
ds_ts[0], ds_ts[1], ds_ts[9], ds_ts[10], ds_ts[19], ds_ts[20], ds_ts[21]

In [126]:
import numpy as np

# 990, 890, ..., 90

idx_to_ts = lambda idx: int(np.linspace(990, 90, 10)[idx].item())

In [127]:
# SAE Latents

In [128]:
import os
import torch
from src.sae.sae import Sae
from tqdm import tqdm

torch.backends.cuda.matmul.allow_tf32 = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True


In [129]:
sae = Sae.load_from_disk(
    "sae-ckpts/music_sae/batch_topk_expansion_factor4_k64_multi_topkFalse_auxk_alpha0.03125_lr8e-06_musiccaps_public_2_along_time/transformer_blocks.11.attn2", device="cuda"
).eval()
sae = sae.to(dtype=torch.float16)

In [130]:
activations_voice = ds_voice['activations']

In [131]:
timesteps = ds_voice['timestep']

In [None]:
n_prompts = 4 * 512
n_ts = 10
tir = []
with torch.no_grad():
    for idx in tqdm(range(n_prompts)):
        activations_prompt = activations_voice[idx * n_ts:(idx + 1) * n_ts]
        sae_in = activations_prompt.reshape(n_ts, -1, sae.d_in)

        _, top_indices = sae.encode(sae_in.to(sae.device))
        top_indices_reshaped = top_indices.reshape(n_ts, -1, top_indices.shape[-1]).cpu()
        tir.append(top_indices_reshaped)

    tir = torch.stack(tir, dim=0)

In [133]:
import numpy as np
import torch

# Along frequency axis

In [134]:
import torch

timestep = 3
n_prompts = tir.shape[0] # 2048
seed = 42
batch_size = 64

LIMIT_FIRST_FREQS = tir.shape[2] # tir.shape[2] # 256

torch.manual_seed(seed)
prompts_indices = torch.randint(0, tir.shape[0], (n_prompts,))
x_time = tir[prompts_indices, timestep, :LIMIT_FIRST_FREQS, :]  # [B, 1536, 64]
x_time = x_time.to(torch.int16)
B, N, K = x_time.shape

with torch.no_grad():
    max_token = x_time.max().item() + 1
    avg_intersection = torch.zeros((N, N), dtype=torch.float32)

    for start in range(0, B, batch_size):
        end = min(start + batch_size, B)
        chunk = x_time[start:end].to("cuda")  # [bs, 1536, 64]

        bs = chunk.shape[0]
        one_hot = torch.zeros((bs, N, max_token), dtype=torch.bool, device=chunk.device)
        for k in range(K):
            one_hot.scatter_(2, chunk[:, :, k:k+1].long(), 1)

        intersection = torch.matmul(one_hot.float(), one_hot.transpose(1, 2).float())  # [bs, 1536, 1536]
        avg_intersection += intersection.sum(dim=0).cpu()

    avg_intersection /= B


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

intersection = avg_intersection.cpu().numpy()

plt.figure(figsize=(20, 20))
ax = sns.heatmap(intersection, cmap='viridis', vmin=0, vmax=64, square=True, 
                 cbar_kws={"label": "Average # of Shared Indices"})
ax.set_title(f"SAE Latents Intersection Map\n(timestep={idx_to_ts(timestep)}, T=1000)", fontsize=36, pad=20)
num_ticks = 10
tick_positions = np.linspace(0, intersection.shape[0] - 1, num=num_ticks, dtype=int)
tick_positions_lst = list(tick_positions)
tick_positions_lst.insert(8, np.int64(431))
tick_positions = np.array(tick_positions_lst)

# Apply custom ticks
ax.set_xticks(tick_positions)
ax.set_yticks(tick_positions)
ax.set_xticklabels(tick_positions, fontsize=14)
ax.set_yticklabels(tick_positions, fontsize=14)
ax.set_xlabel("Time_idx", fontsize=24)
ax.set_xlabel("Time_idx", fontsize=24)
# ax.set_ylabel("Frequency", fontsize=24)
# ax.set_ylabel("Frequency", fontsize=24)
plt.show()


In [None]:
import seaborn as sns
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import squareform

# Convert to NumPy
avg_np = avg_intersection.cpu().numpy()

# Convert similarity matrix to distance matrix
# Higher overlap → smaller distance
distance = 64 - avg_np  # Since max intersection is 64

# Hierarchical clustering (Ward linkage)
linkage = sch.linkage(squareform(distance, checks=False), method='ward')

# Reorder matrix
dendro = sch.dendrogram(linkage, no_plot=True)
reordered = avg_np[dendro['leaves'], :][:, dendro['leaves']]

# Plot clustered heatmap
plt.figure(figsize=(20, 20))
ax = sns.heatmap(reordered, cmap='viridis', vmin=0, vmax=64, square=True, 
                 cbar_kws={"label": "Average # of Shared Indices"})

# Custom tick positions: 10 evenly spaced indices across 1536
num_ticks = 10
tick_positions = np.linspace(0, reordered.shape[0] - 1, num=num_ticks, dtype=int)

# Apply custom ticks
ax.set_xticks([])
ax.set_yticks([])

# Labeling
ax.set_title(f"Clustered SAE Latents Intersection Map\n(timestep={idx_to_ts(timestep)}, T=1000)", fontsize=36, pad=20)
ax.set_xlabel("Frequency (grouped)", fontsize=24)
ax.set_ylabel("Frequency (grouped)", fontsize=24)
plt.tight_layout()
plt.show()

# Along diffusion timesteps

In [144]:
import torch

n_prompts = tir.shape[0] # 2048
seed = 42
batch_size = 64

frequency_idx = 400

torch.manual_seed(seed)
prompts_indices = torch.randint(0, tir.shape[0], (n_prompts,))
x_time = tir[prompts_indices, :, frequency_idx, :]  # [B, 10, 64]
x_time = x_time.to(torch.int16)
B, N, K = x_time.shape

with torch.no_grad():
    max_token = x_time.max().item() + 1
    avg_intersection = torch.zeros((N, N), dtype=torch.float32)

    for start in range(0, B, batch_size):
        end = min(start + batch_size, B)
        chunk = x_time[start:end].to("cuda")  # [bs, 1536, 64]

        bs = chunk.shape[0]
        one_hot = torch.zeros((bs, N, max_token), dtype=torch.bool, device=chunk.device)
        for k in range(K):
            one_hot.scatter_(2, chunk[:, :, k:k+1].long(), 1)

        intersection = torch.matmul(one_hot.float(), one_hot.transpose(1, 2).float())  # [bs, 1536, 1536]
        avg_intersection += intersection.sum(dim=0).cpu()

    avg_intersection /= B


In [None]:
import matplotlib.pyplot as plt

intersection = avg_intersection.cpu().numpy()

plt.figure(figsize=(10, 10))
ax = sns.heatmap(intersection, cmap='viridis', vmin=0, vmax=64, square=True, 
                 cbar_kws={"label": "Average # of Shared Indices"})
# ax.set_title(f"SAE Latents Intersection Map\n(frequency_idx={frequency_idx})", fontsize=36, pad=20)
ax.set_title(f"SAE Latents Intersection Map\n(time_idx={frequency_idx})", fontsize=36, pad=20)

num_ticks = 10
tick_positions = np.linspace(0, intersection.shape[0] - 1, num=num_ticks, dtype=int)
tick_labels = [f"{idx_to_ts(t)}" for t in tick_positions]

# Apply custom ticks
ax.set_xticks(tick_positions)
ax.set_yticks(tick_positions)
ax.set_xticklabels(tick_labels, fontsize=14)
ax.set_yticklabels(tick_labels, fontsize=14)
ax.set_xlabel("Diffusion Timestep", fontsize=24)
ax.set_ylabel("Diffusion Timestep", fontsize=24)
plt.show()


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

n_prompts = tir.shape[0]
seed = 42
batch_size = 64

torch.manual_seed(seed)
prompts_indices = torch.randint(0, tir.shape[0], (n_prompts,))

# Create grid of 40 frequency indices to plot
num_plots = 40
# freq_indices = np.linspace(0, 1535, num=num_plots, dtype=int)
# random = False
freq_indices = np.random.randint(400, 1024, size=num_plots)
random = True

# Set up large grid of subplots
cols = 10
rows = (num_plots + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))

# Flatten axes for easy access
axes = axes.flatten()

# Prepare to store one heatmap for colorbar reference
last_hm = None

for i, freq_idx in enumerate(freq_indices):
    ax = axes[i]
    x_time = tir[prompts_indices, :, freq_idx, :].to(torch.int16)  # [B, 10, 64]
    B, N, K = x_time.shape

    with torch.no_grad():
        max_token = x_time.max().item() + 1
        avg_intersection = torch.zeros((N, N), dtype=torch.float32)

        for start in range(0, B, batch_size):
            end = min(start + batch_size, B)
            chunk = x_time[start:end].to("cuda")

            bs = chunk.shape[0]
            one_hot = torch.zeros((bs, N, max_token), dtype=torch.bool, device=chunk.device)
            for k in range(K):
                one_hot.scatter_(2, chunk[:, :, k:k+1].long(), 1)

            intersection = torch.matmul(one_hot.float(), one_hot.transpose(1, 2).float())
            avg_intersection += intersection.sum(dim=0).cpu()

        avg_intersection /= B

    mat = avg_intersection.numpy()
    last_hm = sns.heatmap(mat, cmap='viridis', vmin=0, vmax=64, square=True, cbar=False, ax=ax)

    ax.set_title(f"Time index: {freq_idx}", fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])

# Remove unused axes
for j in range(num_plots, len(axes)):
    fig.delaxes(axes[j])

# Add shared colorbar
cbar_ax = fig.add_axes([0.91, 0.05, 0.02, 0.9])  # [left, bottom, width, height]
cbar = fig.colorbar(last_hm.get_children()[0], cax=cbar_ax)
cbar.set_label("Average # of Shared Indices", fontsize=24)
cbar.ax.tick_params(labelsize=20)  # Set tick label size


fig.suptitle(f"{'Randomized ' if random else ''}SAE Latents Intersection Across Diffusion Timesteps", fontsize=34, y=1.01)
plt.tight_layout(rect=[0, 0, 0.9, 1.0])
plt.show()
