In [None]:
!nvidia-smi

First, we grab matplotlib, and set the old "classic" style for some reason only Rui knows.

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.style.use('classic')

In [None]:
import numpy as np
import torch

These are the local imports. Make sure you import the correct model!

In [None]:
from model.models import AltCNN4Layer_D35_sp as Model
from model.collectdata import collect_data, collect_truth
from model.training import select_gpu
from model.plots import plot_ruiplot
from model.efficiency import pv_locations, efficiency
from model.core import modernize_state

Select a GPU here. Same numbering as the NVidia-SMI tool.

In [None]:
device = select_gpu(2)

Pick a file to load.

In [None]:
validation = collect_data('data/Oct03_20K_val.h5',
                          batch_size=1,
                          slice=slice(100),
                          masking=True,
                          device=device)

In [None]:
import h5py
XY_file = 'data/Oct03_20K_val.h5'

with h5py.File(XY_file, mode='r') as XY:
    xmax = np.asarray(XY['Xmax'])
    ymax = np.asarray(XY['Ymax'])

> Note: to get the real PV locations, use `collect_truth('file.h5', pvs=True)` to collect PVs (or SVs with `pvs=False`)

Let's just see how many NaNs we have in the dataset.

In [None]:
print(*np.sum(np.isnan(validation.dataset.tensors[1].cpu().numpy()), 1))

In [None]:
PV = collect_truth('data/Oct03_20K_val.h5', pvs=True)
print('PV.n.shape =    ',  PV.n.shape)
print('PV.n[0].shape = ', *PV.n[0].shape)
print('PV.x[0] =       ', *PV.x[0])
print('PV.y[0] =       ', *PV.y[0])
print('PV.z[0] =       ', *PV.z[0])
print('PV.n[0] =       ', *PV.n[0])
print('PV.cat[0] =     ', *PV.cat[0])

In [None]:
import h5py, awkward

In [None]:
x_list = []
y_list = []
z_list = []
n_list = []
c_list = []

p = 'p'

with h5py.File('temp.h5', mode='r') as XY:
    afile = awkward.persist.hdf5(XY)
    x_list.append(afile[f"{p}v_loc_x"])
    y_list.append(afile[f"{p}v_loc_y"])
    z_list.append(afile[f"{p}v_loc"])
    n_list.append(afile[f"{p}v_ntracks"])
    c_list.append(afile[f"{p}v_cat"])

print(x_list[0].JaggedArray)

In [None]:
import awkward, h5py, numpy as np

simple = awkward.JaggedArray.fromcounts([2,3], [1.,2,3,4,5])

with h5py.File('example.h5', mode='w') as f:
    af = awkward.persist.hdf5(f)
    af['example'] = simple
    
with h5py.File('example.h5', mode='r') as f:
    af = awkward.persist.hdf5(f)
    example = af['example']
    
print(example.JaggedArray)

In [None]:
with h5py.File('temp.h5', mode='r') as XY:
    schema = XY['pv_loc_x']['schema.json'].value

In [None]:
print(np.string_(schema))

In [None]:
PV.n.flatten().max()

In [None]:
import awkward

In [None]:
PV.x

In [None]:
PV.cat[0]

In [None]:
SV = collect_truth('data/Oct03_20K_val.h5', pvs=False)
print('SV.n.shape =    ', SV.n.shape)
print('SV.n[0].shape = ', *SV.n[0].shape)
print('SV.x[0] =       ', *SV.x[0])
print('SV.y[0] =       ', *SV.y[0])
print('SV.z[0] =       ', *SV.z[0])
print('SV.n[0] =       ', *SV.n[0])
print('SV.cat[0] =     ', *SV.cat[0])

In [None]:
model = Model().to(device)

Select a model to load. Make sure it matches the model you imported above.

> #### Mike note:
>
> If you use an old-style model, comment out the `d = modernize(d, 3)` line - that converts the old model key names to the new format.

In [None]:
state = torch.load(
    '/share/lazy/schreihf/PvFinder/models/'
    '07Jan19_AltCNN4Layer_D35_sp_300epochs'
    '_240K_lr_3em5_bs256_Alt_Loss_A_5p5/'
    '07Jan19_AltCNN4Layer_D35_sp_300epochs'
    '_240K_lr_3em5_bs256_Alt_Loss_A_5p5_199.pyt'
)

In [None]:
del state['fc1.weight'], state['fc1.bias']

In [None]:
state = modernize_state(model, state)

In [None]:

model.load_state_dict(state)
model.eval()

Let's grab the outputs and labels as normal numpy arrays.

In [None]:
%%time
with torch.no_grad():
    outputs = model(validation.dataset.tensors[0]).cpu().numpy()
    labels = validation.dataset.tensors[1].cpu().numpy()

In [None]:
from cycler import cycler

In [None]:
cycle = cycler(color=['black', '#444444', '#888888'])
hatch=['/', '']

In [None]:
styledict = {
    'font.size':18,
    'font.weight':'bold',
    'axes.prop_cycle': cycle,
}
fontdict = {
    'size':18,
    'weight':'bold',
}
rui_styles = {
    'kernel': dict(facecolor='none', hatch=''),
    'target': dict(facecolor='none', hatch='/'),
    'predicted': dict(facecolor='none', hatch='\\'),
    'masked': dict(facecolor='none', hatch='*'),
}


And here's Rui's plotting code.

In [None]:
inputs = validation.dataset.tensors[0].cpu().numpy().squeeze()
zvals = np.linspace(-100, 300, 4000, endpoint=False) + 0.05
finalmsg = ''
internal_count = 0
output_filename = '07Jan19_AltCNN4Layer_D35_sp_{number:02}.pdf'
#  None # Or set: '120000_3layer_{number:02}.pdf'

for event in range(2):
    input = inputs[event]
    label = labels[event]
    output = outputs[event]
    
    # Consistent parameters for the calls below
    parameters = {
        "threshold": 1e-2,
        "integral_threshold": .2,
        "min_width": 3
    }
    
    # Compute the "actual" efficenies and things
    ftruth = pv_locations(label, **parameters)
    fcomputed = pv_locations(output, **parameters)
    results = efficiency(label, output, difference=5.0, **parameters)
    
    # Add a line to the final results string (print at end)
    finalmsg += f"Event {event}: {results}\n"
    
    # Make sure bin numbers are integers
    truth = np.around(ftruth).astype(np.int32)
    computed = np.around(fcomputed).astype(np.int32)
    
    # Join arrays and remove any points closer than 5 bins
    # We plot over these "points of interest"
    poi = np.sort(np.concatenate([truth, computed]))
    poi = poi[np.concatenate([[True], np.fabs(np.diff(poi)) > 5])]
    
    print(f"\nEvent {event}:", results)
    
    for index, i in enumerate(poi):
        # Convert to location in z
        center = (i / 10) - 100
        
        # Collect items less than 5 apart as "true"
        b_truth = np.fabs(ftruth - i) <= 5
        b_comp = np.fabs(fcomputed - i) <= 5
        in_truth = np.any(b_truth)
        in_comp = np.any(b_comp)
        
        if in_truth and in_comp:
            msg = 'PV found'
        elif in_truth:
            msg = 'PV not found'
        elif np.any(np.isnan(label[i-3:i+3])):
            msg = 'Masked'
        else:
            msg = 'False positive'
            
        with plt.style.context(styledict):
        
            fig, axs = plt.subplots(2, figsize=(12,10), sharex=True,
                                    gridspec_kw={'height_ratios':[2,1],
                                                'hspace':0.1})
        
            # ax1 is the axis that is tied to left (density)
            # ax2 is the axis that is tied to the right (probability)
            ax1, ax2 = plot_ruiplot(zvals, i, input, label, output, ax=axs[0], styles=rui_styles)
            ax1.set_title(f"Event {event} @ {center:.1f} mm: {msg}",
                          fontdict=fontdict)


            msg = ""
            
            truth_centroid = (ftruth[b_truth] / 10) - 100
            for value in truth_centroid:
                msg += f"True: {value:.3f} mm\n"
                
            comp_centroid = (fcomputed[b_comp] / 10) - 100
            for value in comp_centroid:
                msg += f"Pred: {value:.3f} mm\n"
                
            if len(truth_centroid) == 1 and len(comp_centroid) == 1:
                diff = (comp_centroid[0] - truth_centroid[0]) * 1_000
                msg += f"∆: {diff:.0f} µm\n"
            
            ax1.text(.02, .8, msg,
                     transform=ax1.transAxes,
                     verticalalignment='top')
            
            print(f"\nEvent {event}.{index}:")
            
            # Plot and print PVs
            ax2.scatter(PV.z[event], np.ones_like(PV.z[event])*.4, s=50, color='C0')
            for x,y,z,n,cat in zip(PV.x[event], PV.y[event], PV.z[event], PV.n[event], PV.cat[event]):
                # Only print out if z in plotting range
                if center - 2.5 < z < center + 2.5:
                    print()
                    print(f'PV: {n} tracks (type {cat})')
                    print(f'  x: {x*1000:5.0f} μm')
                    print(f'  y: {y*1000:5.0f} μm')
                    print(f'  z: {z:8.3f} mm')
                    

            # Plot and print SVs
            ax2.scatter(SV.z[event], np.ones_like(SV.z[event])*.6, s=50, color='C1')
            for x,y,z,n,cat in zip(SV.x[event], SV.y[event], SV.z[event], SV.n[event], SV.cat[event]):
                # Only print out if z in plotting range
                if center - 2.5 < z < center + 2.5:
                    print()
                    print(f'SV: {n} tracks (type {cat})')
                    print(f'  x: {x*1000:5.0f} μm')
                    print(f'  y: {y*1000:5.0f} μm')
                    print(f'  z: {z:8.3f} mm')
            
            ax = axs[1]
            ax.plot((np.arange(4000) / 10) - 100, xmax[event]*1000000, label="x")
            ax.plot((np.arange(4000) / 10) - 100, ymax[event]*1000000, label="y")
            ax.set_xlim(ax1.get_xlim())
            ax.set_ylim(-150,150)
            ax.grid(axis='y')
            ax.set_ylabel('xy maximum [μm]')
            ax.legend(loc='best')
            
            ax.set_xlabel(ax1.get_xlabel())
            ax1.set_xlabel("")
            
            if output_filename:
                print()
                print(output_filename.format(number=internal_count))
            
            # Save and show
            if output_filename:        
                plt.savefig(output_filename.format(number=internal_count), transparent=True)
            plt.show()
            
            internal_count += 1
            
print(finalmsg)