In [1]:
from datasets import Dataset
import torch

In [2]:
import os
from src.sae.sae import Sae
from tqdm import tqdm

torch.backends.cuda.matmul.allow_tf32 = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns


In [4]:
import numpy as np

idx_to_ts = lambda idx: float(np.linspace(990, 90, 10)[idx].item() / 1000)

In [17]:
def retrieve_top_latents_per_timestep(activations_tensor, sae, diffusion_timesteps, n_prompts = 4 * 512):
    tir = []
    with torch.no_grad():
        for idx in tqdm(range(n_prompts)):
            activations_prompt = activations_tensor[idx * diffusion_timesteps:(idx + 1) * diffusion_timesteps]
            sae_in = activations_prompt.reshape(diffusion_timesteps, -1, sae.d_in)

            _, top_indices = sae.encode(sae_in.to(sae.device))
            top_indices_reshaped = top_indices.reshape(diffusion_timesteps, -1, top_indices.shape[-1]).cpu()
            tir.append(top_indices_reshaped)
        tir = torch.stack(tir, dim=0)
    return tir

def compute_latents_intersection_map(top_latents_per_ts, timestep, n_prompts, seed=42, batch_size=64, limit_to_first_freqs=None, device="cuda"):
    # top_latents_per_ts: [n_all_prompts, n_timesteps, n_freqs/n_time, top_k_latents]
    if limit_to_first_freqs is None:
        limit_to_first_freqs = top_latents_per_ts.shape[2]
    assert 0 <= timestep < top_latents_per_ts.shape[1], "???"

    torch.manual_seed(seed)
    prompts_indices = torch.randint(0, top_latents_per_ts.shape[0], (n_prompts,))
    # x_time: [prompts, n_freqs/n_time, top_k_latents]
    x_time = top_latents_per_ts[prompts_indices, timestep, :limit_to_first_freqs, :]
    x_time = x_time.to(torch.int16)
    B, N, K = x_time.shape

    with torch.no_grad():
        max_token = x_time.max().item() + 1
        avg_intersection = torch.zeros((N, N), dtype=torch.float32)

        for start in range(0, B, batch_size):
            end = min(start + batch_size, B)
            # chunk:[bsz, n_freqs/n_time, top_k_latents]
            chunk = x_time[start:end].to(device)
            bs = chunk.shape[0]
            one_hot = torch.zeros((bs, N, max_token), dtype=torch.bool, device=chunk.device)
            for k in range(K):
                one_hot.scatter_(2, chunk[:, :, k:k+1].long(), 1)

            # intersection: [bs, n_freqs/n_time, n_freqs/n_time]
            intersection = torch.matmul(one_hot.float(), one_hot.transpose(1, 2).float())
            avg_intersection += intersection.sum(dim=0).cpu()

        avg_intersection /= B
    return avg_intersection.cpu().numpy()

def compute_latents_intersection_map_diffusion_timesteps(top_latents_per_ts, freq_seq_idx, n_prompts, seed=42, batch_size=64,device="cuda"):
    # top_latents_per_ts: [n_all_prompts, n_timesteps, n_freqs/n_time, top_k_latents]
    assert 0 <= freq_seq_idx < top_latents_per_ts.shape[2], "???"

    torch.manual_seed(seed)
    prompts_indices = torch.randint(0, top_latents_per_ts.shape[0], (n_prompts,))
    # x_time: [prompts, n_timesteps, top_k_latents]
    x_time = top_latents_per_ts[prompts_indices, :, freq_seq_idx, :]
    x_time = x_time.to(torch.int16)
    B, N, K = x_time.shape

    with torch.no_grad():
        max_token = x_time.max().item() + 1
        avg_intersection = torch.zeros((N, N), dtype=torch.float32)

        for start in range(0, B, batch_size):
            end = min(start + batch_size, B)
            # chunk:[bsz, n_timesteps, top_k_latents]
            chunk = x_time[start:end].to(device)
            bs = chunk.shape[0]
            one_hot = torch.zeros((bs, N, max_token), dtype=torch.bool, device=chunk.device)
            for k in range(K):
                one_hot.scatter_(2, chunk[:, :, k:k+1].long(), 1)

            # intersection: [bs, n_timesteps, n_timesteps]
            intersection = torch.matmul(one_hot.float(), one_hot.transpose(1, 2).float())
            avg_intersection += intersection.sum(dim=0).cpu()

        avg_intersection /= B
    return avg_intersection.cpu().numpy()


def plot_intersection_map(intersection_map, timestep_idx, top_k_latents, freq_or_time, tick_position = None, title=None, vmin=None):
    plt.figure(figsize=(20, 20))
    vmin = 0 if vmin is None else vmin
    title = f"SAE Latents Intersection Map\n(timestep={idx_to_ts(timestep_idx)})" if title is None else title
    ax = sns.heatmap(intersection_map, cmap='viridis', vmin=vmin, vmax=top_k_latents, square=True,
                    cbar_kws={"label": "Average # of Shared Indices"})
    ax.set_title(title, fontsize=36, pad=20)

    num_ticks = 10
    tick_positions = np.linspace(0, intersection_map.shape[0] - 1, num=num_ticks, dtype=int)
    tick_positions_lst = list(tick_positions)
    if tick_position is not None:
        tick_positions_lst.append(tick_position)
        tick_positions_lst.sort()
    tick_positions = np.array(tick_positions_lst)
    ax.set_xticks(tick_positions)
    ax.set_yticks(tick_positions)
    ax.set_xticklabels(tick_positions, fontsize=14)
    ax.set_yticklabels(tick_positions, fontsize=14)

    xy_label = "Time_idx" if freq_or_time == "time" else "Frequency"
    ax.set_xlabel(xy_label, fontsize=24)
    ax.set_ylabel(xy_label, fontsize=24)
    plt.show()

def plot_intersection_map_diffusion_timesteps(intersection_map, time_freq_idx, top_k_latents, freq_or_time, title=None, vmin=None):
    plt.figure(figsize=(20, 20))
    vmin = 0 if vmin is None else vmin
    obj = "frequency" if freq_or_time == "freq" else "time"
    title = f"SAE Latents Intersection Map\n({obj}={time_freq_idx})" if title is None else title
    ax = sns.heatmap(intersection_map, cmap='viridis', vmin=vmin, vmax=top_k_latents, square=True,
                    cbar_kws={"label": "Average # of Shared Indices"})
    ax.set_title(title, fontsize=36, pad=20)

    num_ticks = 10
    tick_positions = np.linspace(0, intersection_map.shape[0] - 1, num=num_ticks, dtype=int)
    tick_labels = np.array([idx_to_ts(x) for x in tick_positions])
    ax.set_xticks(tick_positions)
    ax.set_yticks(tick_positions)
    ax.set_xticklabels(tick_labels, fontsize=14)
    ax.set_yticklabels(tick_labels, fontsize=14)

    xy_label = "Diffusion timestep"
    ax.set_xlabel(xy_label, fontsize=24)
    ax.set_ylabel(xy_label, fontsize=24)
    plt.show()

def plot_intersection_map_diff_ts_multiple_freqs(top_latents_per_ts, top_k_latents, freq_or_time, title=None, vmin=None, seed=42, n_plots = 40, n_cols = 10, start_from_idx = None):
    # top_latents_per_ts: [n_all_prompts, n_timesteps, n_freqs/n_time, top_k_latents]
    vmin = 0 if vmin is None else vmin
    start_from_idx = 0 if start_from_idx is None else start_from_idx

    n_prompts = top_latents_per_ts.shape[0]
    np.random.seed(seed)
    freq_time_indices = np.random.randint(start_from_idx, top_latents_per_ts.shape[2], size=n_plots)

    n_rows = (n_plots + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    axes = axes.flatten()
    last_hm = None
    for i, freq_time_idx in enumerate(freq_time_indices):
        ax = axes[i]
        mat = compute_latents_intersection_map_diffusion_timesteps(top_latents_per_ts=top_latents_per_ts, freq_seq_idx=freq_time_idx, n_prompts=n_prompts, seed=seed)
        last_hm = sns.heatmap(mat, cmap='viridis', vmin=vmin, vmax=top_k_latents, square=True, cbar=False, ax=ax)
        obj = "Freq" if freq_or_time == "freq" else "Time"
        ax.set_title(f"{obj} index: {freq_time_idx}", fontsize=18)
        ax.set_xticks([])
        ax.set_yticks([])

    for j in range(n_plots, len(axes)):
        fig.delaxes(axes[j])

    # shared colorbar
    cbar_ax = fig.add_axes([0.91, 0.05, 0.02, 0.9])  # [left, bottom, width, height]
    cbar = fig.colorbar(last_hm.get_children()[0], cax=cbar_ax)
    cbar.set_label("Average # of Shared Indices", fontsize=24)
    cbar.ax.tick_params(labelsize=20)  # Set tick label size

    title = f"Randomized SAE Latents Intersection Across Diffusion Timesteps" if title is None else title
    fig.suptitle(title, fontsize=34, y=1.01)
    plt.tight_layout(rect=[0, 0, 0.9, 1.0])
    plt.show()

In [6]:
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import squareform

def cluster_intersection_map(intersection_map, top_k_latents, freq_or_time, timestep, title=None):
    assert freq_or_time in ["freq", "time"], "???"
    distances = top_k_latents - intersection_map
    linkage = sch.linkage(squareform(distances, checks=False), method='ward')
    dendro = sch.dendrogram(linkage, no_plot=True)
    reordered = intersection_map[dendro['leaves'], :][:, dendro['leaves']]

    plt.figure(figsize=(20, 20))
    ax = sns.heatmap(reordered, cmap='viridis', vmin=0, vmax=top_k_latents, square=True, 
                    cbar_kws={"label": "Average # of Shared Indices"})
    num_ticks = 10
    ax.set_xticks([])
    ax.set_yticks([])

    title = f"Clustered SAE Latents Intersection Map\n(timestep={idx_to_ts(timestep)})" if title is None else title
    ax.set_title(title, fontsize=36, pad=20)
    xy_label = "Time_idx" if freq_or_time == "time" else "Frequency"
    xy_label += " (clustered)"
    ax.set_xlabel(xy_label, fontsize=24)
    ax.set_ylabel(xy_label, fontsize=24)
    plt.tight_layout()
    plt.show()

# Frequency SAE

In [58]:
N_DIFFUSION_TIMESTEPS = 10
TOP_K_LATENTS = 32
FREQ_OR_TIME = "freq"

In [None]:
ds_voice = Dataset.load_from_disk('activations/dpmscheduler/musiccaps_voice4_10s_alongfreq/stable-audio-open-1.0/transformer_blocks.11.attn2')
ds_voice.set_format(
    type="torch",
    columns=["activations", "timestep"],
    dtype=torch.float16
)
ds_ts = ds_voice['timestep']
activations_voice = ds_voice['activations']

In [27]:
sae = Sae.load_from_disk(
    "sae-ckpts/music_sae/stableaudio_dpm/batch_topk_expansion_factor4_k32_multi_topkFalse_auxk_alpha0.03125_lr8e-06_epochs2_musiccaps_public2_10s_alongfreq/transformer_blocks.11.attn2", device="cuda"
).eval()
sae = sae.to(dtype=torch.float16)

In [None]:
tir = retrieve_top_latents_per_timestep(activations_tensor=activations_voice, sae=sae, diffusion_timesteps=N_DIFFUSION_TIMESTEPS, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS)

In [71]:
TIMESTEP_IDX = 8
LIMIT_TO_FIRST_FREQS = 256

In [72]:
avg_interesection_map = compute_latents_intersection_map(top_latents_per_ts=tir, timestep=TIMESTEP_IDX, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS, limit_to_first_freqs=LIMIT_TO_FIRST_FREQS)

In [None]:
plot_intersection_map(avg_interesection_map, timestep_idx=TIMESTEP_IDX, top_k_latents=TOP_K_LATENTS, freq_or_time=FREQ_OR_TIME)

In [None]:
TIMESTEP_IDX1 = 2
TIMESTEP_IDX2 = 8
LIMIT_TO_FIRST_FREQS = 256

avg_interesection_map1 = compute_latents_intersection_map(top_latents_per_ts=tir, timestep=TIMESTEP_IDX1, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS, limit_to_first_freqs=LIMIT_TO_FIRST_FREQS)
avg_interesection_map2 = compute_latents_intersection_map(top_latents_per_ts=tir, timestep=TIMESTEP_IDX2, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS, limit_to_first_freqs=LIMIT_TO_FIRST_FREQS)
# diff_map = np.where(
#     avg_interesection_map1 > avg_interesection_map2,
#     1,
#     np.where(
#         avg_interesection_map1 < avg_interesection_map2,
#         -1,
#         0
#     )
# )
# title = "YELLOW: INTERSECT$_{ts=0.79}>$INTERSECT$_{ts=0.19}$\nVIOLET: INTERSECT$_{ts=0.79}<$INTERSECT$_{ts=0.19}$"

diff_map = np.abs(avg_interesection_map1 - avg_interesection_map2) / avg_interesection_map1
title = r"$\frac{|INTERSECT_{ts=0.79}-INTERSECT_{ts=0.19}|}{INTERSECT_{ts=0.79}}$"

vmax = np.ceil(diff_map.max()).item()
vmin = np.floor(diff_map.min()).item()
plot_intersection_map(diff_map, timestep_idx=TIMESTEP_IDX, top_k_latents=vmax, freq_or_time=FREQ_OR_TIME, vmin=vmin, title=title)




In [None]:
TIMESTEP_IDX = 2
LIMIT_TO_FIRST_FREQS = None
avg_interesection_map = compute_latents_intersection_map(top_latents_per_ts=tir, timestep=TIMESTEP_IDX, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS, limit_to_first_freqs=LIMIT_TO_FIRST_FREQS)

cluster_intersection_map(intersection_map=avg_interesection_map, top_k_latents=TOP_K_LATENTS, freq_or_time=FREQ_OR_TIME, timestep=TIMESTEP_IDX)

+ along diffusion timesteps

In [None]:
FREQUENCY_IDX = 300

dts_intersection_map = compute_latents_intersection_map_diffusion_timesteps(top_latents_per_ts=tir, freq_seq_idx=FREQUENCY_IDX, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS)
plot_intersection_map_diffusion_timesteps(intersection_map=dts_intersection_map, time_freq_idx=FREQUENCY_IDX, top_k_latents=TOP_K_LATENTS, freq_or_time=FREQ_OR_TIME)

In [None]:
plot_intersection_map_diff_ts_multiple_freqs(
    top_latents_per_ts=tir,
    top_k_latents=TOP_K_LATENTS,
    freq_or_time=FREQ_OR_TIME,
    seed=42,
    n_plots=40,
    n_cols=10
)

# Time SAE

In [7]:
N_DIFFUSION_TIMESTEPS = 10
TOP_K_LATENTS = 64
FREQ_OR_TIME = "time"

In [None]:
ds_voice = Dataset.load_from_disk('activations/dpmscheduler/musiccaps_voice4_10s_alongtime/stable-audio-open-1.0/transformer_blocks.11.attn2')
ds_voice.set_format(
    type="torch",
    columns=["activations", "timestep"],
    dtype=torch.float16
)
ds_ts = ds_voice['timestep']
activations_voice = ds_voice['activations']

In [9]:
sae = Sae.load_from_disk(
    "sae-ckpts/music_sae/stableaudio_dpm/batch_topk_expansion_factor4_k64_multi_topkFalse_auxk_alpha0.03125_lr8e-06_musiccaps_public2_10s_alongtime/transformer_blocks.11.attn2", device="cuda"
).eval()
sae = sae.to(dtype=torch.float16)

In [None]:
tir = retrieve_top_latents_per_timestep(activations_tensor=activations_voice, sae=sae, diffusion_timesteps=N_DIFFUSION_TIMESTEPS, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS)

In [None]:
TIMESTEP_IDX = 2
LIMIT_TO_FIRST_FREQS = None
avg_interesection_map = compute_latents_intersection_map(top_latents_per_ts=tir, timestep=TIMESTEP_IDX, n_prompts=activations_voice.shape[0] // N_DIFFUSION_TIMESTEPS, limit_to_first_freqs=LIMIT_TO_FIRST_FREQS)
plot_intersection_map(avg_interesection_map, timestep_idx=TIMESTEP_IDX, top_k_latents=TOP_K_LATENTS, freq_or_time=FREQ_OR_TIME)

In [None]:
plot_intersection_map_diff_ts_multiple_freqs(
    top_latents_per_ts=tir,
    top_k_latents=TOP_K_LATENTS,
    freq_or_time=FREQ_OR_TIME,
    seed=333,
    n_plots=20,
    n_cols=5,
    start_from_idx=500
)