In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import pandas as pd
import gzip

In [7]:
MODEL_PATH = 'baseline_11_29/'

In [None]:
# Load the model
from capture24.patch_tst import PatchTST
import yaml

with open('capture24/config_patchtst.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['hook_attention_maps'] = True
model = PatchTST(config)
model.load_state_dict(torch.load(MODEL_PATH + 'patchtst_model.pth', weights_only=True, map_location=torch.device('cpu')))
model.eval()

In [4]:
# Load the Test Data 
with gzip.open('capture24/final_data_512/X_test.npy.gz', 'rb') as f:
    X_test = np.load(f)

with gzip.open('capture24/final_data_512/Y_test.npy.gz', 'rb') as f:
    Y_test = np.load(f)

with open('capture24/final_data_512/label_to_index.json', 'r') as f:
    data = json.load(f)

idx_to_label = data['index_to_label']
label_to_idx = data['label_to_index']


In [None]:
def plot_example(test_example, label, path = None):
    # size will be (512, 3) -> (time, channels)
    # conver to pandas dataframe
    df = pd.DataFrame(
        {
            'time': range(test_example.shape[0]),
            'channel_1': test_example[:, 0],
            'channel_2': test_example[:, 1],
            'channel_3': test_example[:, 2]
        }
    )
    # plot the dataframe
    plt.figure(figsize=(15, 10))
    sns.lineplot(x = 'time', y = 'channel_1', data = df)
    sns.lineplot(x = 'time', y = 'channel_2', data = df)
    sns.lineplot(x = 'time', y = 'channel_3', data = df)
    plt.title(f'Test Example for {label}')
    if path: 
        plt.savefig(path)
    plt.show()
    

In [None]:
from capture24.patch_tst import C24_Dataset
from torch.utils.data import DataLoader
# Identify a single example per class

CLASS_NAME = list(label_to_idx.keys())[9]
class_idx = label_to_idx[CLASS_NAME]


# Find a single test example for this class
test_idx = np.where(Y_test == class_idx)[0][0]

# create folder for specific class
path = MODEL_PATH + f'attention/{CLASS_NAME}/'
if not os.path.exists(path):
    os.makedirs(path)

# Get the test example
test_example = X_test[test_idx]
plot_example(test_example, CLASS_NAME, path + f'{CLASS_NAME}.png')

# Inference # add batch dimension
test_example = test_example[np.newaxis, :, :]
ex = C24_Dataset(test_example, np.array([Y_test[test_idx]]), idx_to_label, label_to_idx)
loader = DataLoader(ex, batch_size=1, shuffle=False)

for x, y in loader:
    outputs = model(x)

maps = model.attention_maps
print(maps[0].shape)

In [None]:
# Plot all of the attention maps
for i, map in enumerate(maps):
    path = MODEL_PATH + f'attention/{CLASS_NAME}/{i+1}/'
    if not os.path.exists(path):
        os.makedirs(MODEL_PATH + f'attention/{CLASS_NAME}/{i+1}')
    for j, head in enumerate(map.squeeze()):
        plt.figure(figsize=(15, 10))
        sns.heatmap(head.cpu().numpy().squeeze(), cmap='viridis')
        plt.title(f'Attention Map for layer{i+1} and head {j+1}')
        # save plot to path folder
        plt.savefig(path + f'layer_{i+1}_head_{j+1}.png')


In [8]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import numpy as np

# For each activity, load all attention maps into a layers x heads grid, and save as a compiled image

ACTIVITIES = list(label_to_idx.keys())
NUM_LAYERS = 3
NUM_HEADS = 16  # assumes shape [batch, heads, ...] as before

for act in ACTIVITIES:
    grid_images = []
    for head in range(1, NUM_HEADS+1):
        row_imgs = []
        for layer in range(1, NUM_LAYERS+1):
            fname = MODEL_PATH + f"attention/{act}/{layer}/layer_{layer}_head_{head}.png"
            if os.path.exists(fname):
                img = mpimg.imread(fname)
            else:
                # create a blank placeholder if it doesn't exist
                img = np.ones((224, 224, 3), dtype=np.float32)  # Assuming 224x224 images, adjust as needed
            row_imgs.append(img)
        # horizontally stack for each row (head)
        row_imgs = [img if img is not None else np.ones_like(row_imgs[0]) for img in row_imgs]
        grid_images.append(np.concatenate(row_imgs, axis=1))
    # vertically stack for all rows (heads)
    grid_img = np.concatenate(grid_images, axis=0)
    plt.figure(figsize=(NUM_LAYERS * 3, NUM_HEADS * 3))
    plt.imshow(grid_img)
    plt.axis('off')
    plt.title(f"All Attention Maps for {act}")
    save_path = MODEL_PATH + f'attention/{act}_all_attention_grids.png'
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
