In [None]:
import math
import os

import matplotlib.pyplot as plt
import torch

In [None]:
w = 0.003
attn_score_pt_path = f"w_{w}_num_layers_1_time_feat_dim_8_linear_attn_score.pt"
attn_score_dict = torch.load(os.path.join("attn_analysis_output", attn_score_pt_path))

In [None]:
def get_attn_score(attn_score, t_diff, w):
    """
    Get the normalized attention score and true attention score for the first n-1 tokens
    attn_score: [batch_size, n, n]
    t_diff: [batch_size, n]
    w: float, exponential decay factor
    """
    attn_score_true = torch.exp(-w * t_diff)
    # exclude the last token and normalize the previous tokens
    attn_score_true = attn_score_true[:, :-1] / attn_score_true[:, :-1].sum(dim=1, keepdim=True)
    attn_score = attn_score[:, -1, :-1] / attn_score[:, -1, :-1].sum(dim=1, keepdim=True)
    return attn_score, attn_score_true


def get_attn_err(attn_score_dict, w):
    train_attn_score, train_attn_score_true = get_attn_score(
        attn_score_dict["train_attn_score"], attn_score_dict["train_t_diff"], w
    )
    val_attn_score, val_attn_score_true = get_attn_score(
        attn_score_dict["val_attn_score"], attn_score_dict["val_t_diff"], w
    )
    test_attn_score, test_attn_score_true = get_attn_score(
        attn_score_dict["test_attn_score"], attn_score_dict["test_t_diff"], w
    )

    train_attn_err = torch.mean((train_attn_score_true - train_attn_score) ** 2)
    val_attn_err = torch.mean((val_attn_score_true - val_attn_score) ** 2)
    test_attn_err = torch.mean((test_attn_score_true - test_attn_score) ** 2)

    return train_attn_err, val_attn_err, test_attn_err

In [None]:
def viz_attn_score(attn_score_dict, w):
    train_attn_score, train_attn_score_true = get_attn_score(
        attn_score_dict["train_attn_score"], attn_score_dict["train_t_diff"], w
    )
    val_attn_score, val_attn_score_true = get_attn_score(
        attn_score_dict["val_attn_score"], attn_score_dict["val_t_diff"], w
    )
    test_attn_score, test_attn_score_true = get_attn_score(
        attn_score_dict["test_attn_score"], attn_score_dict["test_t_diff"], w
    )

    avg_train_attn_score = train_attn_score.mean(dim=0)
    std_train_attn_score = train_attn_score.std(dim=0)
    avg_train_attn_score_true = train_attn_score_true.mean(dim=0)
    std_train_attn_score_true = train_attn_score_true.std(dim=0)

    avg_val_attn_score = val_attn_score.mean(dim=0)
    std_val_attn_score = val_attn_score.std(dim=0)
    avg_val_attn_score_true = val_attn_score_true.mean(dim=0)
    std_val_attn_score_true = val_attn_score_true.std(dim=0)

    avg_test_attn_score = test_attn_score.mean(dim=0)
    std_test_attn_score = test_attn_score.std(dim=0)
    avg_test_attn_score_true = test_attn_score_true.mean(dim=0)
    std_test_attn_score_true = test_attn_score_true.std(dim=0)

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 16))

    num_runs = 10

    ax1.errorbar(
        range(1, len(avg_train_attn_score) + 1),
        avg_train_attn_score,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_train_attn_score],
        label="Learned",
    )
    ax1.errorbar(
        range(1, len(avg_train_attn_score_true) + 1),
        avg_train_attn_score_true,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_train_attn_score_true],
        label="Truth",
    )
    ax1.legend()
    # ax1.plot(range(1, len(avg_train_attn_score) + 1), avg_train_attn_score)
    # ax1.plot(range(1, len(avg_train_attn_score_true) + 1), avg_train_attn_score_true)
    ax1.set_ylabel("Attention Score")
    ax1.set_xlabel("Event Index")
    ax1.set_title("Train")

    ax2.errorbar(
        range(1, len(avg_val_attn_score) + 1),
        avg_val_attn_score,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_val_attn_score],
        label="Learned",
    )
    ax2.errorbar(
        range(1, len(avg_val_attn_score_true) + 1),
        avg_val_attn_score_true,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_val_attn_score_true],
        label="Truth",
    )
    ax2.legend()
    # ax2.plot(range(1, len(avg_val_attn_score) + 1), avg_val_attn_score)
    # ax2.plot(range(1, len(avg_val_attn_score_true) + 1), avg_val_attn_score_true)
    ax2.set_ylabel("Attention Score")
    ax2.set_xlabel("Event Index")
    ax2.set_title("Val")

    ax3.errorbar(
        range(1, len(avg_test_attn_score) + 1),
        avg_test_attn_score,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_test_attn_score],
        label="Learned",
    )
    ax3.errorbar(
        range(1, len(avg_test_attn_score_true) + 1),
        avg_test_attn_score_true,
        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_test_attn_score_true],
        label="Truth",
    )
    ax3.legend()
    # ax3.plot(range(1, len(avg_test_attn_score) + 1), avg_test_attn_score)
    # ax3.plot(range(1, len(avg_test_attn_score_true) + 1), avg_test_attn_score_true)
    ax3.set_ylabel("Attention Score")
    ax3.set_xlabel("Event Index")
    ax3.set_title("Test")

    plt.show()

In [None]:
train_attn_err, val_attn_err, test_attn_err = get_attn_err(attn_score_dict, w)

In [None]:
train_attn_err, val_attn_err, test_attn_err

In [None]:
viz_attn_score(attn_score_dict, w)

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Parameters
# a = 0.003       # decay rate
# omega = 0.02   # angular frequency

# # Time range
# t = np.linspace(0, 800, 1000)

# # Define the function
# f = np.exp(-a * t) * np.cos(omega * t)**2

# # Plot
# plt.figure(figsize=(8, 4))
# plt.plot(t, f, label=r'$e^{-a t}\cos^2(\omega t)$')
# plt.title('Damped Oscillation: $e^{-a t}\\cos^2(\\omega t)$')
# plt.xlabel('t')
# plt.ylabel('f(t)')
# plt.ylim(0, 1.1)  # A bit above 1 for clarity
# plt.legend()
# plt.grid(True)
# plt.show()