In [1]:
import itertools
import os
import yaml

import numpy as np
import seaborn as sns
import torch
import torch.nn as nn

from source.data import jetclass, jetnet, topqcd
from source.data.datamodule import JetLightningDataModule
from source.models.litmodel import TorchLightningModule
from source.models.part import ParticleTransformer, AttentionBlock
from source.models.pnet import ParticleNet

dataset = jetclass
model_class = ParticleTransformer
train_model = False

num_data = 10
num_epochs = 10

In [2]:
def get_hook_fn(name: str, intermediate_outputs: dict[str, torch.Tensor]):
    def hook_fn(module, input, output):
        intermediate_outputs[name] = output.detach().cpu()
    return hook_fn

def get_intermediate_outputs(model: ParticleTransformer) -> dict[str, torch.Tensor]:
    intermediate_outputs: dict[str, torch.Tensor] = {}

    # 8 particle attention blocks
    for i in range(8):
        attn_block: AttentionBlock = model.par_attn_blocks[i]
        attn_softmax = attn_block.attn.softmax

        hook_fn = get_hook_fn(f"par_{i}", intermediate_outputs)
        attn_softmax.register_forward_hook(hook_fn)

    return intermediate_outputs

def load_ckpt(model: nn.Module, epoch: int) -> nn.Module:
    dataset_name: str = dataset.__name__.split('.')[-1]
    ckpt_dir = os.path.join('training_logs', dataset_name, model.__class__.__name__, 'lastest_run', 'checkpoints')
    ckpt_path = os.path.join(ckpt_dir, f"epoch={epoch}.ckpt")
    ckpt = torch.load(ckpt_path, weights_only=True)
    state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()}
    model.load_state_dict(state_dict)

In [3]:
# Model setup (`yaml` configuration file)
with open(f"configs/{model_class.__name__}.yaml", 'r') as file:
    hparams = yaml.safe_load(file)[model_class.__name__]
    score_dim: int = len(dataset.channels)
    model = model_class(score_dim=score_dim, parameters=hparams)

particle_features: dict[tuple[str, int], list[tuple[str, np.ndarray]]] = {}
intermediate_outputs: dict[tuple[str, int, int, int], np.ndarray] = {}

for channel in dataset.channels:
    # Create a batch_size == 1 data_module
    data_module = JetLightningDataModule(
        events_list=[dataset.JetEvents(channel=channel)],
        num_train=-1,
        num_valid=-1,
        num_test=-1,
        batch_size=1,
    )

    for epoch_index in range(num_epochs):
        # Load the model from the checkpoint
        load_ckpt(model, epoch_index)
        model.eval()

        # Hook the model to extract intermediate outputs
        _intermediate_outputs = get_intermediate_outputs(model)

        for data_index in range(num_data):
            # Load the data with index == data_index
            x, y_true = data_module.data_list[0][data_index]

            # Remove padded particles
            x = x[torch.all(torch.isfinite(x), dim=-1)]
            x = x.unsqueeze(0)

            # Save the features
            if epoch_index == 0:
                fields = data_module.fields
                selected_fields = [
                    'log_part_pt',
                    'part_charge',
                    'part_dR',
                ]
                _particle_features = []
                for i, field in enumerate(fields):
                    if field in selected_fields:
                        _feature = x[0, :, i].detach().cpu().numpy().reshape(-1)
                        _min_value = np.min(_feature)
                        _max_value = np.max(_feature)
                        _pad_value = 2 * _min_value - _max_value
                        _diagonal = np.full((len(_feature), len(_feature)), _pad_value)
                        np.fill_diagonal(_diagonal, _feature)
                        _particle_features.append((field, _diagonal))
                particle_features[(channel, data_index)] = _particle_features

            y_pred = model(x)

            for block_index in range(8):
                intermediate_outputs[(channel, data_index, epoch_index, block_index)] = _intermediate_outputs[f"par_{block_index}"].squeeze(0)

summary = {
    'num_data': num_data,
    'num_epochs': num_epochs,
    'particle_features': particle_features,
    'intermediate_outputs': intermediate_outputs,
}

np.save(f"{dataset.__name__.split('.')[-1]}.npy", summary, allow_pickle=True)

Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.63s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.84s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:01<00:00,  1.83s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.34s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.77s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.09s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.08s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.94s/it]
Creating JetLightningDataModule: 100%|██████████| 1/1 [00:02<00:00,  2.39s/it]
