In [None]:
import argparse
import datetime
import deepspeed
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import json
import os
from functools import partial
from pathlib import Path
from collections import OrderedDict

from dataset.datasets import build_dataset
import models.ast_clip_cast
from run_bidirection_compo import get_args
from timm.models import create_model
import util_tools.utils as utils

def main(args, ds_init):
    device = torch.device(args.device)

    args.audio_path = None if 'all' in args.ucf101_type else args.audio_path
    
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    dataset_val, args.nb_classes = build_dataset(is_train=False, test_mode=False, args=args)

    patch_size = 14
    print("Patch size = %s" % str(patch_size))
    args.window_size = 16
    args.patch_size = patch_size
    model_args = {
            'model_name': args.vmae_model,
            'pretrained': False,
            'num_classes': args.nb_classes,
            'all_frames': args.num_frames * args.num_segments,
            'tubelet_size': args.tubelet_size,
            'drop_rate': args.drop,
            'drop_path_rate': args.drop_path,
            'attn_drop_rate': args.attn_drop_rate,
            'drop_block_rate': None,
            'use_mean_pooling': args.use_mean_pooling,
            'init_scale': args.init_scale,
            'fusion_method': args.fusion_method
        }
    if args.audio_path is not None:
        print(f"Audio_Patch size = {args.audio_height*args.audio_width//(args.window_size*args.window_size)}")
        model_args['audio_patch'] = args.audio_height*args.audio_width//(args.window_size*args.window_size)
    if args.bcast_method is not None:
        print(f"bcast_method = {args.bcast_method}")
        model_args['bcast_method'] = args.bcast_method
    if args.time_encoding:
        model_args['time_encoding'] = args.time_encoding
        model_args['spec_shape'] = [args.audio_height//args.window_size, args.audio_width//args.window_size]
    if args.audio_only_finetune:
        model_args['audio_only_finetune'] = True
    if '_ast_' in args.vmae_model:
        model_args['fstride'] = args.stride
        model_args['tstride'] = args.stride
        model_args['input_fdim'] = args.audio_height
        model_args['input_tdim'] = args.audio_width
    if args.enable_audio_stride:
        model_args['fstride'] = 10
        fdim, tdim = int((args.audio_height-16)/10)+1, int((args.audio_width-16)/10)+1
        model_args['audio_patch'] = fdim * tdim
        model_args['spec_shape'] = [fdim, tdim]
    if args.not_use_stpos == False:
        model_args['use_stpos'] = args.not_use_stpos
    if args.pre_time_encoding == True:
        model_args['pre_time_encoding'] = args.pre_time_encoding
    if args.split_time_mlp == True:
        model_args['split_time_mlp'] = args.split_time_mlp
    model_args['use_Adapter'] = True
    model = create_model(**model_args)
    model.to(device)

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Model = %s" % str(model))
    print('number of params:', n_parameters)

    checkpoint = torch.load(args.fine_tune, map_location=device)
    model.load_state_dict(checkpoint['module'], strict=True)
    
    return args, model, dataset_val


data_sel = 2
if data_sel == 1:
    ANNO_PATH='/data/joohyun7u/project/CAST/dataset/epic_sounds'
    DATA_PATH='/local_datasets/epic_sounds/video'
    AUDIO_PATH='/local_datasets/epic_sounds/wav'
    DATA_SET='EPIC_sounds'
    CLASS='44'
    WIDTH='400'
    VMAE_PATH='/data/datasets/Epickitchens100_clips/epic_checkpoint-2400.pth'
else:
    ANNO_PATH='/data/joohyun7u/project/CAST/dataset/vggsound'
    DATA_PATH='/local_datasets/vggsound/resize256/'
    AUDIO_PATH='/local_datasets/vggsound/resize256/'
    DATA_SET='VGGSound'
    CLASS='309'
    WIDTH='1024'
    VMAE_PATH='/data/dataset/epic_audio/vit_b_hybrid_pt_800e.pth'

CLIP_PATH='/data/datasets/Epickitchens100_clips/ViT-B-16.pt'
OUTPUT_DIR='/data/joohyun7u/project/CAST/xlog'
    
input_args = [
'--data_set', DATA_SET,
'--nb_classes', CLASS,
'--data_path', DATA_PATH,
'--anno_path', ANNO_PATH,
'--vmae_finetune', VMAE_PATH,
'--clip_finetune', CLIP_PATH,
'--log_dir', OUTPUT_DIR,
'--output_dir', OUTPUT_DIR,
'--batch_size', '6',
'--input_size', '224',
'--short_side_size', '224',
'--save_ckpt_freq', '10',
'--num_sample', '1',
'--num_frames', '16',
'--opt', 'adamw',
'--lr', '5e-4',
'--opt_betas', '0.9', '0.999',
'--layer_decay', '0.8',
'--weight_decay', '0.05',
'--epochs', '50',
'--dist_eval',
'--test_num_segment', '2',
'--test_num_crop', '3',
'--num_workers', '16',
'--seed', '0',
'--warmup_epochs', '5',
'--enable_deepspeed',
'--reprob', '0.',
'--init_scale', '1.',
'--unfreeze_layers', 'cross', 'clip_temporal_embedding', 'space_time_pos', 'vmae_fc_norm', 'last_proj', 'Adapter', 'ln_post', 'head', 'concat_head', 'clip_text_positional_embedding', 'time_mlp',
'--update_freq', '2',
'--drop_path', '0.2',
'--head_drop', '0.',
'--cutmix', '0.',
'--mixup_switch_prob', '0',
'--mixup_prob', '0.9',
'--device','cuda',
]
audio_args = [
'--vmae_model', 'single_ast_clip_vit_base_patch16_224',
'--audio_path', AUDIO_PATH,
'--audio_type', 'single',
'--realtime_audio',
'--audio_height', '128',
'--audio_width', WIDTH,
# '--spec_augment',
'--specnorm',
'--process_type', 'ast',
'--split_time_mlp',
# '--mask_audio_token', '0.25',
# "--ablation_eval", "missing"
'--time_encoding',
'--not_use_stpos',
]
datasel2_audio_args = [
'--vmae_model', 'single_ast_clip_vit_base_patch16_224',
'--audio_path', AUDIO_PATH,
'--audio_type', 'single',
'--realtime_audio',
'--audio_height', '128',
'--audio_width', WIDTH,
'--spec_augment',
'--specnorm',
'--process_type', 'ast',
'--mixup_spec',
'--add_noise',
'--spec_cutmix',
# '--mask_audio_token', '0.25',
# "--ablation_eval", "missing"
]

if data_sel == 1:
    # 8layer 쯤에 피크가 있음
    time_args = input_args + audio_args + ['--fine_tune','/data/joohyun7u/project/CAST/log4/epic_sounds/EKSound_ASTCLIP_split_time_encoding_v5_MixSpecCut_AddNoise_UF4/OUT/checkpoint-best/mp_rank_00_model_states.pt']
    # time_args = input_args + audio_args + ['--fine_tune','/data/joohyun7u/project/CAST/log4/epic_sounds/EKSound_ASTCLIP_time_encoding_v5_MixSpecCut_AddNoise_UF6/OUT/checkpoint-best/mp_rank_00_model_states.pt']
else:
    time_args = input_args + datasel2_audio_args + ['--time_encoding', '--split_time_mlp', '--not_use_stpos','--fine_tune','/data/joohyun7u/project/CAST/log4/vggsound/VGGSound_ASTCLIP_split_time_encoding_v5_VMHY_MixSpecCut_AddNoise/OUT/checkpoint-best/mp_rank_00_model_states.pt']
    # time_args = input_args + datasel2_audio_args + ['--fine_tune','/data/joohyun7u/project/CAST/log3/vggsound/VGGSound_ASTCLIP_VMHY_MixSpecCut_AddNoise/OUT/checkpoint-best/mp_rank_00_model_states.pt']
    # time_args = input_args + datasel2_audio_args + ['--time_encoding','--fine_tune','/data/joohyun7u/project/CAST/log3/vggsound/VGGSound_ASTCLIP_VMHY_MixSpec_AddNoise/OUT/checkpoint-best/mp_rank_00_model_states.pt']
# sys.argv를 수정하여 인자 전달
import sys
original_argv = sys.argv.copy()
sys.argv = ['attn_map.py'] + time_args

opts, ds_init = get_args()
time_args, time_model, data_loader_val = main(opts, ds_init)

# 원래 argv 복원
sys.argv = original_argv

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from util_tools import audio_transforms
import importlib
importlib.reload(audio_transforms)
from util_tools.audio_transforms import Spectrogram
import os, torchaudio
import IPython.display as ipd
# from ipywidgets import Video
from einops import rearrange
    
def tensor_to_img(spectrogram,output=None):
    # plt.figure(figsize=(spectrogram.shape[-1] // 100, spectrogram.shape[-2] // 100))
    spectrogram = rearrange(spectrogram,'h w -> w h')
    cax = plt.pcolormesh(spectrogram, shading='auto')  # Use 'shading' for better color interpolation
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [sec]')
    
    plt.colorbar(cax)  # This adds the color bar to the right of the plot
    plt.show()
    display(spectrogram.shape)
    if output:
        return spectrogram
    
def show_wav(sample, sr):
    plt.plot(sample.t().numpy())
    plt.title('Waveform')
    plt.xlabel('Sample')
    plt.ylabel('Amplitude')
    plt.show()
    
    samples = sample[0]
    fft_result = np.fft.fft(samples.numpy())
    fft_freq = np.fft.fftfreq(len(samples), 1 / sr)

    # FFT 결과 시각화
    plt.figure(figsize=(12, 4))
    plt.plot(fft_freq, np.abs(fft_result))
    plt.title('Frequency Domain')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Amplitude')
    plt.xlim(0, sr / 2)  # Nyquist Frequency까지 표시
    plt.show()
    
def show_wav_compare(original_sample, noise_reduced_sample=None, sr=None):
    plt.figure(figsize=(12, 4))
    plt.subplot(2, 1, 1)
    plt.plot(original_sample.t().numpy())
    if noise_reduced_sample is not None:
        plt.plot(noise_reduced_sample, alpha=1)
    plt.title('Original & Noise Reduced Audio Waveform')
    plt.xlabel('Sample')
    plt.ylabel('Amplitude')    
    
def display_audio(sample_id=None, audio_path = '/data/datasets/epic_audio/wav_split/', video_path='/data/datasets/Epickitchens100_clips/videos', data_set='EPIC' ,log=False, log_spec=False, preds=None, log_wav=False,output=None):
    if data_set=='EPIC_sounds':
        audio_path = os.path.join('/data/datasets/epic_sounds/wav/', sample_id + '.wav')
        video_path = os.path.join('/data/datasets/epic_sounds/video/', sample_id + '.mp4')
    else:    
        audio_path = os.path.join(audio_path)
        # video_path = os.path.join(video_path)
    if log:
        print(audio_path,'\n',video_path)
    
    samples, sample_rate = torchaudio.load(audio_path)
    
    samples_numpy = show_wav_compare(samples, sr=sample_rate)
    display(ipd.Audio(samples, rate=sample_rate))
    if output:
        return samples
    
    # video = Video.from_file(video_path)
    # display(ipd.display(video))
    
f_dim, t_dim = time_model.get_shape(time_args.stride, time_args.stride, time_args.audio_height, time_args.audio_width)
T = time_args.num_frames//2
f_dim, t_dim

In [None]:
# data_loader_val.spectrogram.noisereduce = True
data_loader_val.spectrogram.noisereduce = False
data_loader_val.sampling_type = 'uniform'
# data_loader_val.sampling_type = 'dense'

In [None]:
time_model.eval()
device='cuda'

from models.prompt import dataset_class
action = dataset_class(time_args.data_set, time_args.anno_path)

# data_iter = iter(data_loader_val)
# batch = next(data_iter)
# batch = data_loader_val.get_item_by_index(42)
# batch = data_loader_val.get_item_by_index(922)
# batch = data_loader_val.get_item_by_index(6922)
# batch = data_loader_val.get_item_by_index(4822)
# batch = data_loader_val.get_item_by_index(8000)  # PPT로 만들러짐
# batch = data_loader_val.get_item_by_index(2222)
# batch = data_loader_val2.get_item_by_index(1252)
# ek-sound 
# batch = data_loader_val.get_item_by_index(2340)
# batch = data_loader_val.get_item_by_index(1507)
# batch = data_loader_val.get_item_by_index(1326)
# batch = data_loader_val.get_item_by_index(5722) 
# batch = data_loader_val.get_item_by_index(626) # 별로임

# #vggsound

# batch = data_loader_val.get_item_by_index(9542)
# batch = data_loader_val.get_item_by_index(7422) # Cat 안대 후반부에 나쁘지 않을지도
# batch = data_loader_val.get_item_by_index(9569) # Dog
batch = data_loader_val.get_item_by_index(6929)
# batch = data_loader_val.get_item_by_index(8643) # 이건 애매함

import random

random_idx = random.randint(0, len(data_loader_val) - 1)  # 정수 인덱스
# batch = data_loader_val.get_item_by_index(random_idx)
# batch = data_loader_val.get_item_by_index(9569)


time_idx = torch.tensor(batch[5]).unsqueeze(0)
print(batch[2],time_idx)
samples = batch[0]
target = batch[1]
batch_size = samples.shape[0]
samples = samples.to(device, non_blocking=True).unsqueeze(0)
if time_args.audio_path is not None:
    if time_args.collate:
        spec = [spe.to(device, non_blocking=True).half() for spe in batch[3]]
    else:
        spec = batch[3].to(device, non_blocking=True).unsqueeze(0)
else:
    spec = None
captions = batch[4]
ste = time_args.audio_width // 16

# time_idx = time_idx.reshape(time_idx.shape[0],-1,2).to(dtype=samples.dtype, device=samples.device)
# start_end = time_idx[:,0,:]
# linspace = torch.linspace(0, 1, steps=ste+1).to(dtype=samples.dtype, device=samples.device)
# segments = start_end[:, 0:1] + (start_end[:, 1:2] - start_end[:, 0:1]) * linspace[:-1]
# next_segments = start_end[:, 0:1] + (start_end[:, 1:2] - start_end[:, 0:1]) * linspace[1:]
# segments = torch.stack([segments, next_segments], dim=-1).view(time_idx.shape[0], ste, 2)

if DATA_SET in ["EPIC_sounds"]:
    path = os.path.join(time_args.audio_path, batch[2] + '.wav')
    video_path = os.path.join(time_args.data_path, batch[2] + '.mp4')
else:
    path = os.path.join(time_args.audio_path, batch[2] + '.mp4')
    video_path = os.path.join(time_args.data_path, batch[2] + '.mp4')

audio_spec = tensor_to_img(spec[0,0,:,:].cpu(),output=True)


f_dim, t_dim = time_model.get_shape(time_args.stride, time_args.stride, time_args.audio_height, time_args.audio_width)

logits, all_atts = time_model(samples[:1], caption=captions, spec=spec[:1], idx=time_idx[:1], output_attentions=True)

print('random_idx',random_idx)
print('action :', action['action'][batch[1]], '\tpred :', action['action'][torch.argmax(logits).item()])
audio_wav = display_audio(audio_path=path,output=True)
video_numpy = batch[0].permute(1, 2, 3, 0)[::2,:,:,:].numpy()  # shape [16, 224, 224, 3]

# Concatenate the frames horizontally
concatenated_frames = np.concatenate(video_numpy, axis=1)

# Plot the concatenated frames
plt.figure(figsize=(20, 5))
plt.imshow((concatenated_frames - concatenated_frames.min()) / (concatenated_frames.max() - concatenated_frames.min()))
plt.axis('off')
plt.show()
    
depth = len(all_atts) // 2


t2s_all_attn = []
s2t_all_attn = []
for layer_id in range(depth):

    # T2S: audio Q, video K (e.g. [1, 12, 468, 1568])
    t2s_all_attn.append(all_atts[2 * layer_id])

    # S2T: video Q, audio K (e.g. [1, 12, 1568, 468])
    s2t_all_attn.append(all_atts[2 * layer_id + 1])

In [None]:
import os

def make_save_dir(batch, action, base_dir="results"):
    """
    build folder name: <video_uid>_<action(with-hyphen)>_<frame_idx>
    e.g.  1234_open-door_17
    """
    folder = f"{batch[2]}_{action['action'][batch[1]].replace(' ', '-')}_{batch[1]}_sampling"
    save_dir = os.path.join(base_dir, folder)
    os.makedirs(save_dir, exist_ok=True)
    return save_dir

def save_audio_spec(audio_spec, save_path):
    plt.figure(figsize=(4, 4))
    plt.imshow(audio_spec, aspect="auto", origin="lower", cmap="gray_r")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
    plt.close()
    
def _prep_rgb(img):
    """RGB [H,W,3] → numpy.float32 in [0,1]"""
    if torch.is_tensor(img):
        img = img.detach().cpu().float().numpy()
    if img.dtype == np.uint8:          # 이미 0–255 uint8
        return img
    # float32/64 → 0–1 로 스케일 or 클리핑
    if img.min() < 0 or img.max() > 1:
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    return np.clip(img, 0.0, 1.0)
def save_rgb_image(img, save_path):
    img = _prep_rgb(img)
    plt.figure(figsize=(32, 4))
    plt.imshow(img, aspect="auto"); plt.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
    plt.close()


In [None]:
import math, torch, torch.nn.functional as F
import matplotlib.pyplot as plt

def show_s2t_attention_overlay(audio_spec: torch.Tensor,
                               attn_audio: torch.Tensor,
                               *,
                               mode: str = "per_frame",   # "per_frame" | "mean"
                               frames: list[int] | None = None,  # 시각화할 vt 인덱스
                               alpha: float = 0.4,
                               cmap: str = "jet",
                               save_path=None,):
    """
    audio_spec : [128, 400]  — Mel-spectrogram
    attn_audio : [8, 12, 39] — mean-pooled S2T attention (vt, f_dim, t_dim)

    mode = "per_frame" → 프레임마다 overlay (기존 방식)
    mode = "mean"      → vt 차원을 평균 내어 하나의 히트맵 overlay
    frames             → "per_frame"에서 보고 싶은 vt 인덱스 지정
                         (None 이면 전 프레임)
    """
    if mode not in {"per_frame", "mean"}:
        raise ValueError("mode must be 'per_frame' or 'mean'")

    # --- attention upsampling ---
    attn_up = F.interpolate(attn_audio.unsqueeze(1), size=audio_spec.shape,
                            mode="bilinear", align_corners=False).squeeze(1)    # [8,128,400]

    # --- 정규화 (0-1) : vt 별 또는 평균 맵 하나 ---
    if mode == "per_frame":
        min_val = torch.amin(attn_up, dim=(-2, -1), keepdim=True)
        max_val = torch.amax(attn_up, dim=(-2, -1), keepdim=True)
        attn_up = (attn_up - min_val) / (max_val - min_val + 1e-6)
        vt_list = frames if frames is not None else list(range(attn_up.size(0)))
        n = len(vt_list)
        cols = min(4, n)                    # 줄당 최대 4장
        rows = math.ceil(n / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows),
                                 sharex=True, sharey=True)
        axes = axes.flatten() if n > 1 else [axes]
        for ax_i, vt in enumerate(vt_list):
            ax = axes[ax_i]
            ax.imshow(audio_spec.cpu(), aspect="auto", origin="lower", cmap="gray_r")
            ax.imshow(attn_up[vt].cpu(), aspect="auto", origin="lower",
                      cmap=cmap, alpha=alpha)
            # ax.set_title(f"vt = {vt}", fontsize=10)
            ax.axis("off")
        # 빈 서브플럿 숨김
        for ax in axes[n:]:
            ax.axis("off")

    else:  # mode == "mean"
        attn_mean = attn_up.mean(0, keepdim=True)          # [1,128,400]
        attn_mean = (attn_mean - attn_mean.min()) / (attn_mean.max() - attn_mean.min() + 1e-6)
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.imshow(audio_spec.cpu(), aspect="auto", origin="lower", cmap="gray_r")
        ax.imshow(attn_mean[0].cpu(), aspect="auto", origin="lower",
                  cmap=cmap, alpha=alpha)
        # ax.set_title("Mean of all frames", fontsize=12)
        ax.axis("off")

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
    plt.show()
    if save_path is not None:
        plt.close()



for layer_id in range(depth):
    print(f"Layer {layer_id+1}")
    s2t = s2t_all_attn[layer_id]
    s2t_split = rearrange(s2t,'b h (vn vt) (af at) -> b h vt vn af at',vt=T, vn=196, af=f_dim, at=t_dim)
    s2t_split_attn_audio = s2t_split.mean(-3).mean(1)
    
    save_dir = make_save_dir(batch, action, base_dir="./results")
    base_name = f"layer{layer_id+1:02d}"

    # 1) 원본 스펙트로그램
    # spec_path = os.path.join(save_dir, f"spec.png")
    # save_audio_spec(audio_spec.detach(), spec_path)

    # 2) overlay (mean 모드 예시)
    # overlay_path = os.path.join(save_dir, f"{base_name}_spec_overlay.png")
    show_s2t_attention_overlay(audio_spec.detach(),
                               s2t_split_attn_audio[0].detach(),
                               mode='mean',
                            #    save_path=overlay_path
                               )

from pathlib import Path
import shutil
video_path = Path(time_args.data_path) / (batch[2] + '.mp4')
shutil.copy2(video_path, save_dir)


In [None]:
import math, torch, torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import rearrange               # pip install einops

def _prep_rgb(img):
    """RGB [H,W,3] → numpy.float32 in [0,1]"""
    if torch.is_tensor(img):
        img = img.detach().cpu().float().numpy()
    if img.dtype == np.uint8:          # 이미 0–255 uint8
        return img
    # float32/64 → 0–1 로 스케일 or 클리핑
    if img.min() < 0 or img.max() > 1:
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    return np.clip(img, 0.0, 1.0)

def show_t2s_attention_overlay(concat_frames: torch.Tensor | np.ndarray,
                               attn_video: torch.Tensor,
                               *,
                               mode: str = "per_ts",      # "per_ts" | "mean"
                               ts: list[int] | None = None,   # 보고 싶은 at 인덱스
                               alpha: float = 0.4,
                               cmap: str = "jet",
                               save_path=None):
    """
    concat_frames : [224, 1792, 3]  (8 장의 224×224 프레임을 가로로 이어 붙인 RGB 이미지)
    attn_video    : [39, 8, 196]    (at = 39, vt = 8, vn = 14×14)
    """
    if mode not in {"per_ts", "mean"}:
        raise ValueError("mode must be 'per_ts' or 'mean'")

    H, W = concat_frames.shape[:2]
    device  = attn_video.device
    concat_frames = _prep_rgb(concat_frames) 
    # (at, vt, vn=196) → (at, 14, 14*8=112)  패치 격자를 가로로 이어붙인 형태
    attn_grid = rearrange(attn_video, 'at vt (ph pw) -> at ph (vt pw)', ph=14, pw=14)
    # attn_grid = rearrange(attn_video, 'at vt (pw ph) -> at ph (vt pw)', ph=14, pw=14)
    # ↑ [39,14,112]

    # 업샘플 → [at, H, W]
    attn_up = F.interpolate(attn_grid.unsqueeze(1), size=(H, W),
                            mode="bilinear", align_corners=False).squeeze(1)

    if mode == "per_ts":
        ts_list = ts if ts is not None else list(range(attn_up.size(0)))
        # at 별로 0–1 정규화
        min_val = torch.amin(attn_up[ts_list], dim=(-2, -1), keepdim=True)
        max_val = torch.amax(attn_up[ts_list], dim=(-2, -1), keepdim=True)
        attn_up = (attn_up[ts_list] - min_val) / (max_val - min_val + 1e-6)

        n = len(ts_list)
        cols = min(4, n); rows = math.ceil(n / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows),
                                 sharex=True, sharey=True)
        axes = axes.flatten() if n > 1 else [axes]

        for ax_i, t in enumerate(ts_list):
            ax = axes[ax_i]
            ax.imshow(concat_frames, aspect="auto")
            ax.imshow(attn_up[ax_i].cpu(), aspect="auto",
                      cmap=cmap, alpha=alpha)
            # ax.set_title(f"at = {t}", fontsize=10)
            ax.axis("off")
        for ax in axes[n:]:
            ax.axis("off")

    else:  # mode == "mean"
        attn_mean = attn_up.mean(0, keepdim=True)   # [1,H,W]
        attn_mean = (attn_mean - attn_mean.min()) / (attn_mean.max() - attn_mean.min() + 1e-6)

        fig, ax = plt.subplots(figsize=(32, 4))
        ax.imshow(concat_frames, aspect="auto")
        ax.imshow(attn_mean[0].cpu(), aspect="auto",
                  cmap=cmap, alpha=alpha)
        # ax.set_title("Mean of all audio time-steps", fontsize=12)
        ax.axis("off")

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
        plt.show()
        plt.close()
    else:
        plt.show()


for layer_id in range(depth):
    print(f"Layer {layer_id+1}")
    
    t2s = t2s_all_attn[layer_id]
    t2s_split = rearrange(t2s,'b h (af at) (vn vt) -> b h af at vt vn',vt=T, vn=196, af=f_dim, at=t_dim)
    t2s_split_attn_video = t2s_split.mean(-4).mean(1)
    
    save_dir = make_save_dir(batch, action, base_dir="./results")
    base_name = f"layer{layer_id+1:02d}"

    # 1) 원본 비디오
    spec_path = os.path.join(save_dir, f"video.png")
    # save_rgb_image(concatenated_frames, spec_path)

    overlay_path = os.path.join(save_dir, f"{base_name}_video_overlay.png")
    show_t2s_attention_overlay(concatenated_frames,
                            t2s_split_attn_video[0].detach(),
                            mode="mean",
                            ts=[0,10,20,30],
                            # save_path=overlay_path
                            )

In [None]:
import math, torch, torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import rearrange            
import gc

def _prep_rgb(img):
    """RGB [H,W,3] → numpy.float32 in [0,1]"""
    if torch.is_tensor(img):
        img = img.detach().cpu().float().numpy()
    if img.dtype == np.uint8:          # 이미 0–255 uint8
        return img
    # float32/64 → 0–1 로 스케일 or 클리핑
    if img.min() < 0 or img.max() > 1:
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    return np.clip(img, 0.0, 1.0)

def _postprocess_heatmap(h, *, blur_ks=5, cutoff=0.2, gamma=2.0):
    """
    h : [H,W] 0~1 텐서
    blur_ks : 평균블러 커널(홀수). 0이면 블러 생략
    cutoff  : (0~1) 하위 퍼센타일 절단. 0.2 → 하위 20% 0으로
    gamma   : >1 이면 밝은 곳 강조, <1 이면 어두운 곳 강조
    """
    if blur_ks and blur_ks > 1:
        pad = blur_ks // 2
        h = F.avg_pool2d(h.unsqueeze(0).unsqueeze(0), blur_ks,
                         stride=1, padding=pad)[0,0]

    if cutoff > 0:
        thr = torch.quantile(h, cutoff)
        h = torch.clamp(h - thr, min=0.0)

    h = h / (h.max() + 1e-6)
    if gamma != 1.0:
        h = h.pow(gamma)
    return h

def _piecewise_strength(h, boundary=0.8,
                        low_max=0.3, high_min=0.3, high_max=1.0):
    """
    h : [H,W] 0–1 텐서
    boundary : 경계값(예: 0.8 → 상위 20%)
    low_max  : 경계 아래 구간의 최대 세기
    high_min : 경계에서의 시작 세기
    high_max : 1.0 에서의 최대 세기
    """
    h_weighted = h.clone()
    # ① 하위 구간: 0~boundary → 0~low_max 로 선형 스케일
    low_mask = h < boundary
    h_weighted[low_mask] = h_weighted[low_mask] / boundary * low_max

    # ② 상위 구간: boundary~1 → high_min~high_max 선형 스케일
    high_mask = ~low_mask
    h_weighted[high_mask] = (
        high_min +
        (h_weighted[high_mask] - boundary) / (1 - boundary) * (high_max - high_min)
    )
    return h_weighted


def show_t2s_attention_overlay(concat_frames: torch.Tensor | np.ndarray,
                               attn_video: torch.Tensor,
                               *,
                               mode: str = "per_ts",      # "per_ts" | "mean"
                               ts: list[int] | None = None,   # 보고 싶은 at 인덱스
                               alpha: float = 0.4,
                               cmap: str = "jet",
                               blur_ks=5, cutoff=0.2, gamma=2.0,
                               boundary=0.8, low_max=0.3, high_min=0.3, high_max=1.0,
                               save_path=None):
    """
    concat_frames : [224, 1792, 3]  (8 장의 224×224 프레임을 가로로 이어 붙인 RGB 이미지)
    attn_video    : [39, 8, 196]    (at = 39, vt = 8, vn = 14×14)
    """
    if mode not in {"per_ts", "mean"}:
        raise ValueError("mode must be 'per_ts' or 'mean'")

    H, W = concat_frames.shape[:2]
    device  = attn_video.device
    concat_frames = _prep_rgb(concat_frames) 
    # (at, vt, vn=196) → (at, 14, 14*8=112)  패치 격자를 가로로 이어붙인 형태
    attn_grid = rearrange(attn_video, 'at vt (ph pw) -> at ph (vt pw)', ph=14, pw=14)
    # attn_grid = rearrange(attn_video, 'at vt (pw ph) -> at ph (vt pw)', ph=14, pw=14)
    # ↑ [39,14,112]

    # 업샘플 → [at, H, W]
    attn_up = F.interpolate(attn_grid.unsqueeze(1), size=(H, W),
                            mode="bilinear", align_corners=False).squeeze(1)

    if mode == "per_ts":
        ts_list = ts if ts is not None else list(range(attn_up.size(0)))
        # at 별로 0–1 정규화
        min_val = torch.amin(attn_up[ts_list], dim=(-2, -1), keepdim=True)
        max_val = torch.amax(attn_up[ts_list], dim=(-2, -1), keepdim=True)
        attn_up = (attn_up[ts_list] - min_val) / (max_val - min_val + 1e-6)

        n = len(ts_list)
        cols = min(4, n); rows = math.ceil(n / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows),
                                 sharex=True, sharey=True)
        axes = axes.flatten() if n > 1 else [axes]

        for ax_i, t in enumerate(ts_list):
            ax = axes[ax_i]
            ax.imshow(concat_frames, aspect="auto")
            ax.imshow(attn_up[ax_i].cpu(), aspect="auto",
                      cmap=cmap, alpha=alpha)
            # ax.set_title(f"at = {t}", fontsize=10)
            ax.axis("off")
        for ax in axes[n:]:
            ax.axis("off")

    else:  # mode == "mean"
        h = attn_up.mean(0, keepdim=True)   # [1,H,W]
        # h = ((h - h.min()) / (h.max() - h.min() + 1e-6))[0]
        h = _postprocess_heatmap(h[0], blur_ks=blur_ks,
                                 cutoff=cutoff, gamma=gamma)

        h_strength = _piecewise_strength(h, boundary=boundary,
                                 low_max=low_max,  # 0~0.8 구간
                                 high_min=high_min, # 0.8 에서 시작
                                 high_max=high_max) # 1.0 에서 최대
        
        fig, ax = plt.subplots(figsize=(32, 4))
        ax.imshow(concat_frames, aspect="auto")
        ax.imshow(h.cpu(), aspect="auto",
                  cmap=cmap, alpha=alpha)
        # ax.set_title("Mean of all audio time-steps", fontsize=12)
        ax.axis("off")

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
        plt.show()
        plt.close()
    else:
        plt.show()
    
    del attn_up, h


for layer_id in range(depth):
    print(f"Layer {layer_id+1}")
    
    t2s = t2s_all_attn[layer_id]
    t2s_split = rearrange(t2s,'b h (af at) (vn vt) -> b h af at vt vn',vt=T, vn=196, af=f_dim, at=t_dim)
    t2s_split_attn_video = t2s_split.mean(-4).mean(1)
    
    save_dir = make_save_dir(batch, action, base_dir="./results")
    base_name = f"layer{layer_id+1:02d}"

    # 1) 원본 비디오
    spec_path = os.path.join(save_dir, f"video.png")
    # save_rgb_image(concatenated_frames, spec_path)
    
    with torch.no_grad():
        overlay_path = os.path.join(save_dir, f"{base_name}_video_overlay.png")
        show_t2s_attention_overlay(concatenated_frames,
                                t2s_split_attn_video[0].detach(),
                                mode="mean",
                            ts=[0,10,20,30],
                            alpha=0.4,          # 전체 투명도
                            blur_ks=1,          # 노이즈 스무딩
                            cutoff=0,         # 하위 20% 제거
                            gamma=0.5,          # 대비 ↑
                            # save_path=overlay_path
                            boundary=0.7, low_max=0.4, high_min=0.4, high_max=1.0
                            )
    
    del t2s_split_attn_video
torch.cuda.empty_cache()
gc.collect()

In [None]:
import math, torch, torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def show_s2t_attention_waveform(waveform: torch.Tensor | np.ndarray,
                                attn_audio: torch.Tensor,
                                *,
                                sr: int | None = None,          # 주파수(선택)
                                mode: str = "per_frame",        # "per_frame" | "mean"
                                frames: list[int] | None = None,# 보고 싶은 vt
                                alpha: float = 0.4,
                                cmap: str = "jet",
                                save_path: str | None = None):
    """
    waveform   : [S] 또는 [C,S]  (torch/np 모두OK)
    attn_audio : [8, 12, 39]     (vt, f_dim, t_dim)

    주 아이디어
    1) attn_audio.mean(1)  -> [8,39]      # f_dim 축 평균
    2) F.interpolate       -> [8,S]       # S = waveform length
    3) imshow(extent=...)  로 α-블렌딩
    """
    if mode not in {"per_frame", "mean"}:
        raise ValueError("mode must be 'per_frame' or 'mean'")

    # ─── 데이터 준비 ─────────────────────────────────────────────
    if torch.is_tensor(waveform):
        wav = waveform.detach().cpu()
    else:
        wav = torch.as_tensor(waveform, dtype=torch.float32)
    if wav.ndim == 2:       # stereo → 첫 채널
        wav = wav[0]
    S = wav.numel()
    t = np.arange(S) / sr if sr else np.arange(S)   # x축

    # attn: [vt, 39] → upsample → [vt, S]
    attn = attn_audio.mean(1)                       # [8,39]
    attn_up = F.interpolate(attn.unsqueeze(1), size=S,
                            mode='linear', align_corners=False).squeeze(1)  # [8,S]

    # ─── overlay ────────────────────────────────────────────────
    if mode == "per_frame":
        vt_list = frames if frames is not None else list(range(attn_up.size(0)))
        n, cols = len(vt_list), min(4, len(vt_list))
        rows = math.ceil(n / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 2.5*rows),
                                 sharex=True, sharey=True)
        axes = axes.flatten() if n > 1 else [axes]

        for ax_i, vt in enumerate(vt_list):
            ax = axes[ax_i]
            # wave
            ax.plot(t, wav, lw=0.7, color='black')
            # attn heatmap (0–1 정규화)
            h = (attn_up[vt] - attn_up[vt].min()) / (attn_up[vt].max() - attn_up[vt].min() + 1e-6)
            ax.imshow(h.unsqueeze(0).repeat(50,1),   # 높이 50px dummy
                      extent=[t[0], t[-1], wav.min(), wav.max()],
                      origin='lower', aspect='auto',
                      cmap=cmap, alpha=alpha)
            ax.set_title(f"vt={vt}", fontsize=9)
            ax.axis("off")
        for ax in axes[n:]:
            ax.axis("off")

    else:  # mode == "mean"
        h = attn_up.mean(0)
        h = (h - h.min()) / (h.max() - h.min() + 1e-6)
        fig, ax = plt.subplots(figsize=(8, 2.5))
        ax.plot(t, wav, lw=0.7, color='black')
        ax.imshow(h.unsqueeze(0).repeat(50,1),
                  extent=[t[0], t[-1], wav.min(), wav.max()],
                  origin='lower', aspect='auto',
                  cmap=cmap, alpha=alpha)
        ax.axis("off")

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        plt.show()
        
        
for layer_id in range(depth):
    print(f"Layer {layer_id+1}")
    s2t = s2t_all_attn[layer_id]
    s2t_split = rearrange(s2t,'b h (vn vt) (af at) -> b h vt vn af at',vt=T, vn=196, af=f_dim, at=t_dim)
    s2t_split_attn_audio = s2t_split.mean(-3).mean(1)
    
    save_dir = make_save_dir(batch, action, base_dir="./results")
    base_name = f"layer{layer_id+1:02d}"

    # 1) 원본 스펙트로그램
    # spec_path = os.path.join(save_dir, f"spec.png")
    # save_audio_spec(audio_spec.detach(), spec_path)

    # 2) overlay (mean 모드 예시)
    # overlay_path = os.path.join(save_dir, f"{base_name}_spec_overlay.png")
    show_s2t_attention_waveform(audio_wav,
                               s2t_split_attn_audio[0].detach().cpu(),
                               mode='mean',
                            #    save_path=overlay_path
                               )


# Multi Entropy

In [None]:
audio_wav

In [None]:
import numpy as np
import torch
from models.prompt import dataset_class
action = dataset_class(time_args.data_set, time_args.anno_path)

time_model.eval()
device = "cuda"

N = 200  # 평가할 샘플 수 (원하면 더 늘리기)
layer_depth = len(time_model.blocks)  # = len(all_atts)//2  와 동일

# [layer][sample] 구조로 entropies 저장
s2t_all = [[] for _ in range(layer_depth)]
t2s_all = [[] for _ in range(layer_depth)]
s2t_all_attn = []
t2s_all_attn = []

# @torch.no_grad()
# def compute_entropy(attn):
#     p = attn.clamp_min(1e-12)
#     h = -(p * p.log2()).sum(-1)       # [B, head, query]
#     return h.mean().item()

@torch.no_grad()
def compute_entropy(attn, mode="ratio"):
    p = attn.clamp_min(1e-12)
    h = -(p * p.log2()).sum(-1)          # [B, head, Q]

    if mode == "norm":       # 0~1
        K = attn.size(-1)
        h = h / math.log2(K)
    elif mode == "ratio":    # 0~1, effective key share
        K = attn.size(-1)
        h = (2**h) / K

    return h.mean().item()

# 전체 데이터셋에서 N개 랜덤 인덱스 선택
total_samples = len(data_loader_val)
rand_indices = np.random.choice(total_samples, N, replace=False)
ind_correct = []

for idx in rand_indices:
    batch = data_loader_val.get_item_by_index(idx)

    time_idx = torch.tensor(batch[5]).unsqueeze(0).to(device)
    samples  = batch[0].unsqueeze(0).to(device, non_blocking=True)
    spec     = batch[3].unsqueeze(0).to(device, non_blocking=True)
    captions = batch[4]

    logits, all_atts = time_model(samples,
                             caption=captions,
                             spec=spec,
                             idx=time_idx,
                             output_attentions=True)

    print('action :', action['action'][batch[1]], '\tpred :', action['action'][torch.argmax(logits).item()])
    ind_correct.append(action['action'][batch[1]] == action['action'][torch.argmax(logits).item()])

    for l in range(layer_depth):
        t2s = all_atts[2*l]      # video Q ← audio KV
        s2t = all_atts[2*l + 1]  # audio Q ← video KV

        s2t_all[l].append(compute_entropy(s2t.detach().cpu()))
        t2s_all[l].append(compute_entropy(t2s.detach().cpu()))
        # s2t_all_attn.append(s2t.detach().cpu())
        # t2s_all_attn.append(t2s.detach().cpu())

# numpy 로 평균·표준편차 계산
mean_s2t = [np.mean(x) for x in s2t_all]
std_s2t  = [np.std(x)  for x in s2t_all]
mean_t2s = [np.mean(x) for x in t2s_all]
std_t2s  = [np.std(x)  for x in t2s_all]

print(f"Correct predictions: {sum(ind_correct)}/{N} ({100 * sum(ind_correct) / N:.2f}%)")

print(f"\nEvaluated {N} samples")
print("Layer |  S2T  μ±σ   |  T2S  μ±σ")
print("----------------------------------")
for i, (m1, s1, m2, s2) in enumerate(zip(mean_s2t, std_s2t, mean_t2s, std_t2s), 1):
    print(f"{i:2d}   | {m1:5.2f}±{s1:4.2f} | {m2:5.2f}±{s2:4.2f}")


In [None]:
import pandas as pd
layers = list(range(1, len(mean_s2t)+1))
df = pd.DataFrame({
    "layer": layers,
    "mean_A2V_ratio": mean_s2t,
    "std_A2V_ratio": std_s2t,
    "mean_V2A_ratio": mean_t2s,
    "std_V2A_ratio": std_t2s
})

file_path = "cava_entropy_stats_vggsoud_speccut_missing0.csv"
# df.to_csv(file_path, index=False)

print(f"CSV saved at {file_path}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# Helvetica.ttf 폰트 경로 지정 및 FontEntry 등록
font_path = '/data/joohyun7u/.fonts/Helvetica.ttf'
fm.fontManager.addfont(font_path)   # <- 추가

# 폰트 이름을 직접 지정 (ttf 내의 family name이어야 함)
prop = fm.FontProperties(fname=font_path)
# plt.rcParams['font.family'] = prop.get_name()
# plt.rcParams['font.family'] = [prop.get_name(), 'DejaVu Sans', 'Noto Sans Symbols']
symbol_font = fm.FontProperties(family='DejaVu Sans')  # 시스템

# rcParams에 지정한 이름 등록
plt.rcParams['font.family'] = 'Helvetica'

sel_type = 1
if sel_type == 1:
    mean_t2s = [np.mean(x) for x in t2s_all]
    std_t2s  = [np.std(x)  for x in t2s_all]
    mean_s2t = [np.mean(x) for x in s2t_all]
    std_s2t  = [np.std(x)  for x in s2t_all]
elif sel_type >= 2:
    import pandas as pd
    # entropy_path = './cava_entropy_stats_missing0.csv'
    entropy_path = './cava_entropy_stats_epicsound_pth3_missing0.csv'
    # entropy_path = './cava_entropy_stats_vggsound_missing0.csv'
    entropy_df = pd.read_csv(entropy_path)
    mean_t2s = entropy_df['mean_V2A_ratio'].tolist()
    mean_s2t = entropy_df['mean_A2V_ratio'].tolist()
    std_t2s  = entropy_df['std_V2A_ratio'].tolist()
    std_s2t  = entropy_df['std_A2V_ratio'].tolist()

layers = np.arange(1, len(mean_t2s) + 1)
# layers = np.arange(0, len(mean_t2s) + 0)
plt.figure(figsize=(12,5))
plt.errorbar(layers, mean_t2s, yerr=std_t2s, fmt='o-', capsize=4, label='V' + r"$\rightarrow$" + 'A Entropy')
plt.errorbar(layers, mean_s2t, yerr=std_s2t, fmt='s-', capsize=4, label='A' + r"$\rightarrow$" + 'V Entropy')
plt.xlabel('Layer', fontsize=14)
plt.ylabel('Entropy Ratio', fontsize=14)
plt.title('Layer-wise Attention Entropy (V' + r"$\rightarrow$" + 'A & A' + r"$\rightarrow$" + 'V)', fontsize=16)
plt.grid(True)
plt.legend()
# plt.savefig("layerwise_entropy_vggsoud_speccut_missing0.pdf", bbox_inches='tight')
plt.show()


# # entropy ratio로 나눈 값 시각화
# mean_t2s_ratio = [m / np.log2(time_model.blocks[0].cross_t_down.out_features) for m in mean_t2s]
# mean_s2t_ratio = [m / np.log2(time_model.blocks[0].cross_s_down.out_features) for m in mean_s2t]

# plt.figure(figsize=(12,5))
# plt.errorbar(layers, mean_t2s_ratio, yerr=std_t2s, fmt='o-', capsize=4, label='T→S Entropy Ratio')
# plt.errorbar(layers, mean_s2t_ratio, yerr=std_s2t, fmt='s-', capsize=4, label='S→T Entropy Ratio')
# plt.xlabel('Layer')
# plt.ylabel('Entropy Ratio')
# plt.title('Layer-wise Attention Entropy Ratio (T→S & S→T)')
# plt.grid(True)
# plt.legend()
# # plt.savefig("layerwise_entropy_ratio.pdf", bbox_inches='tight')
# plt.show()