In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.animation as animation
from matplotlib import rc
from IPython.display import HTML
from matplotlib import rcParams
import seaborn as sns
from textwrap import wrap
from tqdm import tqdm
import pandas as pd
import re
import math
import gc

In [None]:
def read_tensors_aux(dir, pt_name_predicate=None):
    """ recursively read tensors from a directory """
    if pt_name_predicate is None:
        pt_name_predicate = lambda x, y: True
    tensors = {}
    for path in os.listdir(dir):
        if path.endswith('.pt'):
            key = path[:-3]
            if pt_name_predicate is None or pt_name_predicate(key, dir):
                try:
                    tensors[key] = torch.load(os.path.join(dir, path))
                except Exception as e:
                    print(f"Error loading {path}: {e}")
        elif os.path.isdir(os.path.join(dir, path)):
            key = path.split('_')[1] if path.startswith('timestep_') else path 
            tensors[key] = read_tensors_aux(os.path.join(dir, path), pt_name_predicate)
        else:
            print(f"Skipping {path}, not a tensor or directory")
    return tensors

In [None]:
def move_files_from_subdirs_to_dir(root_dir='outputs'):
    for subdir in os.listdir(root_dir):
        subdir_path = os.path.join(root_dir, subdir)
        if os.path.isdir(subdir_path):
            for file_name in os.listdir(subdir_path):
                src_path = os.path.join(subdir_path, file_name)
                dst_path = subdir_path + '.png'
                os.rename(src_path, dst_path)
            os.rmdir(subdir_path)

In [None]:
move_files_from_subdirs_to_dir('outputs/dift')

In [None]:
dir = 'tensors/outputs/sd3.5_medium/dift'
gen_experiments = [ f for f in os.listdir(dir) if f.startswith('2025') ]

In [None]:
gen_experiments

In [None]:
gen_experiments = [exp for exp in gen_experiments if '_A_dog_' in exp]

tensors = {}
for exp in gen_experiments:
    exp_dir = os.path.join(dir, exp, '000000')
    seed = int(exp.split('_')[-1])
    _t = read_tensors_aux(exp_dir)
    tensors[seed] = {
        'x': { int(k.split('=')[-1]): v.squeeze(0) for k, v in _t.items() if not 'grad' in k },
        'grad': { int(k.split('=')[-1]): v.squeeze(0) for k, v in _t.items() if 'grad' in k } 
    }
    tensors[seed]['x'] = torch.stack( [ tensors[seed]['x'][t] for t in sorted(tensors[seed]['x'].keys())[::-1] ] )
    tensors[seed]['grad'] = torch.stack( [ tensors[seed]['grad'][t] for t in sorted(tensors[seed]['grad'].keys())[::-1] ] )


In [None]:
num_seeds = len(tensors)
print(f"Number of seeds: {num_seeds}")

### plotting images

In [None]:
images_dir = 'outputs/dift'
filenames = sorted([f for f in os.listdir(images_dir) if f.endswith('.png') and 'dog' in f])
seed_to_image = { int(f.split('_')[-1].split('.')[0]): f for f in filenames }
num_seeds = len(seed_to_image.keys())

fig, axes = plt.subplots(1, num_seeds, figsize=(num_seeds * 2 + 1, 2))

for col_idx, (seed, ax) in enumerate(zip(sorted(seed_to_image.keys())[::-1], axes)):
    img = plt.imread(os.path.join(images_dir, seed_to_image[seed]))
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"Seed {seed}", fontsize=8)

### plotting $x_t$

In [None]:
# plot a grid of num_timesteps x num_seeds showing channel=7 at each timestep for each seed
channel_to_plot = 7

seeds = sorted(list(tensors.keys()))[::-1]
num_timesteps = tensors[seeds[0]]['x'].shape[0]

fig, axes = plt.subplots(num_timesteps, num_seeds, figsize=(num_seeds * 2 + 1, num_timesteps * 2 + 1))
for row_idx, (timestep, ax_row) in enumerate(zip(range(num_timesteps), axes)):
    for col_idx, (seed, ax) in enumerate(zip(seeds, ax_row)):
        ax.imshow(tensors[seed]['x'][timestep,channel_to_plot,:,:].numpy(), aspect='auto', cmap='viridis')
        # no axis
        ax.axis('off')
        if row_idx == 0:
            ax.set_title(f'Seed {seed}', fontsize=10)
        if col_idx == 0:
            ax.set_ylabel(f'Timestep {timestep}', fontsize=10)
plt.tight_layout()
plt.show()

### plotting $grad_t$

In [None]:
# plot a grid of num_timesteps x num_seeds showing channel=7 at each timestep for each seed
channel_to_plot = 0

seeds = sorted(list(tensors.keys()))[::-1]
num_timesteps = tensors[seeds[0]]['x'].shape[0]

fig, axes = plt.subplots(num_timesteps, num_seeds, figsize=(num_seeds * 2 + 1, num_timesteps * 2 + 3))
for row_idx, (timestep, ax_row) in enumerate(zip(range(num_timesteps), axes)):
    for col_idx, (seed, ax) in enumerate(zip(seeds, ax_row)):
        ax.imshow(tensors[seed]['grad'][timestep,channel_to_plot,:,:].numpy(), aspect='auto', cmap='viridis')
        # no axis
        # ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
        if row_idx == 0:
            ax.set_title(f'Seed {seed}', fontsize=10)
        if col_idx == 0:
            ax.set_ylabel(f'Timestep {timestep}', fontsize=10)
plt.suptitle(f'Plotting Channel {channel_to_plot} in $grad_t$', fontsize=30, y=1)
plt.tight_layout()
plt.show()

### plotting images with heatmaps overlaid

In [None]:
def plot_latents_with_heatmaps(latents, heatmaps, col_titles=None, latent_channel_to_plot=7, suptitle=None):
    """
        latents: [tensors of shape (C, H, W)]
        heatmaps: [tensors of shape (1, H, W) or (H,W)]
        col_titles: [str]
        latent_channel_to_plot: int
    """
    if heatmaps is None:
        heatmaps = [ torch.zeros_like(latents[0]) for _ in latents ]

    hs = []
    for h in heatmaps:
        if len(h.shape) == 2:
            hs.append(h.unsqueeze(0))
        else:
            assert len(h.shape) == 3
            hs.append(h)
    heatmaps = hs

    num_cols = len(latents)
    fig, axes = plt.subplots(1, num_cols, figsize=(num_cols * 3, 3 + (0 if suptitle is None else 1)))
    axes = np.atleast_1d(axes)
    for col_idx in range(num_cols):
        latent = latents[col_idx][latent_channel_to_plot].cpu().numpy()
        heatmap = heatmaps[col_idx][0].cpu().numpy()
        
        axes[col_idx].imshow(latent, aspect='auto', cmap='gist_gray', alpha=1)
        axes[col_idx].imshow(heatmap, aspect='auto', cmap='Reds', alpha=0.5)
        axes[col_idx].axis('off')
        if col_titles is not None:
            axes[col_idx].set_title(col_titles[col_idx], fontsize=10)
    if suptitle is not None:
        plt.suptitle(suptitle, fontsize=20)
    plt.tight_layout()
    plt.show()

In [None]:
def convert_one_hot_to_heatmap(point, latent, radius=1):
    row, col = point
    heatmap = torch.zeros_like(latent)
    for r in range(-radius, radius + 1):
        for c in range(-radius, radius + 1):
            if 0 <= row + r < heatmap.shape[1] and 0 <= col + c < heatmap.shape[2]:
                heatmap[0, row + r, col + c] = 1.0
    return heatmap

In [None]:
anchor_points = {
    'Left Eye': (35, 38),
    'Right Eye': (32, 80),
    'Left Nostril': (50, 46),
    'Right Nostril': (50, 63),
    'Left Ear': (25, 20),
    'Right Ear': (25, 100),
    'Ball Center': (110, 60)
}

In [None]:
latents = [ tensors[s]['x'][-1] for s in seeds ]

In [None]:
plot_latents_with_heatmaps(
    latents=latents[:1],
    heatmaps=[ convert_one_hot_to_heatmap(anchor_points['Left Eye'], latents[0], radius=1) for _ in seeds ][:1],
    col_titles=[f'Seed {s}' for s in seeds],
    latent_channel_to_plot=7
)

In [None]:
def save_video_from_dir(dir, save_dir=None, fps=5):
    images = []
    for file in sorted(os.listdir(dir)):
        if file.endswith('.png'):
            img = plt.imread(os.path.join(dir, file))
            images.append(img)
    fig = plt.figure()
    plt.axis('off')
    ims = [[plt.imshow(img, animated=True)] for img in images]
    ani = animation.ArtistAnimation(fig, ims, interval=1000/fps, blit=True, repeat_delay=1000)
    plt.tight_layout()
    plt.close()
    if save_dir is None:
        save_dir = dir
    os.makedirs(save_dir, exist_ok=True)
    ani.save(os.path.join(save_dir, 'animation.mp4'), writer='ffmpeg', fps=fps)

In [None]:
# save_video_from_dir('visualizations/attn_heads/dog_normalized', fps=5)
# save_video_from_dir('visualizations/attn_heads/dog_not_normalized', fps=5)

## DIFT with grads

In [None]:
def compute_dift_heatmap(tensors, grad_timestep, anchor_points, point_desc):
    anchor_point = anchor_points[point_desc]
    grads_at_t = [ tensors[s]['grad'][grad_timestep] for s in seeds ]
    C,H,W = grads_at_t[0].shape
    anchor_dift_vec = grads_at_t[0][:,anchor_point[0], anchor_point[1]] # shape = [C]

    cossim = [ F.cosine_similarity(g.permute(1,2,0).view(-1,C), anchor_dift_vec).view(H,W) for g in grads_at_t]
    softmax_cossim = [ F.softmax(g.view(-1), dim=0).view(H,W) for g in cossim ]

    return cossim, softmax_cossim

In [None]:
cossim, softmax_cossim = compute_dift_heatmap(tensors, 10, anchor_points, 'Left Eye')

In [None]:
cossim[0].max(), cossim[0].min()

In [None]:
# plot cossim[0]
fig = plt.figure()
plt.imshow(softmax_cossim[0].numpy(), cmap='Blues')
plt.show()

In [None]:

plot_latents_with_heatmaps(
    latents=[latents[0]] + latents,
    heatmaps=[ convert_one_hot_to_heatmap(anchor_points['Left Eye'], latents[0], radius=3) ] + cossim,
    col_titles=['Anchor'] + [f'Seed {s}' for s in seeds],
    latent_channel_to_plot=7,
    suptitle='Left Eye'
)

In [None]:
plot_latents_with_heatmaps(
    latents=[latents[0]] + latents,
    heatmaps=None,
    col_titles=['Anchor'] + [f'Seed {s}' for s in seeds],
    latent_channel_to_plot=7,
    suptitle='Left Eye'
)

In [None]:
for grad_timestep in range(tensors[seeds[0]]['x'].shape[0]):
    cossim, softmax_cossim = compute_dift_heatmap(tensors, grad_timestep, anchor_points, 'Left Eye')
    plot_latents_with_heatmaps(
        latents=[latents[0]] + latents,
        heatmaps=[ convert_one_hot_to_heatmap(anchor_points['Left Eye'], latents[0], radius=3) ] + softmax_cossim,
        col_titles=['Anchor'] + [f'Seed {s}' for s in seeds],
        latent_channel_to_plot=7,
        suptitle=f'Using gradients at timestep {grad_timestep} for DIFT'
    )

In [None]:
cossim, softmax_cossim = compute_dift_heatmap(tensors, 33, anchor_points, 'Left Eye')

In [None]:
grad_33 = tensors[23]['grad'][33].view(-1).numpy()
# plot histogram of grad_33
plt.figure(figsize=(8,4))
plt.hist(grad_33, bins=100)
plt.title('Histogram of gradients at timestep 33 for seed 23')
plt.xlabel('Gradient value')
plt.ylabel('Frequency')
plt.show()