In [None]:
import os.path as osp
import os
import random
import pandas as pd
from torch_geometric.data import Data, Dataset, DataLoader
from tqdm import tqdm
import glob
import pandas
import torch
import numpy as np
import multiprocessing as mp
import random
import plotly.graph_objects as go
from torch_geometric.data import Data, Dataset, DataLoader
import matplotlib._color_data as mcd
"libraries for debugging"
import sys
import os.path as osp
"custom imports"
from torch_cmspepr.gravnet_model import GravnetModel
from torch_cmspepr.objectcondensation import calc_LV_Lbeta, formatted_loss_components_string

In [None]:
class ColorWheel:
    '''Returns a consistent color when given the same object'''
    def __init__(self, colors=None, seed=44):
        if colors is None:
            self.colors = list(mcd.XKCD_COLORS.values())
        else:
            self.colors = colors
        np.random.seed(seed)
        np.random.shuffle(self.colors)
        self._original_colors = self.colors.copy()
        self.assigned_colors = {}
        
    def __call__(self, thing):
        key = id(thing)
        if key in self.assigned_colors:
            return self.assigned_colors[key]
        else:
            color = self.colors.pop()
            self.assigned_colors[key] = color
            if not(self.colors): self.colors = self._original_colors.copy()
            return color
    
    def assign(self, thing, color):
        """Assigns a specific color to a thing"""
        key = id(thing)
        self.assigned_colors[key] = color
        if color in self.colors: self.colors.remove(color)


def get_plotly_clusterspace(event, 
                            cluster_space_coords, 
                            clustering=None,
                            size = 1.00):
    assert cluster_space_coords.size(1) == 2
    import plotly.graph_objects as go

    colorwheel = ColorWheel()
    colorwheel.assign(0, '#bfbfbf')
    colorwheel.assign(-1, '#bfbfbf')

    data = []

    if clustering is None: clustering = event.y

    for cluster_index in np.unique(clustering):
        x = cluster_space_coords[clustering == cluster_index].numpy()
        data.append(go.Scatter(
            x=x[:,0], y=x[:,1],# z=x[:,2],
            mode='markers', 
            marker=dict(
                line=dict(width=0),
                size=size,
                color= colorwheel(int(cluster_index)),
                ),
            hovertemplate=(
                f'x=%{{y:0.2f}}<br>y=%{{z:0.2f}}<br>z=%{{x:0.2f}}'
                f'<br>clusterindex={cluster_index}'
                f'<br>'
                ),
            name = f'cluster_{cluster_index}'
            ))
    return data

def get_plotly_truth(event,
                    size = 1.00):
    colorwheel = ColorWheel()
    colorwheel.assign(0, '#bfbfbf')

    data = []

    for cluster_index in np.unique(event.y):
        x = event.x[event.y == cluster_index].numpy()
        data.append(go.Scatter3d(
            x=x[:,0], y=x[:,1], z=x[:,2],
#             x=x[:,3],y=x[:,-1],z=x[:,2],
            mode='lines+markers', 
            marker=dict(
                line=dict(width=0),
                size=size,
                color= colorwheel(int(cluster_index)),
                ),
            hovertemplate=(
                f'x=%{{y:0.2f}}<br>y=%{{z:0.2f}}<br>z=%{{x:0.2f}}'
                f'<br>clusterindex={cluster_index}'
                f'<br>'
                ),
            name = f'cluster_{cluster_index}'
            ))
    return data

In [None]:
import os.path as osp
import glob

import multiprocessing as mp
from tqdm import tqdm
import random
import torch
import pandas
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.utils import is_undirected
from torch_geometric.data import Data, Dataset
import gzip
import pdb

class TrackMLParticleTrackingDataset(Dataset):
    def __init__(self, root, 
                 transform=None, 
                 n_events=0,
                 directed=False, 
                 layer_pairs_plus=False,
                 volume_layer_ids=[[8, 2], [8, 4], [8, 6], [8, 8]], #Layers Selecte
                 layer_pairs=[[7, 8], [8, 9], [9, 10]],             #Connected Layers
                 pt_min=0.3, 
                 eta_range=[-5, 5],                     
                 phi_slope_max=0.0006, 
                 z0_max=150,                  
                 n_phi_sections=1, 
                 n_eta_sections=1,  
                 augments = False,
                 tracking=False,                   
                 n_workers=mp.cpu_count(), 
                 n_tasks=1,               
                 download_full_dataset=False                        
             ):
        hits = glob.glob(osp.join(osp.join(root,'raw'), 'event*-hits.csv'))
        self.hits = sorted(hits)
        particles = glob.glob(osp.join(osp.join(root,'raw'), 'event*-particles.csv'))
        self.particles = sorted(particles)
        truth = glob.glob(osp.join(osp.join(root,'raw'), 'event*-truth.csv'))
        self.truth = sorted(truth)
        if (n_events > 0):
            self.hits = self.hits[:n_events]
            self.particles = self.particles[:n_events]
            self.truth = self.truth[:n_events]
        self.layer_pairs_plus = layer_pairs_plus
        self.volume_layer_ids = torch.tensor(volume_layer_ids)
        self.layer_pairs      = torch.tensor(layer_pairs)
        self.pt_min           = pt_min
        self.eta_range        = eta_range
        self.n_phi_sections   = n_phi_sections
        self.n_eta_sections   = n_eta_sections
        self.full_dataset     = download_full_dataset
        self.n_events         = n_events

#         self.phi_slope_max    = phi_slope_max
#         self.z0_max           = z0_max
#         self.augments         = augments
#         self.tracking         = tracking
#         self.n_tasks          = n_tasks

        super(TrackMLParticleTrackingDataset, self).__init__(root, transform)

    @property
    def raw_file_names(self):
        #if not hasattr(self,'input_files'):
        self.input_files = sorted(glob.glob(self.raw_dir+'/*.csv'))
        return [f.split('/')[-1] for f in self.input_files]

    def len(self):
        N_events = len(self.hits)
        return N_events*self.n_phi_sections*self.n_eta_sections

    def __len__(self):
        N_events = len(self.hits)
        return N_events*self.n_phi_sections*self.n_eta_sections

    def read_events(self,idx):
        hits_filename = self.hits[idx]
        hits = pandas.read_csv(
            hits_filename, usecols=['hit_id', 'x', 'y', 'z', 'volume_id', 'layer_id', 'module_id'],
            dtype={
                'hit_id': np.int64,
                'x': np.float32,
                'y': np.float32,
                'z': np.float32,
                'volume_id': np.int64,
                'layer_id': np.int64,
                'module_id': np.int64
            })
        particles_filename = self.particles[idx]
        particles = pandas.read_csv(
            particles_filename, usecols=['particle_id', 'vx', 'vy', 'vz', 'px', 'py', 'pz', 'q', 'nhits'],
            dtype={
                'particle_id': np.int64,
                'vx': np.float32,
                'vy': np.float32,
                'vz': np.float32,
                'px': np.float32,
                'py': np.float32,
                'pz': np.float32,
                'q': np.int64,
                'nhits': np.int64
            })
        truth_filename = self.truth[idx]
        truth = pandas.read_csv(
            truth_filename, usecols=['hit_id', 'particle_id', 'tx', 'ty', 'tz', 'tpx', 'tpy', 'tpz', 'weight'],
            dtype={
                'hit_id': np.int64,
                'particle_id': np.int64,
                'tx': np.float32,
                'ty': np.float32,
                'tz': np.float32,
                'tpx': np.float32,
                'tpy': np.float32,
                'tpz': np.float32,
                'weight': np.float32
            })
        return hits,particles,truth


    def download(self):
        import os        
        from zipfile import ZipFile
        try:
            from kaggle.api.kaggle_api_extended import KaggleApi
        except ImportError:
            raise RuntimeError('please install and setup the kaggle '
                               'competition api: https://github.com/Kaggle/kaggle-api')
        
        api = KaggleApi()
        api.authenticate()
        
        kgl_comp = 'trackml-particle-identification'
        test_file = 'train_sample.zip'

        if self.full_dataset:
            kgl_file = 'trackml-particle-identification.zip'
            print('Downloading full TrackML dataset (~80GB), this may take a while...')
            api.competition_download_files(kgl_comp, 
                                           path=self.root,
                                           quiet = False,
                                           force = False)
            training_samples = None
            with ZipFile(os.path.join(self.root,kgl_file), 'r') as zf:
                training_samples = [fname for fname in filter(lambda x: 'train' in x and \
                                                                        'sample' not in x and \
                                                                        'blacklist' not in x, 
                                                              zf.namelist())]
                
                for name in tqdm(training_samples, desc='extracting zipballs'):
                    if not os.path.exists(os.path.join(self.root, name)):
                        zf.extract(name, path=self.root)
                        
            for sample in training_samples:
                with ZipFile(os.path.join(self.root,sample), 'r') as zf:
                    fnames = zf.namelist()
                    action = f'unpacking {sample}'
                    for name in tqdm(fnames, desc=action):
                        sample_dir = sample.split('.')[0] + '/'
                        if name == sample_dir:
                            continue
                        outname = os.path.join(self.raw_dir, os.path.basename(name))
                        if os.path.exists(outname):
                            raise Exception(f'{outname} already exists!')
                        with open(outname, 'wb') as fout:
                            fout.write(zf.read(name))
                                          
        else:
            kgl_file = test_file
            print('Downloading training example from TrackML dataset, only 100 training events...')
            api.competition_download_file(kgl_comp, 
                                          test_file,
                                          path=self.root,
                                          quiet = False,
                                          force = False)
            with ZipFile(os.path.join(self.root,kgl_file), 'r') as zf:
                fnames = zf.namelist()
                for name in tqdm(fnames):
                    if name == 'train_100_events/': 
                        continue
                    with open(os.path.join(self.raw_dir, os.path.basename(name)), 'wb') as fout:
                        fout.write(zf.read(name))

        events = glob.glob(osp.join(osp.join(self.root, 'raw'), 'event*-hits.csv'))
        events = [e.split(osp.sep)[-1].split('-')[0][5:] for e in events]
        self.events = sorted(events)
        if (self.n_events > 0):
            self.events = self.events[:self.n_events]


    def select_all_hits(self,hits,particles,truth, noise_label = -1):
        hits_truth = hits.merge(truth[["hit_id", "particle_id"]], on="hit_id", how="left")

        full_truth = hits_truth.merge(particles[["px","py","pz","particle_id"]], on="particle_id", how="left")
        full_truth[["px", "py", "pz"]] = full_truth[["px", "py", "pz"]].fillna(0)

        full_truth["pt"] = np.sqrt(full_truth["px"]**2 + full_truth["py"]**2)
        full_truth["r"] = np.sqrt(full_truth["x"].values**2 + full_truth["y"].values**2)
        full_truth["phi"] = np.arctan2(full_truth["y"].values, full_truth["x"].values)
        full_truth["theta"] = np.arctan2(full_truth["r"].values,full_truth["z"].values)
        full_truth["eta"] = -1*np.log(np.tan(full_truth["theta"]/2.))

        pids_unique, pids_inverse, pids_counts = np.unique(full_truth['particle_id'].values, return_inverse=True, return_counts=True)
        pids_unique = np.arange(pids_unique.size) 
        full_truth["remapped_pid"] = pids_unique[pids_inverse]
        # here we have the reconstructed number of hits for each track
        full_truth["nhits"] = pids_counts[pids_inverse]

        # now we have full information to label tracks and their hits as noise based on various properties
        full_truth['remapped_pid'].where((full_truth['nhits'] >= 2) & (full_truth['pt'] > self.pt_min), 0, inplace=True)

        # re-calculate counts and such
        pids_unique, pids_inverse, pids_counts = np.unique(full_truth['remapped_pid'].values, return_inverse=True, return_counts=True)
        pids_unique = np.arange(pids_unique.size) 
        full_truth["remapped_pid"] = pids_unique[pids_inverse]
        # here we have the reconstructed number of hits for each track
        full_truth["nhits"] = pids_counts[pids_inverse]

        # Select a subset noise + tracks to be selected as signals
        noise = full_truth[full_truth.remapped_pid == 0]
        idx = random.sample(noise.index.to_list(),100)
        idx.sort()
        noise = noise.loc[idx]
        
        signal_tracks = [i for i in range(1, 51)]
        signals = full_truth[full_truth.remapped_pid.isin(signal_tracks)]
        
        selected_hits = pd.concat([signals,noise])

        # Extract the features from the selected_hits
        x = torch.from_numpy(selected_hits['x'].values)
        y = torch.from_numpy(selected_hits['y'].values)
        theta = torch.from_numpy(selected_hits['theta'].values)
        r = torch.from_numpy(selected_hits['r'].values)
        phi = torch.from_numpy(selected_hits['phi'].values)
        z = torch.from_numpy(selected_hits['z'].values)
        eta = torch.from_numpy(selected_hits['eta'].values)
        particle_labels = torch.from_numpy(selected_hits['remapped_pid'].values)
        pos = torch.stack([x, y, z, r, theta, phi], 1)   

        #print("selected hits")
        #print(selected_hits)
        #print("selected hits")
        
        return pos, eta, particle_labels
    
    def split_detector_sections(self,pos, eta, particle_labels, phi_edges, eta_edges):
        pos_sect, particle_label_sect = [], []
        # Refer to the index of the column representing phi values in pos tensor
        phi_idx = -1
        for i in range(len(phi_edges) - 1):
            phi_mask1 = pos[:,phi_idx] > phi_edges[i]
            phi_mask2 = pos[:,phi_idx] < phi_edges[i+1]
            phi_mask  = phi_mask1 & phi_mask2
            phi_pos      = pos[phi_mask]
            phi_eta      = eta[phi_mask]
            phi_particle_label = particle_labels[phi_mask]

            for j in range(len(eta_edges) - 1):
                eta_mask1 = phi_eta > eta_edges[j]
                eta_mask2 = phi_eta < eta_edges[j+1]
                eta_mask  = eta_mask1 & eta_mask2
                phi_eta_pos = phi_pos[eta_mask]
                phi_eta_particle_label = phi_particle_label[eta_mask]
                pos_sect.append(phi_eta_pos)
                particle_label_sect.append(phi_eta_particle_label)

        return pos_sect, particle_label_sect
    
    def get(self,idx):
        
        hits,particles,truth = self.read_events(idx)   
        pos, eta, particle_labels = self.select_all_hits(hits, 
                                                         particles, 
                                                         truth,noise_label=0)
        tracks = torch.empty(0, 5, dtype=torch.long)  
        phi_edges = np.linspace(*(-np.pi, np.pi), num=self.n_phi_sections+1)
        eta_edges = np.linspace(*self.eta_range, num=self.n_eta_sections+1)
        pos_sect, particle_label_sect = self.split_detector_sections(pos, 
                                                                    eta,
                                                                    particle_labels, 
                                                                    phi_edges, 
                                                                    eta_edges)
        for i in range(len(pos_sect)):
            y = particle_label_sect[0]
            return Data(x=pos_sect[0],
                        y=y,
                        tracks=tracks,
                        inpz = torch.Tensor([i]))

def fetch_dataloader(data_dir, 
                     batch_size, 
                     validation_split,
                     n_events = 100,
                     pt_min = 0.3,
                     n_workers = 1,
                     generate_tracks = True,
                     full_dataset = False,
                     shuffle=False):
    volume_layer_ids = [
        [8, 2], [8, 4], [8, 6], [8, 8], # barrel pixels
        [7, 2], [7, 4], [7, 6], [7, 8], [7, 10], [7, 12], [7, 14],# minus pixel endcap
        [9, 2], [9, 4], [9, 6], [9, 8], [9, 10], [9, 12], [9, 14], # plus pixel endcap
    ]
    dataset = TrackMLParticleTrackingDataset(root=data_dir,
                                             layer_pairs_plus=True, 
                                             pt_min= pt_min,
                                             volume_layer_ids=volume_layer_ids,
                                             n_events=n_events, 
                                             n_workers=n_workers, 
                                             tracking = generate_tracks,
                                             download_full_dataset=full_dataset)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    if dataset_size > 2:
        split = int(np.floor(validation_split * dataset_size))
    else: 
        split = 1
    print(split)
    random_seed= 1001

    train_subset, val_subset = torch.utils.data.random_split(dataset, [dataset_size - split, split],
                                                             generator=torch.Generator().manual_seed(random_seed))
    print("train subset dim:", len(train_subset))
    print("validation subset dim", len(val_subset))
    dataloaders = {
        'train':  DataLoader(train_subset, batch_size=batch_size, shuffle=shuffle),
        'val':   DataLoader(val_subset, batch_size=1, shuffle=shuffle)
        }
    print("train_dataloader dim:", len(dataloaders['train']))
    print("val dataloader dim:", len(dataloaders['val']))
    return dataloaders

In [None]:
loss_offset: float = 0.0

def compute_oc_loss(out, data, s_c=1., return_components=True):
    device = out.device
    pred_betas = torch.sigmoid(out[:,0])
    #print(pred_betas)
    pred_cluster_space_coords = out[:,1:]
    assert all(t.device == device for t in [
        pred_betas, pred_cluster_space_coords, data.y,
        data.batch,
        ])
           
    out_oc = calc_LV_Lbeta(
        pred_betas,
        pred_cluster_space_coords,
        data.y.long(),
        data.batch,
        return_components=return_components,
        qmin = 0.1,
        )
    
    # print(out_oc)
    print(formatted_loss_components_string(out_oc))
    
    #print(out_oc["L_V"])
    #print(out_oc["L_beta"])
    
    LV = out_oc["L_V"]
    Lbeta = out_oc["L_beta"]
    return LV + Lbeta + loss_offset

In [None]:
def train(data_loader, model, epoch, optimizer,interval = 1):
    print('Training epoch', epoch)
    model.train()
    data = tqdm(data_loader, total=len(data_loader))
    data.set_postfix({'loss': '?'})
    for i,inputs in enumerate(data):
        inputs.to('cuda')
        #print("inputs:", inputs.x)
        optimizer.zero_grad()
        result = model(inputs.x, inputs.batch)
        #print(result)
        loss = compute_oc_loss(result,inputs)
        #if i % interval == 0:
        #    print(f'loss={float(loss)}')
        loss.backward()
        optimizer.step()
        data.set_postfix({'loss': float(loss)})

In [None]:
def test(data_loader, model, epoch, generate_plots = False):
    with torch.no_grad():
        model.eval()
        loss = 0.
        data = tqdm(data_loader, total=len(data_loader))
        for i,inputs in enumerate(data):
            inputs.to('cuda')
            result = model(inputs.x, inputs.batch)
            loss +=  compute_oc_loss(result,inputs)
            pred_betas = torch.sigmoid(result[:,0])
            pred_cluster_space_coords = result[:,1:]
            if generate_plots:
                fig = go.Figure(get_plotly_truth(inputs.to('cpu'),size = 2.75))
                fig.write_html("plots/truth_plot_epoch_"+str(epoch+1)+"_batch_"+str(i+1)+".html")
                pred_fig=go.Figure(get_plotly_clusterspace(inputs.to('cpu'),pred_cluster_space_coords.to('cpu'),size = 2.75))
                pred_fig.write_html("plots/predictions_plot_epoch_"+str(epoch+1)+"_batch_"+str(i+1)+".html")
        loss /= len(data_loader)
        print(f'Avg test loss: {loss} {len(data_loader)}')

In [None]:
# Root path settings work ONLY WHEN training files are in the same directory as the code
root = osp.join("/mnt/c/Users/linds/trackml/train_1")
noise_pt_min = 0.3
batch_size = 10
validation_split = 0.1
events = 200
model_input_dim = 6
model_output_dim = 3
data = fetch_dataloader(data_dir = root,
                        batch_size = batch_size,
                        validation_split=validation_split,
                        full_dataset = True,
                        n_events = events,
                        pt_min = noise_pt_min,
                        shuffle=True)
epochs = 100
train_loader,test_loader = data['train'],data['val']
model = GravnetModel(input_dim=model_input_dim,output_dim=model_output_dim).to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

for i_epoch in range(epochs):
    train(train_loader, 
          model, 
          i_epoch, 
          optimizer,
          interval = 1)
    test(test_loader, 
        model, 
        i_epoch,
        generate_plots = True) 

In [None]:
#################### END ############################################################################################