## Test with Regression Activation Vector

In this notebook we calculate the sensitivity score and the Br score

In [1]:
import os
import time
import shutil
import zipfile
import tarfile
import urllib
import requests
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict, Counter
import h5py as h5
import argparse

In [2]:
import numpy as np
import awkward as ak
import uproot
import vector
vector.register_awkward()

In [3]:
import matplotlib.pyplot as plt
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from torch.utils.data import DataLoader
from weaver.utils.dataset import SimpleIterDataset
from weaver.train import test_load, model_setup
from weaver.utils.logger import _logger, _configLogger

## Mimic Args

Copied from train.py

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--regression-mode', action='store_true', default=False,
                    help='run in regression mode if this flag is set; otherwise run in classification mode')
parser.add_argument('-c', '--data-config', type=str,
                    help='data config YAML file')
parser.add_argument('--extra-selection', type=str, default=None,
                    help='Additional selection requirement, will modify `selection` to `(selection) & (extra)` on-the-fly')
parser.add_argument('--extra-test-selection', type=str, default=None,
                    help='Additional test-time selection requirement, will modify `test_time_selection` to `(test_time_selection) & (extra)` on-the-fly')
parser.add_argument('-i', '--data-train', nargs='*', default=[],
                    help='training files; supported syntax:'
                         ' (a) plain list, `--data-train /path/to/a/* /path/to/b/*`;'
                         ' (b) (named) groups [Recommended], `--data-train a:/path/to/a/* b:/path/to/b/*`,'
                         ' the file splitting (for each dataloader worker) will be performed per group,'
                         ' and then mixed together, to ensure a uniform mixing from all groups for each worker.'
                    )
parser.add_argument('-l', '--data-val', nargs='*', default=[],
                    help='validation files; when not set, will use training files and split by `--train-val-split`')
parser.add_argument('-t', '--data-test', nargs='*', default=[],
                    help='testing files; supported syntax:'
                         ' (a) plain list, `--data-test /path/to/a/* /path/to/b/*`;'
                         ' (b) keyword-based, `--data-test a:/path/to/a/* b:/path/to/b/*`, will produce output_a, output_b;'
                         ' (c) split output per N input files, `--data-test a%%10:/path/to/a/*`, will split per 10 input files')
parser.add_argument('--data-fraction', type=float, default=1,
                    help='fraction of events to load from each file; for training, the events are randomly selected for each epoch')
parser.add_argument('--file-fraction', type=float, default=1,
                    help='fraction of files to load; for training, the files are randomly selected for each epoch')
parser.add_argument('--fetch-by-files', action='store_true', default=False,
                    help='When enabled, will load all events from a small number (set by ``--fetch-step``) of files for each data fetching. '
                         'Otherwise (default), load a small fraction of events from all files each time, which helps reduce variations in the sample composition.')
parser.add_argument('--fetch-step', type=float, default=0.01,
                    help='fraction of events to load each time from every file (when ``--fetch-by-files`` is disabled); '
                         'Or: number of files to load each time (when ``--fetch-by-files`` is enabled). Shuffling & sampling is done within these events, so set a large enough value.')
parser.add_argument('--in-memory', action='store_true', default=False,
                    help='load the whole dataset (and perform the preprocessing) only once and keep it in memory for the entire run')
parser.add_argument('--train-val-split', type=float, default=0.8,
                    help='training/validation split fraction')
parser.add_argument('--no-remake-weights', action='store_true', default=False,
                    help='do not remake weights for sampling (reweighting), use existing ones in the previous auto-generated data config YAML file')
parser.add_argument('--demo', action='store_true', default=False,
                    help='quickly test the setup by running over only a small number of events')
parser.add_argument('--lr-finder', type=str, default=None,
                    help='run learning rate finder instead of the actual training; format: ``start_lr, end_lr, num_iters``')
parser.add_argument('--tensorboard', type=str, default=None,
                    help='create a tensorboard summary writer with the given comment')
parser.add_argument('--tensorboard-custom-fn', type=str, default=None,
                    help='the path of the python script containing a user-specified function `get_tensorboard_custom_fn`, '
                         'to display custom information per mini-batch or per epoch, during the training, validation or test.')
parser.add_argument('-n', '--network-config', type=str,
                    help='network architecture configuration file; the path must be relative to the current dir')
parser.add_argument('-o', '--network-option', nargs=2, action='append', default=[],
                    help='options to pass to the model class constructor, e.g., `--network-option use_counts False`')
parser.add_argument('-m', '--model-prefix', type=str, default='models/{auto}/network',
                    help='path to save or load the model; for training, this will be used as a prefix, so model snapshots '
                         'will saved to `{model_prefix}_epoch-%%d_state.pt` after each epoch, and the one with the best '
                         'validation metric to `{model_prefix}_best_epoch_state.pt`; for testing, this should be the full path '
                         'including the suffix, otherwise the one with the best validation metric will be used; '
                         'for training, `{auto}` can be used as part of the path to auto-generate a name, '
                         'based on the timestamp and network configuration')
parser.add_argument('--load-model-weights', type=str, default=None,
                    help='initialize model with pre-trained weights')
parser.add_argument('--exclude-model-weights', type=str, default=None,
                    help='comma-separated regex to exclude matched weights from being loaded, e.g., `a.fc..+,b.fc..+`')
parser.add_argument('--freeze-model-weights', type=str, default=None,
                    help='comma-separated regex to freeze matched weights from being updated in the training, e.g., `a.fc..+,b.fc..+`')
parser.add_argument('--num-epochs', type=int, default=20,
                    help='number of epochs')
parser.add_argument('--steps-per-epoch', type=int, default=None,
                    help='number of steps (iterations) per epochs; '
                         'if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples')
parser.add_argument('--steps-per-epoch-val', type=int, default=None,
                    help='number of steps (iterations) per epochs for validation; '
                         'if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples')
parser.add_argument('--samples-per-epoch', type=int, default=None,
                    help='number of samples per epochs; '
                         'if neither of `--steps-per-epoch` or `--samples-per-epoch` is set, each epoch will run over all loaded samples')
parser.add_argument('--samples-per-epoch-val', type=int, default=None,
                    help='number of samples per epochs for validation; '
                         'if neither of `--steps-per-epoch-val` or `--samples-per-epoch-val` is set, each epoch will run over all loaded samples')
parser.add_argument('--optimizer', type=str, default='ranger', choices=['adam', 'adamW', 'radam', 'ranger'],  # TODO: add more
                    help='optimizer for the training')
parser.add_argument('--optimizer-option', nargs=2, action='append', default=[],
                    help='options to pass to the optimizer class constructor, e.g., `--optimizer-option weight_decay 1e-4`')
parser.add_argument('--lr-scheduler', type=str, default='flat+decay',
                    choices=['none', 'steps', 'flat+decay', 'flat+linear', 'flat+cos', 'one-cycle'],
                    help='learning rate scheduler')
parser.add_argument('--warmup-steps', type=int, default=0,
                    help='number of warm-up steps, only valid for `flat+linear` and `flat+cos` lr schedulers')
parser.add_argument('--load-epoch', type=int, default=None,
                    help='used to resume interrupted training, load model and optimizer state saved in the `epoch-%%d_state.pt` and `epoch-%%d_optimizer.pt` files')
parser.add_argument('--start-lr', type=float, default=5e-3,
                    help='start learning rate')
parser.add_argument('--batch-size', type=int, default=128,
                    help='batch size')
parser.add_argument('--use-amp', action='store_true', default=False,
                    help='use mixed precision training (fp16)')
parser.add_argument('--gpus', type=str, default='0',
                    help='device for the training/testing; to use CPU, set to empty string (""); to use multiple gpu, set it as a comma separated list, e.g., `1,2,3,4`')
parser.add_argument('--predict-gpus', type=str, default=None,
                    help='device for the testing; to use CPU, set to empty string (""); to use multiple gpu, set it as a comma separated list, e.g., `1,2,3,4`; if not set, use the same as `--gpus`')
parser.add_argument('--num-workers', type=int, default=1,
                    help='number of threads to load the dataset; memory consumption and disk access load increases (~linearly) with this numbers')
parser.add_argument('--predict', action='store_true', default=False,
                    help='run prediction instead of training')
parser.add_argument('--predict-output', type=str,
                    help='path to save the prediction output, support `.root` and `.parquet` format')
parser.add_argument('--export-onnx', type=str, default=None,
                    help='export the PyTorch model to ONNX model and save it at the given path (path must ends w/ .onnx); '
                         'needs to set `--data-config`, `--network-config`, and `--model-prefix` (requires the full model path)')
parser.add_argument('--onnx-opset', type=int, default=15,
                    help='ONNX opset version.')
parser.add_argument('--io-test', action='store_true', default=False,
                    help='test throughput of the dataloader')
parser.add_argument('--copy-inputs', action='store_true', default=False,
                    help='copy input files to the current dir (can help to speed up dataloading when running over remote files, e.g., from EOS)')
parser.add_argument('--log', type=str, default='',
                    help='path to the log file; `{auto}` can be used as part of the path to auto-generate a name, based on the timestamp and network configuration')
parser.add_argument('--print', action='store_true', default=False,
                    help='do not run training/prediction but only print model information, e.g., FLOPs and number of parameters of a model')
parser.add_argument('--profile', action='store_true', default=False,
                    help='run the profiler')
parser.add_argument('--backend', type=str, choices=['gloo', 'nccl', 'mpi'], default=None,
                    help='backend for distributed training')
parser.add_argument('--cross-validation', type=str, default=None,
                    help='enable k-fold cross validation; input format: `variable_name%%k`')
parser.add_argument('--disable-mps', action='store_true', default=False,
                    help='disable using mps device if it does not work for you')
parser.add_argument('--hidden-states-out', type=str, default="hidden_states_out.h5",
                    help='path to save hidden states as h5 file')
parser.add_argument('--hidden-states', action='store_true', default=False,
                    help='let ParT output hidden states with the logits')

_StoreTrueAction(option_strings=['--hidden-states'], dest='hidden_states', nargs=0, const=True, default=False, type=None, choices=None, help='let ParT output hidden states with the logits', metavar=None)

manually parse args

In [6]:
args = parser.parse_args([
    "--predict",
    "--data-test", "/Users/billyli/scope/JetClass/minimal/*.root",
    "--data-config", "/Users/billyli/scope/particle_transformer/data/JetClass/JetClass_full.yaml",
    "--data-fraction", "0.001",
    "--network-config", "/Users/billyli/scope/particle_transformer/networks/ParT_w_hidden_states.py",
    "--use-amp",
    "--model-prefix", "/Users/billyli/scope/particle_transformer/models/ParT_full.pt",
    "--batch-size", "128",
    "--predict-output", "/Users/billyli/scope/particle_transformer/tmp/test_output.root",
    "--gpus", "",
    "--num-workers", "0",
    "--disable-mps",
    "--hidden-states"
])

print(args)

Namespace(backend=None, batch_size=128, copy_inputs=False, cross_validation=None, data_config='/Users/billyli/scope/particle_transformer/data/JetClass/JetClass_full.yaml', data_fraction=0.001, data_test=['/Users/billyli/scope/JetClass/minimal/*.root'], data_train=[], data_val=[], demo=False, disable_mps=True, exclude_model_weights=None, export_onnx=None, extra_selection=None, extra_test_selection=None, fetch_by_files=False, fetch_step=0.01, file_fraction=1, freeze_model_weights=None, gpus='', hidden_states=True, hidden_states_out='hidden_states_out.h5', in_memory=False, io_test=False, load_epoch=None, load_model_weights=None, log='', lr_finder=None, lr_scheduler='flat+decay', model_prefix='/Users/billyli/scope/particle_transformer/models/ParT_full.pt', network_config='/Users/billyli/scope/particle_transformer/networks/ParT_w_hidden_states.py', network_option=[], no_remake_weights=False, num_epochs=20, num_workers=0, onnx_opset=15, optimizer='ranger', optimizer_option=[], predict=True, 

## Dataloader

In [7]:
test_loaders, data_config = test_load(args)

## Load Model

In [8]:
if args.gpus:
    # distributed training
    if args.backend is not None:
        local_rank = args.local_rank
        torch.cuda.set_device(local_rank)
        gpus = [local_rank]
        dev = torch.device(local_rank)
        torch.distributed.init_process_group(backend=args.backend)
        _logger.info(f'Using distributed PyTorch with {args.backend} backend')
    else:
        gpus = [int(i) for i in args.gpus.split(',')]
        dev = torch.device(gpus[0])
else:
    gpus = None
    dev = torch.device('cpu')
    if not args.disable_mps:
        try:
            if torch.backends.mps.is_available():
                dev = torch.device('mps')
        except AttributeError:
            pass

In [9]:
model, model_info, loss_func = model_setup(args, data_config, device=dev)
model = model.to(dev)
model_path = args.model_prefix if args.model_prefix.endswith(
                '.pt') else args.model_prefix + '_best_epoch_state.pt'
_logger.info('Loading model %s for eval' % model_path)
model.load_state_dict(torch.load(model_path, map_location=dev))
if gpus is not None and len(gpus) > 1:
    model = torch.nn.DataParallel(model, device_ids=gpus)
model = model.to(dev)

Loss function not defined in /Users/billyli/scope/particle_transformer/networks/ParT_w_hidden_states.py. Will use `torch.nn.CrossEntropyLoss()` by default.


## Calculte Gradient

In [10]:
for name, get_test_loader in test_loaders.items():
    test_loader = get_test_loader()

In [11]:
model.eval()
data_config = test_loader.dataset.config

label_counter = Counter()
total_loss = 0
num_batches = 0
total_correct = 0
entry_count = 0
count = 0
scores = []
labels = defaultdict(list)
labels_counts = []
observers = defaultdict(list)
hiddens = []
start_time = time.time()

In [12]:
model.register_tcav_layer("mod.norm") 

In [22]:

gs = []
with tqdm(test_loader) as tq:
    for X, y, Z in tq:
        # X, y: torch.Tensor; Z: ak.Array
        inputs = [X[k].to(dev) for k in data_config.input_names]
        label = y[data_config.label_names[0]].long().to(dev)
        entry_count += label.shape[0]

        try:
            mask = y[data_config.label_names[0] + '_mask'].bool().to(dev)
        except KeyError:
            mask = None
        model_output = model(*inputs)
        g = model.tcav_grads_from_batch(*inputs, class_idx=8) 
        gs.append(g.cpu().numpy()  )

0it [00:00, ?it/s]

=== Restarting DataIter test_, seed=None ===


1it [00:05,  5.03s/it]

grads tensor([[-1.0705e-03, -1.7459e-04, -1.0588e-04,  ..., -1.4244e-04,
         -2.2214e-04, -1.1570e-04],
        [-1.3032e-03, -1.4071e-04,  6.4780e-05,  ...,  4.7366e-06,
         -2.3196e-04, -7.1918e-05],
        [-1.2833e-03, -4.8746e-04, -5.7918e-04,  ..., -2.4301e-04,
         -3.7723e-04, -4.8216e-04],
        ...,
        [-1.4722e-03, -3.0158e-04, -3.0450e-04,  ..., -2.9749e-04,
         -3.0271e-04, -3.1144e-04],
        [-9.1181e-04, -9.7223e-05, -2.2547e-04,  ..., -8.6482e-05,
         -1.7573e-04, -1.0130e-04],
        [-8.1403e-04, -5.2182e-04, -3.5538e-04,  ..., -1.4816e-04,
         -5.4643e-04, -5.2116e-04]])
acts tensor([[ 0.8622,  0.6470, -2.0281,  ..., -0.6050,  2.4926, -1.6463],
        [-0.9107, -1.2412, -3.5839,  ..., -2.8996, -0.2032, -2.0256],
        [ 1.6379, -2.6126, -3.8140,  ...,  0.5892, -1.1663, -2.5431],
        ...,
        [ 4.0475, -1.1442, -1.6574,  ..., -0.4289, -1.3053, -2.8669],
        [-0.9764, -0.4170,  2.1545,  ..., -0.6325,  1.1542, -0.3

2it [00:11,  5.70s/it]

grads tensor([[-1.2117e-03, -2.5649e-04, -2.6270e-04,  ..., -2.4044e-04,
         -2.5798e-04, -2.5096e-04],
        [-8.4738e-04, -3.9467e-04, -8.7526e-04,  ..., -3.6555e-04,
         -5.0501e-04, -5.0358e-04],
        [-1.1464e-03, -2.6904e-04, -2.4613e-04,  ..., -1.7773e-04,
         -3.0074e-04, -2.6031e-04],
        ...,
        [-1.1131e-03, -3.4293e-04, -5.1087e-04,  ..., -2.1375e-04,
         -2.8746e-04, -3.4646e-04],
        [-1.1316e-03, -2.7434e-04, -3.6717e-04,  ..., -2.3015e-04,
         -2.9905e-04, -2.0400e-04],
        [-1.0235e-03, -3.2016e-04, -2.5542e-04,  ...,  9.4877e-06,
         -3.9651e-04, -3.5732e-04]])
acts tensor([[ 3.2385, -0.1857, -0.7922,  ...,  1.3820, -0.3139,  0.3554],
        [ 6.2643,  0.3387, -3.6791,  ...,  0.5822, -0.5821, -0.5718],
        [ 2.4885,  0.3706,  0.9692,  ...,  2.7578, -0.4536,  0.5991],
        ...,
        [ 1.0460,  0.1797, -1.4636,  ...,  1.4437,  0.7238,  0.1451],
        [ 3.7467,  0.9146, -0.4704,  ...,  1.5739,  0.5486,  1.9

3it [00:17,  6.16s/it]

grads tensor([[-1.1384e-03, -2.3432e-04, -2.5627e-04,  ..., -2.1030e-04,
         -2.3883e-04, -2.3603e-04],
        [-8.2101e-04, -3.2081e-04, -4.4179e-04,  ..., -1.4804e-04,
         -3.8474e-04, -3.1290e-04],
        [-8.9343e-04, -3.1853e-04, -3.7822e-04,  ..., -9.9574e-06,
         -3.9494e-04, -2.6040e-04],
        ...,
        [-2.1702e-03, -7.9525e-04, -9.8955e-04,  ..., -1.1938e-03,
         -7.7310e-04, -8.7940e-04],
        [-1.6848e-03, -4.9071e-04, -3.3351e-04,  ..., -4.5307e-04,
         -2.8467e-04, -6.5395e-04],
        [-1.7464e-03, -8.8669e-04, -8.3429e-04,  ..., -7.4235e-04,
         -4.8730e-04, -5.4423e-04]])
acts tensor([[ 1.3439,  0.8229, -0.5327,  ...,  2.3071,  0.5546,  0.7175],
        [ 4.9724,  0.6641, -0.4539,  ...,  2.2606,  0.0749,  0.7372],
        [ 3.9028,  0.1443, -0.4222,  ...,  3.0728, -0.5792,  0.6961],
        ...,
        [-4.4938, -3.5217, -4.4721,  ..., -5.4713, -3.4124, -3.9333],
        [-7.4546, -5.9526, -2.7414,  ..., -5.1836, -1.7395, -9.2

4it [00:26,  7.07s/it]

grads tensor([[-1.5303e-03, -4.4282e-04, -4.5502e-04,  ..., -4.3107e-04,
         -2.9678e-04, -3.0052e-04],
        [-1.8287e-03, -8.0757e-04, -7.0824e-04,  ..., -7.2970e-04,
         -3.7866e-04, -6.5044e-04],
        [-1.7380e-03, -7.2167e-04, -9.7322e-04,  ..., -6.7990e-04,
         -3.8810e-04, -4.0157e-04],
        ...,
        [-1.7846e-03, -4.5350e-05, -1.9220e-04,  ..., -1.2969e-04,
         -2.9415e-04, -3.9461e-04],
        [-1.2106e-03, -2.7111e-05, -1.7389e-04,  ..., -1.5323e-04,
         -2.5238e-04, -2.4419e-04],
        [-1.0274e-03, -8.0724e-06, -6.8440e-05,  ...,  9.3098e-05,
         -2.1462e-04, -2.2813e-04]])
acts tensor([[-8.8504, -5.5690, -5.8252,  ..., -5.3223, -2.5013, -2.5832],
        [-6.1937, -6.6242, -5.7490,  ..., -5.9380, -2.8430, -5.2395],
        [-3.7838, -4.4011, -6.4968,  ..., -4.0530, -1.6206, -1.7343],
        ...,
        [ 0.3944, -4.3864, -2.7241,  ..., -3.4319, -1.5729, -0.4333],
        [-5.5672, -5.9331, -3.5983,  ..., -3.9272, -2.3534, -2.4

5it [00:30,  6.12s/it]

grads tensor([[-1.8447e-03, -3.1782e-05, -1.8794e-04,  ..., -1.7229e-04,
         -3.1673e-04, -3.2691e-04],
        [-1.2343e-03, -1.4651e-06, -9.6475e-05,  ..., -5.4668e-05,
         -2.1769e-04, -2.2948e-04],
        [-8.8162e-04,  3.2153e-05, -8.0473e-05,  ...,  8.3961e-05,
         -2.5924e-04, -2.3172e-04],
        ...,
        [-1.8734e-03,  2.0496e-04,  6.6017e-04,  ...,  3.1096e-04,
         -4.0518e-04, -2.8746e-04],
        [-5.2548e-04,  3.0447e-04,  8.8507e-04,  ...,  3.6259e-04,
         -2.6790e-04,  3.5428e-04],
        [-9.3183e-04,  4.7312e-05,  3.2358e-04,  ..., -3.8832e-07,
         -1.0902e-04,  4.0282e-04]])
acts tensor([[ 1.5984, -4.5700, -2.7167,  ..., -2.9026, -1.1910, -1.0677],
        [-4.6488, -5.6464, -4.2205,  ..., -4.8482, -2.4049, -2.2250],
        [-5.6705, -4.2015, -2.6064,  ..., -4.9353, -0.0775, -0.4648],
        ...,
        [ 1.7451, -0.7579, -2.1885,  ..., -1.0911,  1.1589,  0.7897],
        [-1.4569, -0.4925, -2.2272,  ..., -0.6662,  1.2170, -0.6

6it [00:35,  5.63s/it]

grads tensor([[-6.7028e-04,  7.5813e-04, -6.9563e-04,  ..., -1.2923e-04,
         -1.3258e-04,  2.4723e-04],
        [ 9.9851e-04,  7.0538e-04, -2.3326e-04,  ..., -3.9395e-04,
          2.1434e-04,  1.9139e-03],
        [ 1.3239e-03,  1.1354e-03, -1.2123e-03,  ...,  1.7846e-04,
          1.1477e-03,  1.2545e-03],
        ...,
        [-2.0104e-03, -6.1570e-04,  2.3081e-05,  ..., -6.7576e-04,
         -5.3849e-04, -6.5050e-04],
        [-1.5772e-03, -4.1360e-04, -7.8476e-04,  ..., -5.0743e-04,
         -4.5198e-04, -5.2917e-04],
        [-1.6611e-03, -4.9042e-04, -8.9237e-04,  ..., -4.1519e-04,
         -4.5423e-04, -4.6239e-04]])
acts tensor([[-0.4681, -0.7121,  1.6705,  ...,  0.7422,  0.7473,  0.1252],
        [-2.9043, -0.2093,  1.3752,  ...,  1.6465,  0.6192, -2.2495],
        [-3.4313, -1.3289,  1.8035,  ..., -0.0521, -1.3456, -1.4878],
        ...,
        [-4.9403, -2.1747,  2.3835,  ..., -2.6032, -1.6225, -2.4229],
        [-3.0192, -0.5733, -3.8324,  ..., -1.3972, -0.9090, -1.5

7it [00:40,  5.28s/it]

grads tensor([[-0.0018, -0.0005, -0.0007,  ..., -0.0005, -0.0005, -0.0005],
        [-0.0012, -0.0003, -0.0004,  ..., -0.0001, -0.0003, -0.0004],
        [-0.0019, -0.0005, -0.0008,  ..., -0.0005, -0.0006, -0.0006],
        ...,
        [-0.0017, -0.0008, -0.0006,  ..., -0.0006, -0.0009, -0.0006],
        [-0.0020, -0.0007, -0.0007,  ..., -0.0005, -0.0008, -0.0007],
        [-0.0010, -0.0002, -0.0002,  ..., -0.0002, -0.0002, -0.0002]])
acts tensor([[-5.1392, -1.5579, -3.5484,  ..., -1.2735, -1.7209, -1.4107],
        [-1.3121,  0.2479, -0.7614,  ...,  2.2561,  0.3236, -1.0340],
        [-4.9306, -1.6131, -4.3081,  ..., -1.6173, -2.4113, -2.2028],
        ...,
        [-0.7816, -2.0408, -1.1250,  ..., -1.3274, -2.3371, -1.1918],
        [-2.9194, -2.0298, -2.2858,  ..., -1.0394, -2.6797, -2.4895],
        [-0.5788, -1.9769,  0.8191,  ...,  0.6353, -0.7042, -2.1818]],
       grad_fn=<SelectBackward0>)


8it [00:42,  5.34s/it]

grads tensor([[-0.0018, -0.0008, -0.0005,  ..., -0.0003, -0.0010, -0.0008],
        [-0.0019, -0.0008, -0.0005,  ..., -0.0003, -0.0010, -0.0009],
        [-0.0011, -0.0002, -0.0002,  ..., -0.0002, -0.0002, -0.0002],
        ...,
        [-0.0010, -0.0004, -0.0004,  ..., -0.0002, -0.0004, -0.0004],
        [-0.0012, -0.0004, -0.0006,  ..., -0.0004, -0.0003, -0.0004],
        [-0.0014, -0.0004, -0.0006,  ..., -0.0005, -0.0004, -0.0004]])
acts tensor([[-1.0064, -2.0154, -1.0625,  ..., -0.1655, -2.9232, -2.0782],
        [-1.3858, -1.9703, -0.7161,  ..., -0.0759, -2.8316, -2.3233],
        [-1.4903, -0.2969, -2.4878,  ..., -0.6575, -0.5126, -0.2503],
        ...,
        [ 2.3178,  0.1889, -0.4218,  ...,  1.5670,  0.2606,  0.3065],
        [ 0.6941, -0.5219, -4.8899,  ...,  0.0295,  1.2677, -1.2587],
        [-1.2034, -0.1063, -1.8450,  ..., -1.4266,  0.1330, -0.2339]],
       grad_fn=<SelectBackward0>)





In [23]:
gs = np.concatenate(gs, axis=0)

## Load cav

In [24]:
cav = h5.File('/Users/billyli/scope/particle_transformer/notebooks/rav_sdmass.h5', 'r')['RAV_jet_sdmass'][:]
cav = cav[:-1]

In [25]:
normalized_cav = cav / np.linalg.norm(cav)


In [26]:
cadir_deriv = gs @ normalized_cav  # shape (N,)
tcav = float((cadir_deriv  > 0).mean())
print(tcav)

1.0


In [27]:
gs[:, 0]

array([-1.0705204e-03, -1.3031547e-03, -1.2833124e-03, -1.0863956e-03,
       -1.2782087e-03, -1.2245905e-03, -1.1331652e-03, -8.7467348e-04,
       -1.4419832e-03, -1.1401846e-03, -1.1082630e-03, -1.1594278e-03,
       -1.3063160e-03, -1.0285192e-03, -1.3628390e-03, -8.8517158e-04,
       -1.2003947e-03, -1.3494641e-03, -1.2737323e-03, -1.1052706e-03,
       -1.0131560e-03, -9.6922024e-04, -1.3155468e-03, -1.1098369e-03,
       -1.3614972e-03, -1.3931772e-03, -1.1406268e-03, -1.2890535e-03,
       -1.6585081e-03, -1.2997587e-03, -1.1295517e-03, -9.7025395e-04,
       -1.4697411e-03, -1.1494958e-03, -1.6342890e-03, -1.3695704e-03,
       -1.5165720e-03, -1.2361638e-03, -1.2402132e-03, -1.1187964e-03,
       -1.4018831e-03, -1.3508550e-03, -1.3522693e-03, -1.3080584e-03,
       -1.4560594e-03, -1.2622641e-03, -1.4761935e-03, -1.1277983e-03,
       -1.2292753e-03, -1.3014469e-03, -1.1713859e-03, -1.1040382e-03,
       -1.3137029e-03, -1.4429954e-03, -8.0329261e-04, -2.2106231e-03,
      

In [28]:
cadir_deriv

array([[0.00061564],
       [0.00080163],
       [0.00076082],
       [0.00076636],
       [0.00074594],
       [0.00070806],
       [0.00061166],
       [0.00068055],
       [0.00075212],
       [0.00068854],
       [0.00058322],
       [0.00058469],
       [0.0007332 ],
       [0.00063378],
       [0.00077886],
       [0.00060006],
       [0.00073154],
       [0.00076531],
       [0.00074341],
       [0.00066724],
       [0.00075736],
       [0.00069177],
       [0.00078488],
       [0.00077366],
       [0.00068688],
       [0.00073133],
       [0.00071468],
       [0.0006862 ],
       [0.00071113],
       [0.00073873],
       [0.00067779],
       [0.0006057 ],
       [0.00080783],
       [0.00073351],
       [0.00078332],
       [0.00067745],
       [0.00064432],
       [0.00069073],
       [0.00075214],
       [0.00072202],
       [0.00073741],
       [0.00075096],
       [0.00083136],
       [0.00074581],
       [0.00076608],
       [0.00071389],
       [0.00077273],
       [0.000

In [20]:
model._tcav_acts.size()

torch.Size([1, 104, 128])

In [21]:
normalized_cav

array([[-0.03297416],
       [ 0.33101214],
       [-0.11765189],
       [-0.04690461],
       [ 0.03210919],
       [-0.04460833],
       [-0.02226003],
       [-0.04631099],
       [ 0.00991599],
       [-0.17100279],
       [-0.05036304],
       [-0.13745825],
       [ 0.00794413],
       [ 0.02099918],
       [-0.00250691],
       [ 0.22482662],
       [ 0.00728835],
       [-0.00522326],
       [-0.0212866 ],
       [-0.00739012],
       [ 0.01401818],
       [ 0.10765859],
       [-0.07550145],
       [ 0.07931605],
       [ 0.01503608],
       [ 0.07150803],
       [-0.07139318],
       [-0.05169635],
       [-0.00820409],
       [ 0.12700651],
       [-0.0331678 ],
       [ 0.09080516],
       [-0.02124833],
       [-0.06124515],
       [ 0.00797757],
       [-0.05143098],
       [ 0.16436235],
       [-0.05247383],
       [-0.04136527],
       [-0.16268531],
       [ 0.13618894],
       [ 0.01025466],
       [-0.03644122],
       [-0.14382278],
       [ 0.00599246],
       [-0