In [None]:
import glob
import json
import os

import pandas as pd
import numpy as np

from experiments.mmshap import compute_mm_score

# Parse output json file

In [None]:
def combine_output(json_list):
    """
    receive a list of output files. load then and combine them into a single one.
    """
    data = []
    for jf in json_list:
        with open(jf, "r") as f:
            data.append(json.load(f))
    return data        

In [None]:
mullama_fs_list = glob.glob("../data/output_data/mullama_muchomusic_musiccaps_fs/*.json")
qwen_fs_list = glob.glob("../data/output_data/qwenaudio_muchomusic_musiccaps_fs/*.json")
qwen_desc_list = glob.glob("../data/output_data/qwenaudio_muchomusic_musiccaps_desc/*.json")

In [None]:
qwen_fs = combine_output(qwen_fs_list)
qwen_desc = combine_output(qwen_desc_list)
mu_fs = combine_output(mullama_fs_list)

In [None]:
with open("../data/output_data/qwen_fs.json", "w") as f:
    json.dump(qwen_fs, f)

with open("../data/output_data/qwen_desc.json", "w") as f:
    json.dump(qwen_desc, f)
    
with open("../data/output_data/mullama_fs.json", "w") as f:
    json.dump(mu_fs, f)

# Compute MM-SHAP values

In [None]:
def extract_answer_pandas(
    model_output, 
    answer_options, 
    prefix="The correct answer is:", 
    letter_options=["A", "B", "C", "D"]):
    """ adaptation of muchomusic function but applies to our pandas dataframe """

    output = model_output.split(prefix)[-1].strip()
    response = list(set(letter_options).intersection(output))
    if len(response) == 1:
        final_response = letter_options.index(response[0])
    else:
        normalized_output = output.lower().strip()
        normalized_answers = [j.lower().strip() for j in answer_options]

        for j, answer in enumerate(normalized_answers):
            if answer in normalized_output:
                final_response = j
                break
            else:
                final_response = -1
    return final_response

def compare_answers(response, answer_orders):
    """
    return correct/incorrect/unanswered
    """
    answer = 0
    if response == answer_orders.index(0):
        answer = 1
    elif response == -1:
        answer = -1

    return answer

def accuracy(df):
    return df[df["final_answer"] == 1]["final_answer"].count()/df["final_answer"].count()


In [None]:
def compute_mmshap_row(row):
    base_folder = ".."
    data_folder = os.path.join(base_folder, row.output_folder)
    
    shapley_values = np.load(os.path.join(data_folder, f"{row.question_id}_shapley_values.npy"))
    tokens = np.load(os.path.join(data_folder, f"{row.question_id}_tokens.npy")).squeeze(0).squeeze(0)
    
    audio_length = len(np.where(tokens < 0)[0])
    audio_score, text_score = compute_mm_score(shap_values=shapley_values, audio_length=audio_length, method="sum")
    
    return pd.Series({"a-shap": audio_score, "t-shap": text_score, "tokens": tokens})

In [None]:
qfs = pd.read_json("../data/output_data/qwen_fs.json")
mfs = pd.read_json("../data/output_data/mullama_fs.json")

In [None]:
qfs["extracted_response"] = qfs[["model_output", "answers"]].apply(lambda x: extract_answer_pandas(x.model_output, x.answers), axis=1)
qfs["final_answer"] = qfs[["extracted_response", "answer_orders"]].apply(lambda x: compare_answers(x.extracted_response, x.answer_orders), axis=1)

In [None]:
mfs["extracted_response"] = mfs[["model_output", "answers"]].apply(lambda x: extract_answer_pandas(x.model_output, x.answers), axis=1)
mfs["final_answer"] = mfs[["extracted_response", "answer_orders"]].apply(lambda x: compare_answers(x.extracted_response, x.answer_orders), axis=1)
mfs["question"] = mfs[["prompt"]].apply(lambda x: x.prompt.split("Question: ")[-1], axis=1)
mfs["audio_path"] = mfs[["audio_path"]].apply(lambda x: x.audio_path.replace("data/", ""), axis=1)

In [None]:
qfs[["a_shap", "t_shap", "tokens"]] = qfs.apply(compute_mmshap_row, axis=1)
mfs[["a_shap", "t_shap", "tokens"]] = mfs.apply(compute_mmshap_row, axis=1)

In [None]:
mfs[mfs["final_answer"] == 1][["question", "model_output", "a_shap"]]

# Plot 

In [None]:
import librosa
import matplotlib.pyplot as plt
import matplotlib as mpl
import IPython.display as ipd

In [None]:
mfs[mfs["final_answer"] == 1][["question_id", "question", "a_shap", "tokens", "audio_path"]].sort_values(by="a_shap", ascending=False)

In [None]:
example = mfs[mfs["question_id"] == 869]

In [None]:
example

In [None]:
# dataset_path
dataset_path = "/media/gigibs/DD02EEEC68459F17/datasets"

In [None]:
# load audio
x, fs = librosa.load(os.path.join(dataset_path, example["audio_path"].values[0]), sr=24000)

In [None]:
x.shape

In [None]:
ipd.Audio(x, rate=fs)

In [None]:
# load data
data_path = os.path.join("..", example["output_folder"].values[0])

In [None]:
os.listdir(data_path)

In [None]:
tokens = np.load(os.path.join(data_path, "869_tokens.npy")).squeeze(0).squeeze(0)
audio_tokens = np.where(tokens < 0)[-1]

In [None]:
tokens.shape, audio_tokens.shape

In [None]:
all_shapley_values = np.load(os.path.join(data_path, "869_shapley_values.npy")).squeeze(0).squeeze(0)
audio_shapley_values = all_shapley_values[audio_tokens]

In [None]:
audio_shapley_values.shape

In [None]:
abs_audio_shapley = np.abs(audio_shapley_values).sum(axis=1)
pos_audio_shapley = np.clip(audio_shapley_values, a_min=0, a_max=None).sum(axis=1)
neg_audio_shapley = np.clip(audio_shapley_values, a_min=None, a_max=0).sum(axis=1)

In [None]:
pos_audio_shapley.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

def plot_shapley_analysis(signal, audio_shapley_values, sample_rate,
                         gt_start, gt_end, colormap='viridis', 
                         figsize=(12, 8), idx=None, output_token=None):
    """
    Plot signal with Shapley value heatmaps in separate subplots with shared x-axis.
    
    Parameters:
    -----------
    signal : array-like
        The input signal (waveform) to display
    audio_shapley_values : array-like (2D)
        Raw Shapley values (will be summed across features)
    sample_rate : int
        Sampling rate of the audio signal (Hz)
    gt_start : float
        Start time of ground truth event (seconds)
    gt_end : float
        End time of ground truth event (seconds)
    colormap : str, optional
        Matplotlib colormap to use (default: 'viridis')
    figsize : tuple, optional
        Figure size (width, height) in inches (default: (12, 8))
    """
    
    # Calculate Shapley value components
    abs_shapley = np.abs(audio_shapley_values)#.mean(axis=1)
    pos_shapley = np.clip(audio_shapley_values, a_min=0, a_max=None)#.mean(axis=1)
    neg_shapley = np.clip(audio_shapley_values, a_min=None, a_max=0)#.mean(axis=1)

    if idx is None and output_token is None:
        abs_shapley = abs_shapley.sum(axis=1)
        pos_shapley = pos_shapley.sum(axis=1)
        neg_shapley = neg_shapley.sum(axis=1)
    
    # Create figure with subplots
    fig, axes = plt.subplots(4, 1, figsize=figsize, 
                           sharex=True, 
                           gridspec_kw={'height_ratios': [2, 1, 1, 1]})
    
    total_duration = len(signal) / sample_rate
    time_axis = np.linspace(0, total_duration, len(signal))
    shapley_time_axis = np.linspace(0, total_duration, len(abs_shapley))
    
    # --- 1. Signal plot (top subplot) ---
    ax_signal = axes[0]
    ax_signal.plot(time_axis, signal, color='gray', alpha=0.7, linewidth=0.5)
    ax_signal.set_ylabel('Amplitude', fontsize=10)
    
    # Add ground truth rectangle (only on signal plot)
    ymin, ymax = ax_signal.get_ylim()
    ax_signal.axvspan(gt_start, gt_end, ymin=0, ymax=1, 
                     color='red', alpha=0.3, label='Ground Truth')
    ax_signal.legend(loc='upper right')
    
    # --- 2. Absolute Shapley values ---
    ax_abs = axes[1]
    im_abs = ax_abs.imshow(
        np.repeat(abs_shapley.reshape(1, -1), 10, axis=0),
        aspect='auto',
        cmap=colormap,
        extent=[0, total_duration, 0, 1],
        vmin=0,  # Ensure consistent scaling
        vmax=np.max(abs_shapley)
    )
    ax_abs.set_ylabel('Absolute\nValue', rotation=0, ha='right', va='center', fontsize=10)
    ax_abs.set_yticks([])
    
    # --- 3. Positive Shapley values ---
    ax_pos = axes[2]
    im_pos = ax_pos.imshow(
        np.repeat(pos_shapley.reshape(1, -1), 10, axis=0),
        aspect='auto',
        cmap=colormap,
        extent=[0, total_duration, 0, 1],
        vmin=0,
        vmax=np.max(abs_shapley)  # Same scale as absolute
    )
    ax_pos.set_ylabel('Positive\nOnly', rotation=0, ha='right', va='center', fontsize=10)
    ax_pos.set_yticks([])
    
    # --- 4. Negative Shapley values ---
    ax_neg = axes[3]
    im_neg = ax_neg.imshow(
        np.repeat(np.abs(neg_shapley).reshape(1, -1), 10, axis=0),  # Show magnitude
        aspect='auto',
        cmap=colormap,
        extent=[0, total_duration, 0, 1],
        vmin=0,
        vmax=np.max(abs_shapley)  # Same scale as absolute
    )
    ax_neg.set_ylabel('Negative\nOnly', rotation=0, ha='right', va='center', fontsize=10)
    ax_neg.set_yticks([])
    ax_neg.set_xlabel('Time (seconds)', fontsize=12)
    
    # --- Colorbar ---
    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(im_abs, cax=cax, label='Shapley Value Magnitude')
    
    # --- Formatting ---
    fig_title = "Shapley Values (sum over all output tokens)"
    if idx is not None:
        fig_title = f'Shapley Values -- Output Token: {output_token}'
    plt.suptitle(fig_title, y=0.98, fontsize=14)
    plt.tight_layout()
    
    # Remove boxes around subplots
    for ax in axes:
        ax.set_frame_on(False)
    
    plt.subplots_adjust(right=0.9, hspace=0.1)
    plt.show()

In [None]:
audio_shapley_values.shape

In [None]:
for idx, t in enumerate(["The", "tele", "phone", "sound", "effect", "is", "present", "at", "the", "beginning", "of", "this", "music", "piece", "."]):
    plot_shapley_analysis(
        x, 
        sample_rate=fs, 
        audio_shapley_values=audio_shapley_values[:,idx], 
        gt_start=0.5, 
        gt_end=3.5, 
        colormap="viridis", 
        idx=idx,
        output_token=t
    )

In [None]:
plot_shapley_analysis(
    x, 
    sample_rate=fs, 
    audio_shapley_values=audio_shapley_values, 
    gt_start=0.5, 
    gt_end=3.5, 
    colormap="viridis", 
)

In [None]:
# load model and tokenizer to encode output 

In [None]:
example["model_output"].values[0]

In [None]:
example["prompt"].values[0]

In [None]:
from matplotlib.colors import rgb2hex
import matplotlib.cm as cm
import numpy as np

def highlight_tokens(shapley_values, tokens, max_abs_value=None):
    """
    Highlight text tokens based on Shapley values using HTML span tags.
    
    Args:
        shapley_values: List of Shapley values (one per token)
        tokens: List of text tokens (same length as shapley_values)
        max_abs_value: Optional maximum absolute value for color scaling.
                      If None, will use max absolute value from shapley_values.
    
    Returns:
        HTML string with tokens colored based on Shapley values
    """
    if len(shapley_values) != len(tokens):
        raise ValueError("Shapley values and tokens must have the same length")
    
    shapley_values = np.array(shapley_values, dtype=float)
    
    # Determine color scaling
    if max_abs_value is None:
        max_abs_value = np.max(np.abs(shapley_values))
    
    # Normalize values to [-1, 1] range for coloring
    normalized_values = shapley_values / (max_abs_value + 1e-10)
    
    # Create red (positive) and blue (negative) color maps
    red_cmap = cm.Reds
    blue_cmap = cm.Blues_r
    
    highlighted_text = []
    for value, token in zip(normalized_values, tokens):
        if value > 0:  # Positive impact
            # Scale to [0.4, 1] range to avoid very light colors
            intensity = 0.4 + 0.6 * abs(value)
            rgba = red_cmap(intensity)
        elif value < 0:  # Negative impact
            intensity = 0.4 + 0.6 * abs(value)
            rgba = blue_cmap(intensity)
        else:  # Zero impact
            highlighted_text.append(token)
            continue
            
        hex_color = rgb2hex(rgba)
        span = f'<span style="background-color: {hex_color}">{token}</span>'
        highlighted_text.append(span)
    
    return ' '.join(highlighted_text)

In [None]:
shap_values = [0.5, -0.3, 0.1, -0.8, 0.0]
tokens = ["The", "movie", "was", "terrible", "!"]

html_output = highlight_tokens(shap_values, tokens)

# To display in Jupyter notebook:
from IPython.display import HTML
HTML(html_output)