In [177]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.animation as animation
from matplotlib import rc
from IPython.display import HTML
from matplotlib import rcParams
import seaborn as sns
from textwrap import wrap
from tqdm import tqdm
import pandas as pd
import re
import math
import gc

In [178]:
def read_tensors_aux(dir, pt_name_predicate=None):
    """ recursively read tensors from a directory """
    if pt_name_predicate is None:
        pt_name_predicate = lambda x, y: True
    tensors = {}
    for path in os.listdir(dir):
        if path.endswith('.pt'):
            key = path[:-3]
            if pt_name_predicate is None or pt_name_predicate(key, dir):
                try:
                    tensors[key] = torch.load(os.path.join(dir, path))
                except Exception as e:
                    print(f"Error loading {path}: {e}")
        elif os.path.isdir(os.path.join(dir, path)):
            key = path.split('_')[1] if path.startswith('timestep_') else path 
            tensors[key] = read_tensors_aux(os.path.join(dir, path), pt_name_predicate)
        else:
            print(f"Skipping {path}, not a tensor or directory")
    return tensors

In [179]:
dir = 'tensors/outputs/sd3.5_medium/dift'
gen_experiments = os.listdir(dir)

In [181]:
for exp in gen_experiments:
    exp_dir = os.path.join(dir, exp, '000000')
    seed = int(exp.split('_')[-1])
    print(f"Experiment: {exp}, Seed: {seed}")

Experiment: 20251019_051518_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_15, Seed: 15
Experiment: 20251019_050908_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_11, Seed: 11
Experiment: 20251019_051205_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_13, Seed: 13
Experiment: 20251019_051640_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_16, Seed: 16
Experiment: 20251019_051138_dift-02_-.h*.Lskip-.Lresgate*.*_A_trapeze_duo_swapping_bars_mid-air_while_twisting_12, Seed: 12
Experiment: 20251019_051339_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_14, Seed: 14
Experiment: 20251019_051628_dift-02_-.h*.Lskip-.Lresgate*.*_A_trapeze_duo_swapping_bars_mid-air_while_twisting_15, Seed: 15
Experiment: 20251019_051034_dift-01_-.h*.Lskip-.Lresgate*.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_12, Seed: 12
Experime

In [None]:
exp_full_name = [e for e in gen_experiments if exp_short_name in e][0]
exp_dir = os.path.join(dir, exp_full_name, '000000')
tensors = read_tensors_aux(exp_dir)

In [None]:
tensors['x_t=1000'].shape, tensors['pos_out_t=1000'].shape, tensors['neg_out_t=1000'].shape, tensors['x_grad_t=1000'].shape

In [None]:
grad_t = [tensors[f'x_grad_t={t:04d}'] for t in range(1000, -1, -1) if f'x_grad_t={t:04d}' in tensors]
x_t = [tensors[f'x_t={t:04d}'] for t in range(1000, -1, -1) if f'x_t={t:04d}' in tensors]
pos_out_t = [tensors[f'pos_out_t={t:04d}'] for t in range(1000, -1, -1) if f'pos_out_t={t:04d}' in tensors]
neg_out_t = [tensors[f'neg_out_t={t:04d}'] for t in range(1000, -1, -1) if f'neg_out_t={t:04d}' in tensors]

grad_t = torch.stack(grad_t).permute(0, 2, 3, 1).cpu().numpy()
x_t = torch.stack(x_t).permute(0, 2, 3, 1).cpu().numpy()
pos_out_t = torch.stack(pos_out_t).permute(0, 2, 3, 1).cpu().numpy()
neg_out_t = torch.stack(neg_out_t).permute(0, 2, 3, 1).cpu().numpy()

# max pool over channels
grad_t_max = np.max(np.abs(grad_t), axis=-1)

In [None]:
# grad_t_max.shape is (T=50, W=128, H=128)
# plot a grid of images of grad_t_max
def plot_grid(tensors, suptitle, title_fn=lambda x: f't={x}', nrows=5, ncols=10, figsize=(20, 10), cmap='viridis', normalize=False, save_path=None):
    if normalize and tensors.max() > tensors.min():
        tensors = (tensors - tensors.min()) / (tensors.max() - tensors.min())

    fig, ax = plt.subplots(nrows, ncols, figsize=figsize)
    for i in range(nrows):
        for j in range(ncols):
            t = i * ncols + j
            if t < tensors.shape[0]:
                ax[i, j].imshow(tensors[t], cmap=cmap)
                ax[i, j].set_title(title_fn(t))
                ax[i, j].axis('off')
            else:
                ax[i, j].axis('off')
    plt.tight_layout()
    plt.suptitle(suptitle, fontsize=40, y=1.05)
    if save_path is not None:
        os.makedirs(os.path.dirname(f'visualizations/{save_path}'), exist_ok=True)
        plt.savefig(f'visualizations/{save_path}', bbox_inches='tight')
    else:
        plt.show()

In [None]:
plot_grid(grad_t_max, f'Conditioning Saliency Maps (experiment: {exp_short_name})', normalize=True)

In [None]:
plot_grid(x_t[:,:,:,7], 'x_t (channel 7)', normalize=True)

In [None]:
plot_grid(x_t[49,:,:,:].transpose(2,0,1), 'x_t at t=49 (all channels)', title_prefix='c=', nrows=2, ncols=8, figsize=(20,10), normalize=True)

In [None]:
plot_grid(pos_out_t[:,:,:,7], 'Conditioned Model Output (channel 7)', normalize=True)

In [None]:
plot_grid(neg_out_t[:,:,:,7], 'Unconditioned Model Output (channel 7)', normalize=True)

In [None]:
files = os.listdir('outputs/sd3.5_medium')
files = [f for f in files if 'dataset-01' in f]

In [None]:
for f in files:
    os.system(f'cp outputs/sd3.5_medium/{f}/000000.png dataset_images/{f[:16]}{f[27:-3]}.png')

In [None]:
files[0][:16]

In [None]:
def compute_saliency_mass(grads):
    """
        grads: tensor of shape (T, W, H) or list of tensors of shape (W, H)
        Returns: tensor of shape (T,) with the saliency mass for each timestep
    """
    if isinstance(grads, list):
        grads = torch.tensor(grads)
    # grads = torch.max(grads, axis=1)  # max pooling over channels
    # W,H = grads.shape[1], grads.shape[2]
    return torch.sum(torch.abs(grads), axis=(1, 2)).cpu().numpy()

In [None]:
def compute_metric_in_stream(base_dir='tensors/outputs/sd3.5_medium', dir_perdicate=lambda x: True, tensor_predicate=lambda x, d: x.startswith('x_grad_t='), metric_fn=lambda x: np.mean(np.array(x))):
    metrics = []
    for path in os.listdir(dir):
        if not dir_perdicate(path):
            continue
        tensors = read_tensors_aux(os.path.join(base_dir, path, '000000'), tensor_predicate)
        ks = sorted(list(tensors.keys()))[::-1]
        grad = np.array([tensors[k] for k in ks])
        for k, v in tensors.items():
            d = { 'path': path}
            d['tensor_name'] = ''.join(k.split('=')[:-1])
            d['timestep'] =int(k.split('=')[-1])
            d['value'] = metric_fn(v[np.newaxis, ...])[0]
            metrics.append(d)

    return metrics


In [None]:
compute_metric_in_stream(metric_fn=compute_saliency_mass)

In [None]:
pd.DataFrame(compute_metric_in_stream(metric_fn=compute_saliency_mass))

In [None]:
df = pd.read_csv('tensor_metrics_df.csv')
df['escaped_prompt'] = df['path'].apply(lambda x: x[31:x.rfind('_')])
df['seed'] = df['path'].apply(lambda x: int(x[x.rfind('_')+1:]))
df = df.sort_values(by='path')

labels_df = pd.read_excel('hallucinations_dataset.xlsx')
labels_df['escaped_prompt'] = labels_df['prompt'].apply(lambda x: re.sub(r"[^\w\-\.]", "_", x)[:50])

# inner join to add labels to df
df = df.merge(labels_df, on='escaped_prompt', how='inner')

# # df has four columns: path, tensor_name, timestep, saliency_mass
# # Create a column 'normalized_value' which is the value divided by the value at timestep 1000 for that path and tensor_name
values_at_1000 = df[df['timestep'] == 1000][['path', 'tensor_name', 'value', 'metric_name']]
df = df.merge(values_at_1000, on=['path', 'tensor_name', 'metric_name'], suffixes=('', '_at_1000'))
df['normalized_value'] = df['value'] / df['value_at_1000']
# for metric_name in ['saliency_mass', 'l2', 'l1', 'max', 'mean', 'var', 'var_on_diff', 'var_on_abs_diff']:
#    df[f'normalized_value'] = df.apply(lambda row: row['value'] / df[(df['path'] == row['path']) & (df['tensor_name'] == row['tensor_name']) & (df['timestep'] == 1000) & (df['metric_name'] == row['metric_name'])]['value'].values[0], axis=1)

In [None]:
df[df['timestep'] == 1000]

In [None]:
label_name = 'hallucinations'  # or 'coherence'
for metric in ['var_on_diff', 'var_on_abs_diff', 'saliency_mass', 'l2', 'l1', 'max', 'mean', 'var']:
    # plot normalized_saliency_mass over timestep for each value of hallucination
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df[(df['metric_name'] == metric) & (df[label_name].isin([0,2]))], x='timestep', y='normalized_value', hue=label_name, markers=True, dashes=False, palette='tab10')
    plt.yscale('log')
    plt.title(f'{metric.capitalize()} over Timesteps')
    plt.xlabel('Timestep')
    plt.xlim(1000, 800)
    plt.ylabel(f'{metric.capitalize()}') 
    plt.legend(title=f'{label_name.capitalize()} level', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

In [None]:
examples = {
    'Hallucination (gymnastics)': '20251002_022502_dataset-03_-.*_A_trapeze_duo_swapping_bars_mid-air_while_twisting_23',
    'Hallucination (fidget spinner)': '20251001_180741_dataset-03_-.*_A_hand_holding_a_yellow_fidget_spinner._The_hand_i_23',
    'Normal (dog with ball)': '20251001_170130_dataset-03_-.*_A_dog_playing_with_an_orange_ball_with_blue_stripe_23',
    'Normal (walking in field)': '20251001_175612_dataset-03_-.*_A_cinematic_shot_of_a_person_walking_along_a_quiet_23'
}

In [None]:
df[df['path'].isin(examples.values())]['prompt'].unique()

In [None]:
def read_concat(dir, tensor_predicate=None):
    tensors = read_tensors_aux(dir, tensor_predicate)
    ks = sorted(list(tensors.keys()))[::-1]
    t = np.array([tensors[k] for k in ks])
    return t

In [None]:
examples_tensors = {name: read_concat(os.path.join(dir, path, '000000'), lambda x, d: x.startswith('x_grad_t=')) for name, path in examples.items()}

In [None]:
examples_tensors.keys()

In [None]:
for name, t in examples_tensors.items():
    plot_grid(t, f"{name}", normalize=True)

In [None]:
col_var = {name: torch.var(torch.tensor(t), dim=1).cpu().numpy() for name, t in examples_tensors.items()}

plt.figure(figsize=(16, 12))
plt.suptitle('Variance over Columns of Saliency Maps', fontsize=25)

for i, (name, var) in enumerate(col_var.items()):
    plt.subplot(2, 2, i+1)
    plt.imshow(var, aspect='auto', cmap='viridis')
    plt.colorbar(label='Variance')
    plt.title(name)
    plt.xlabel('Width (W)')
    plt.ylabel('Timestep')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
row_var = {name: torch.var(torch.abs(torch.tensor(t)), dim=2).cpu().numpy() for name, t in examples_tensors.items()}

plt.figure(figsize=(16, 12))
plt.suptitle('Variance per Row of Saliency Maps', fontsize=25)

for i, (name, var) in enumerate(row_var.items()):
    plt.subplot(2, 2, i+1)
    plt.imshow(var, aspect='auto', cmap='viridis')
    plt.colorbar(label='Variance')
    plt.title(name)
    plt.xlabel('Width (W)')
    plt.ylabel('Timestep')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
col_var = {name: torch.var(torch.abs(torch.tensor(t)), dim=1).cpu().numpy() for name, t in examples_tensors.items()}

plt.figure(figsize=(16, 12))
plt.suptitle('Variance over Columns of Saliency Maps', fontsize=25)

for i, (name, var) in enumerate(col_var.items()):
    plt.subplot(2, 2, i+1)
    plt.imshow(var, aspect='auto', cmap='viridis')
    plt.colorbar(label='Variance')
    plt.title(name)
    plt.xlabel('Width (W)')
    plt.ylabel('Timestep')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
row_var = {name: torch.var(torch.abs(torch.tensor(t)), dim=2).cpu().numpy() for name, t in examples_tensors.items()}

plt.figure(figsize=(16, 12))
plt.suptitle('Variance per Row of Saliency Maps', fontsize=25)

for i, (name, var) in enumerate(row_var.items()):
    plt.subplot(2, 2, i+1)
    plt.imshow(var, aspect='auto', cmap='viridis')
    plt.colorbar(label='Variance')
    plt.title(name)
    plt.xlabel('Width (W)')
    plt.ylabel('Timestep')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
time_var = {name: torch.var(torch.abs(torch.diff(torch.tensor(t), dim=0)[:20,:,:]), dim=0).cpu().numpy() for name, t in examples_tensors.items()}

plt.figure(figsize=(16, 12))
plt.suptitle('Variance per Column of Saliency Maps', fontsize=25)

for i, (name, var) in enumerate(time_var.items()):
    plt.subplot(2, 2, i+1)
    plt.imshow(var, aspect='auto', cmap='viridis')
    plt.colorbar(label='Variance')
    plt.title(name)
    plt.xlabel('Width (W)')
    plt.ylabel('Timestep')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
def add_short_name(df, examples):
    short_names = {v: k for k, v in examples.items()}
    df['short_name'] = df['path'].apply(lambda x: short_names.get(x, 'Unknown'))
    return df

In [None]:
label_name = 'hallucinations'  # or 'coherence'
filenames = examples.values()
for metric in ['var_on_diff', 'var_on_abs_diff', 'saliency_mass', 'l2', 'l1', 'max', 'mean', 'var']:
    # plot normalized_saliency_mass over timestep for each value of hallucination
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df[(df['metric_name'] == metric) & (df[label_name].isin([0,2])) & (df['path'].isin(filenames))], x='timestep', y='normalized_value', hue='short_name', markers=True, dashes=False, palette='tab10')
    plt.yscale('log')
    plt.title(f'{metric.capitalize()} over Timesteps')
    plt.xlabel('Timestep')
    plt.xlim(1000, 800)
    plt.ylabel(f'{metric.capitalize()}') 
    plt.legend(title=f'{label_name.capitalize()} level', loc='lower center')
    plt.tight_layout()
    plt.show()

In [None]:
df = add_short_name(df, examples)


In [None]:
df[(df['metric_name'] == metric) & (df[label_name].isin([0,2])) & (df['path'].isin(filenames))]

In [None]:
dog_dir = 'tensors/outputs/sd3.5_medium/dog/'
dog_files = { f[f.index('.')+2:f.index('.')+2+f[f.index('.')+2:].index('.')]: f for f in os.listdir(dog_dir) }
dog_files = { int(k if k != '*' else -1) : v for k, v in dog_files.items() }
dog_tensors = { h: read_tensors_aux(dog_dir + exp_dir)['000000'] for h, exp_dir in dog_files.items() }
dog_tensors = { h: torch.stack([tlist[k] for k in sorted(list(tlist.keys()))[::-1] ]) for h, tlist in dog_tensors.items() }
dog_tensors = torch.stack([dog_tensors[h] for h in sorted(dog_tensors.keys())])

In [None]:
dog_tensors.shape

In [None]:
for h in range(dog_tensors.shape[0]):
    plot_grid(dog_tensors[h].cpu().numpy(), f"Dog experiment, head={h if h >= 0 else '*'}", normalize=True)

In [None]:
for t in range(50):
    plot_grid(dog_tensors[:,t,:,:].cpu().numpy(), f"Dog experiment, timestep={t}", title_fn=lambda x: f'head={x-1}' if x > 0 else 'all heads', nrows=5, ncols=5, figsize=(20, 15), normalize=False, save_path=f'attn_heads/dog_not_normalized/timestep_{t:02d}.png')

In [None]:
def save_video_from_dir(dir, save_dir=None, fps=5):
    images = []
    for file in sorted(os.listdir(dir)):
        if file.endswith('.png'):
            img = plt.imread(os.path.join(dir, file))
            images.append(img)
    fig = plt.figure()
    plt.axis('off')
    ims = [[plt.imshow(img, animated=True)] for img in images]
    ani = animation.ArtistAnimation(fig, ims, interval=1000/fps, blit=True, repeat_delay=1000)
    plt.tight_layout()
    plt.close()
    if save_dir is None:
        save_dir = dir
    os.makedirs(save_dir, exist_ok=True)
    ani.save(os.path.join(save_dir, 'animation.mp4'), writer='ffmpeg', fps=fps)

In [None]:
save_video_from_dir('visualizations/attn_heads/dog_normalized', fps=5)
save_video_from_dir('visualizations/attn_heads/dog_not_normalized', fps=5)