In [2]:
import uproot
from tqdm import tqdm
import pandas as pd
import numpy as np
from pathlib import Path
import plotly
import plotly.express as px
import multiprocessing as mp
from functools import partial
from hgcal_dev.evaluation.experiments.simple_experiment import SimpleExperiment
from hgcal_dev.evaluation.studies.base_study import BaseStudy
from hgcal_dev.clustering.meanshift import MeanShift
from sklearn.manifold import TSNE

In [3]:
exp = SimpleExperiment('2p1ewhmz', 'epoch=14-step=59999.ckpt')

Global seed set to 1238354216
INFO:lightning:Global seed set to 1238354216


In [4]:
events = exp.get_events('train', 5)

100%|██████████| 5/5 [00:00<00:00, 456.53it/s]


In [5]:
study = BaseStudy(exp)

In [6]:
study.qualitative_cluster_study()

100%|██████████| 5/5 [00:00<00:00, 579.47it/s]


KeyboardInterrupt: 

In [6]:
train_events = exp.get_events('train', n=1000)

100%|██████████| 1000/1000 [00:02<00:00, 470.01it/s]


In [7]:
val_events = exp.get_events('val', n=1000)

100%|██████████| 1000/1000 [00:02<00:00, 470.83it/s]


In [12]:
norm_embedding = train_events[0].embedding
norm_embedding = norm_embedding / np.linalg.norm(norm_embedding, axis=0, ord=2)

In [15]:
train_events[0].embedding[0]

array([-0.8593804 , -0.42098802, -0.45050734, -0.5097168 ,  1.609824  ,
        0.29436257,  0.46049598,  0.3420815 ], dtype=float32)

In [20]:
X = TSNE(n_components=3).fit_transform(norm_embedding)

In [26]:
plot_df = pd.DataFrame({'x': X[:, 0], 'y': X[:, 1], 'z': X[:, 2], 'cluster': train_events[0].input_event['labels_i'].astype(str)})

In [27]:
px.scatter_3d(plot_df, x='x', y='y', z='z', color='cluster')

In [29]:
px.scatter_3d(train_events[0].input_event, x='x', y='y', z='z', color='labels_i')

In [32]:
train_events[0].input_event[train_events[0].input_event['labels_i']==15]

Unnamed: 0,x,y,z,labels_i
0,-6.232693,8.549379,-8.425092,15
2,-6.213314,8.822315,-8.346693,15
3,-6.18881,8.347054,-8.19017,15
14,-6.113827,8.544322,-8.191325,15
17,-6.112558,8.632979,-8.211396,15
24,-6.208022,8.398388,-8.324887,15
36,-6.214666,8.350923,-8.241852,15
44,-6.05415,8.539296,-8.102091,15
56,-6.181337,8.588432,-8.117681,15
60,-6.276959,8.588334,-8.337689,15


In [34]:
norm_embedding[train_events[0].input_event['labels_i']==15]

array([[-0.14147091, -0.05638906, -0.06271756, -0.08886344,  0.20143925,
         0.04364519,  0.06510681,  0.04625609],
       [-0.13023101, -0.06536559, -0.06180707, -0.07390454,  0.20273766,
         0.03791274,  0.03642772,  0.0315116 ],
       [-0.12721667, -0.06403159, -0.05488187, -0.07730442,  0.2078034 ,
         0.05057079,  0.05449379,  0.03300808],
       [-0.12721667, -0.06403159, -0.05488187, -0.07730442,  0.2078034 ,
         0.05057079,  0.05449379,  0.03300808],
       [-0.12721667, -0.06403159, -0.05488187, -0.07730442,  0.2078034 ,
         0.05057079,  0.05449379,  0.03300808],
       [-0.14147091, -0.05638906, -0.06271756, -0.08886344,  0.20143925,
         0.04364519,  0.06510681,  0.04625609],
       [-0.12721667, -0.06403159, -0.05488187, -0.07730442,  0.2078034 ,
         0.05057079,  0.05449379,  0.03300808],
       [-0.12721667, -0.06403159, -0.05488187, -0.07730442,  0.2078034 ,
         0.05057079,  0.05449379,  0.03300808],
       [-0.12721667, -0.06403159

In [None]:
from scipy.spatial import distance_matrix

In [35]:
clusterer = MeanShift(bandwidth=0.01)

In [36]:
pred_labels = clusterer.cluster(train_events[0].embedding)

In [37]:
pred_labels

array([ 0,  1,  0,  0,  7,  7,  9,  2,  6, 14,  2, 16,  4, 16,  0, 12, 15,
        0,  2, 15,  4, 10,  8,  1,  0,  4,  1,  3,  9, 11, 17,  5, 13,  9,
       14, 13,  0,  4,  2,  3,  3,  6,  8, 11,  0,  2,  2, 11,  7, 10,  8,
        1,  1,  6, 10,  7,  0, 12, 12,  6,  0,  3, 13,  1, 14,  4,  5,  3,
       17,  5,  8, 10, 15,  5,  0,  5,  1,  3,  9])

In [41]:
plot_df = train_events[0].input_event
plot_df['slabels_i'] = plot_df['labels_i'].astype(str)
plot_df['spred_labels_i'] = pred_labels.astype(str)
px.scatter_3d(plot_df, x='x', y='y', z='z', color='slabels_i')

In [42]:
px.scatter_3d(plot_df, x='x', y='y', z='z', color='spred_labels_i')