In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import scipy as sp
import wandb

import warnings

warnings.filterwarnings("ignore")
sys.path.append("../../")
device = "cuda" if torch.cuda.is_available() else "cpu"

from lightning_modules.jetGNN.submodels.interaction_gnn import InteractionGNN
from lightning_modules.jetGNN.submodels.gravnet import GravNet

In [2]:
# Load the config file
config_file = "jet_tag_config.yaml"
with open(config_file, "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
# 1. Load each dataset in [train, val, test]
# 2. For each data object in dataset, torch load
# 3. Check if each key in data object is not nan
# 4. If does not contain nan, save to new dataset

# Load the datasets
input_dir = config["input_dir"]
for subdir in ["train"]:
    print(f"Loading {subdir} dataset")
    subdir_files = os.listdir(os.path.join(input_dir, subdir))
    subdir_files = [f for f in subdir_files if f.endswith(".pt")]
    print(f"Found {len(subdir_files)} files")
    subdir_files = [os.path.join(input_dir, subdir, file) for file in subdir_files]
    for file in tqdm(subdir_files[8:]):
        need_to_save = False
        dataset = torch.load(file)
        # Check if any nan values in sample
        for data in dataset:
            for key in data.keys:
                if torch.isnan(data[key]).any():
                    print("Nan value found in sample")
                    print(data)
                    # Remove sample from dataset
                    dataset.remove(data)
                    need_to_save = True
        # Save the dataset
        if need_to_save:
            print(f"Saving dataset with no nan values, with length {len(dataset)}")
            torch.save(dataset, file)
        

Loading train dataset
Found 12 files


 50%|█████     | 2/4 [01:29<01:31, 45.62s/it]

Nan value found in sample
Data(y=0.0, pE=[1], px=[1], py=[1], pz=[1], log_pt=[1], log_E=[1], delta_pt=[1], log_delta_pt=[1], delta_E=[1], log_delta_E=[1], delta_R=[1], delta_eta=[1], delta_phi=[1], jet_pt=0.5918821692466736, jet_pE=0.7195183634757996, jet_px=-0.5479620695114136, jet_py=0.2237454652786255, jet_pz=0.4091237187385559, jet_mass=nan, jet_eta=0.20556142926216125, jet_phi=0.8770483732223511)
Saving dataset with no nan values, with length 99999


100%|██████████| 4/4 [03:58<00:00, 59.62s/it]
