In [None]:
import os
import sys 
sys.path.append('../')

import torch
from trackml.dataset import load_event
import numpy as np
import pandas as pd
from torch_geometric.data import Dataset
from matplotlib import pyplot as plt
import mplhep as hep
hep.style.use("CMS")
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
evtid = '21512'
prefix = f'/tigress/jdezoort/codalab/train_1/event0000{evtid}'
hits, particles, truth = load_event(
        prefix, parts=['hits', 'particles', 'truth'])

In [None]:
import os
from os.path import join
import torch
from torch_geometric.data import Dataset, Data

def calc_eta(r, z):
    theta = np.arctan2(r, z)
    eta = -1.0 * np.log(np.tan(theta / 2.0))

def append_features(hits, particles, truth):
    particles['pt'] = np.sqrt(particles.px**2 + 
                              particles.py**2)
    particles['eta_pt'] = calc_eta(particles.pt, 
                                   particles.pz)
    truth = (truth[['hit_id', 'particle_id']]
             .merge(particles[['particle_id', 'pt', 'eta_pt', 'q', 'vx', 'vy']], 
                    on='particle_id'))
    hits['r'] = np.sqrt(hits.x**2 + hits.y**2)
    hits['phi'] = np.arctan2(hits.y, hits.x)
    hits['eta_rz'] = calc_eta(hits.r, hits.z)
    hits['u'] = hits['x']/(hits['x']**2 + hits['y']**2)
    hits['v'] = hits['y']/(hits['x']**2 + hits['y']**2)
    hits = (hits[['hit_id', 'r', 'phi', 'eta_rz', 
                  'x', 'y', 'z', 'u', 'v', 'volume_id']]
            .merge(truth[['hit_id', 'particle_id', 'pt', 'eta_pt']], 
                          on='hit_id'))
    data = Data(x=hits[['x', 'y', 'z', 'r', 'phi', 'eta_rz', 'u', 'v']].values,
                particle_id=hits['particle_id'].values, 
                pt=hits['pt'].values)
    return data

class TrackClouds(Dataset):
    def __init__(self, root: str, processed_file_dir: str,
                 n_sectors: int, pre_transform=None):
        self.root = root
        self.processed_file_dir = processed_file_dir
        self.raw_file_path = root
        self.processed_file_path = processed_file_dir
        self.n_sectors = n_sectors
        self.idx_dict = {}
        counter = 0
        for i in range(1000):
            for j in range(self.n_sectors):
                self.idx_dict[counter] = (i, j)
                
        suffix = '-hits.csv.gz'
        self.prefixes, self.exists = [], {}
        for p in os.listdir(self.raw_file_path):
            if str(p).endswith(suffix):
                prefix = str(p).replace(suffix, '')
                evtid = int(prefix[-9:])
                if f'data{evtid}_s0.pt' in self.processed_file_names:
                    self.exists[evtid] = True
                else: self.exists[evtid] = False
                self.prefixes.append(prefix)
        self.dataset = []
       
        super(TrackClouds, self).__init__(processed_file_dir,
                                          pre_transform=pre_transform)
        
    @property
    def raw_file_names(self):
        return os.listdir(self.raw_file_path)

    @property
    def processed_file_names(self):
        return os.listdir(self.processed_file_path)
        
    def len(self) -> int:
        return len(self.dataset)

    def get(self, idx: int) -> Data:
        evtid, s = self.idx_dict[idx]
        name = f'data{evtid}_s{s}.pt'
        return torch.load(join(self.processed_dir, name))
        
    def process(self):
        idx = 0
        for i, f in enumerate(self.prefixes):
            s = 0
            evtid = int(f[-9:])
            name=f'data{evtid}_s{s}.pt'
            if self.exists[evtid]: 
                print(join(self.processed_dir, name))
                data = torch.load(join(self.processed_dir, name))
                self.dataset.append(data)
                continue
            print('Processing', evtid)
            hits, particles, truth = load_event(
                f, parts=['hits', 'particles', 'truth']
            )
            data = self.pre_transform(hits, particles, truth)
            torch.save(data, join(self.processed_dir, name))
            self.dataset.append(data)
            idx += 1


In [None]:
import os
from os.path import join
import torch
from torch_geometric.data import Data

class PointClouds():
    def __init__(self, outdir: str, indir: str,
                 n_sectors: int, redo=False):
        self.outdir = outdir
        self.indir = indir
        self.n_sectors = n_sectors
        self.redo = redo
        self.idx_dict = {}
        counter = 0
        for i in range(1000):
            for j in range(self.n_sectors):
                self.idx_dict[counter] = (i, j)
                
        suffix = '-hits.csv.gz'
        self.prefixes, self.exists = [], {}
        for p in os.listdir(self.indir):
            if str(p).endswith(suffix):
                prefix = str(p).replace(suffix, '')
                evtid = int(prefix[-9:])
                if f'data{evtid}_s0.pt' in os.listdir(outdir):
                    self.exists[evtid] = True
                else: self.exists[evtid] = False
                self.prefixes.append(join(indir, prefix))
                
        self.data_list = []
        self.process()
    
    def calc_eta(self, r, z):
        theta = np.arctan2(r, z)
        return -1.0 * np.log(np.tan(theta / 2.0))
                
    def append_features(self, hits, particles, truth):
        particles['pt'] = np.sqrt(particles.px**2 + 
                                  particles.py**2)
        particles['eta_pt'] = self.calc_eta(particles.pt, 
                                            particles.pz)
        truth = (truth[['hit_id', 'particle_id']]
                 .merge(particles[['particle_id', 'pt', 'eta_pt', 'q', 'vx', 'vy']], 
                        on='particle_id'))
        hits['r'] = np.sqrt(hits.x**2 + hits.y**2)
        hits['phi'] = np.arctan2(hits.y, hits.x)
        hits['eta_rz'] = self.calc_eta(hits.r, hits.z)
        hits['u'] = hits['x']/(hits['x']**2 + hits['y']**2)
        hits['v'] = hits['y']/(hits['x']**2 + hits['y']**2)
        hits = (hits[['hit_id', 'r', 'phi', 'eta_rz', 
                      'x', 'y', 'z', 'u', 'v', 'volume_id']]
                .merge(truth[['hit_id', 'particle_id', 'pt', 'eta_pt']], 
                              on='hit_id'))
        data = Data(x=hits[['x', 'y', 'z', 'r', 'phi', 'eta_rz', 'u', 'v']].values,
                    particle_id=hits['particle_id'].values, 
                    pt=hits['pt'].values)
        return data

    def process(self):
        for i, f in enumerate(self.prefixes):
            print(f)
            s = 0
            evtid = int(f[-9:])
            name=f'data{evtid}_s{s}.pt'
            if self.exists[evtid] and not self.redo:
                data = torch.load(join(self.outdir, name))
                self.data_list.append(data)
            else: 
                hits, particles, truth = load_event(
                    f, parts=['hits', 'particles', 'truth']
                )
                data = self.append_features(hits, particles, truth)
                torch.save(data, join(self.outdir, name))
                self.data_list.append(data)

In [None]:
from torch_geometric.loader import DataListLoader
tc = PointClouds(indir='/tigress/jdezoort/codalab/train_1', outdir='../point_clouds/',
                 n_sectors=1)
for i in tc:
    print(i)

In [None]:
loader = DataListLoader(tc.data_list, batch_size=1)
for l in loader: print(l)

In [None]:
import os
import torch
from torch_geometric.data import InMemoryDataset, download_url

class TrackClouds(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return os.listdir(self.root)

    @property
    def processed_file_names(self):
        return ['data{e}_s0.pt' 
                for e in np.arange(21000, 22000, 1)]

    def process(self):
        data_list = []
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
        #prefixes = [str(p).replace(suffix, '')
        #            for p in self.raw_file_path.iterdir()
        #            if str(p).endswith(suffix)]