In [None]:
import itertools
import os
import yaml

import awkward as ak
import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from source.dataset import JetClass, JetNet, JetTorchDataset
from source.models.part import AttentionBlock, ParticleTransformer

### `yaml` configurations

In [None]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

rnd_seed = config['rnd_seed']
L.seed_everything(rnd_seed)

### Hook function for extracting intermediate values of the models

See ["how to extract intermediate values of torch models via hook function"](https://discuss.pytorch.org/t/how-can-i-extract-intermediate-layer-output-from-loaded-cnn-model/77301) for further detail.

In [None]:
def torch_model_hook_function(name: str, intermediate_outputs: dict[str, torch.Tensor]):
    def hook_function(module, input, output):
        intermediate_outputs[name] = output.detach().cpu()
    return hook_function

def ParT_intermediate_outputs(model: ParticleTransformer) -> dict[str, torch.Tensor]:
    intermediate_outputs: dict[str, torch.Tensor] = {}

    # 8 particle attention blocks
    for i in range(8):
        attn_block: AttentionBlock = model.par_attn_blocks[i]
        attn_softmax = attn_block.attn.softmax

        hook_function = torch_model_hook_function(f"par_{i}", intermediate_outputs)
        attn_softmax.register_forward_hook(hook_function)

    return intermediate_outputs

### Function for load model checkpoint

In [None]:
def load_model_checkpoint(dataset: str, model: str, rnd_seed: int, epoch: int) -> nn.Module:
    # Prth to log directory.
    log_dir = os.path.join('training_logs', dataset, f"{model}_{rnd_seed}", 'lastest_run')

    # Load the initial model structure.
    model = torch.load(os.path.join(log_dir, 'model.pt'), weights_only=False)

    # Path to the model checkpoint.
    ckpt_dir = os.path.join(log_dir, 'checkpoints')
    ckpt_path = os.path.join(ckpt_dir, f"epoch={epoch}.ckpt")

    # Checkpoints might have different version due to shutdown during training.
    ckpt_version = 1
    while os.path.exists(os.path.join(ckpt_dir, f"epoch={epoch}-v{ckpt_version}.ckpt")):
        ckpt_path = os.path.join(ckpt_dir, f"epoch={epoch}-v{ckpt_version}.ckpt")
        ckpt_version += 1
    ckpt = torch.load(ckpt_path, weights_only=True)

    # Load the model.
    state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()}
    model.load_state_dict(state_dict)
    model.eval()

    return model

### Load dataset

In [None]:
DATASET = globals()[config['dataset']]
MODEL = ParticleTransformer

# Create the dataset.
pad_num_ptcs = config[DATASET.__name__]['pad_num_ptcs']
channels = DATASET.CHANNELS
jet_events_list = [DATASET(channel=channel) for channel in channels]
jet_dataset_list = [JetTorchDataset(jet_events, label, pad_num_ptcs=pad_num_ptcs) for label, jet_events in enumerate(jet_events_list)]
fields = jet_dataset_list[0].fields

# Number of data points.
num_data = 10

### Store the selected particle features

In [None]:
# Store data
particle_features_list = []

for jet_events, data_index in itertools.product(jet_events_list, range(num_data)):
    
    x = jet_events.data[data_index]
    
    # Select few fields to present in dashboard.
    for field in fields:
        particle_features_list.append({
            'channel': jet_events.channel,
            'data_index': data_index,
            'feature': field,
            'array': ak.to_numpy(x[field])
        })

### Store the intermediate values

In [None]:
# Outputs to be stored and loaded in dashboard.
intermediate_outputs_list = []

# Weights and biases of the first layer.
linear_weights = []

# Loop over each epoch checkpoint.
for epoch_index in range(config['num_epochs']):

    # Load the model checkpoint.
    model = load_model_checkpoint(
        dataset=DATASET.__name__,
        model=MODEL.__name__,
        rnd_seed=rnd_seed,
        epoch=epoch_index,
    )

    # Store the weights and biases of the first layer.
    if MODEL == ParticleTransformer:
        w = model.par_embedding.embedding[1].weight.detach().numpy().T
        assert len(w.shape) == 2, "The weight matrix should have 2 dimensions."
        for i, field in enumerate(fields):
            linear_weights.append({
                'epoch_index': epoch_index,
                'field': field,
                'weights': w[i],
            })

    # Extract intermediate outputs for each channel and data index.
    for jet_dataset, data_index in itertools.product(jet_dataset_list, range(num_data)):

        x, _ = jet_dataset[data_index]

        # Remove the padded particles.
        x = x[torch.all(torch.isfinite(x), dim=-1)]

        # Reshape for additional the batch dimension.
        x = x.unsqueeze(0)

        # Fetch intermediate outputs.
        if MODEL == ParticleTransformer:
            output = ParT_intermediate_outputs(model)
            _ = model(x)

            # 8 heads in the attention blocks.
            for block_index in range(8):
                intermediate_outputs_list.append({
                    'channel': jet_dataset.channel,
                    'data_index': data_index,
                    'epoch_index': epoch_index,
                    'block_index': block_index,
                    'output': output[f"par_{block_index}"].squeeze(0)
                })

### Store in pandas DataFrame

In [None]:
summary = {
    'channels': channels,
    'num_data': num_data,
    'num_epochs': config['num_epochs'],
    'particle_features': pd.DataFrame(particle_features_list),
    'intermediate_outputs': pd.DataFrame(intermediate_outputs_list),
    'linear_weights': pd.DataFrame(linear_weights),
}

os.makedirs('intermediate_outputs', exist_ok=True)
save_path = os.path.join('intermediate_outputs', f"{DATASET.__name__}_{MODEL.__name__}_{rnd_seed}.npy")
np.save(save_path, summary, allow_pickle=True)