Skip to content

Commit

Permalink
Merge pull request #69 from RasmusOrsoe/main
Browse files Browse the repository at this point in the history
Added binned resolution plot script
  • Loading branch information
asogaard committed Nov 12, 2021
2 parents b39887e + 4c3298e commit a622a18
Show file tree
Hide file tree
Showing 4 changed files with 617 additions and 2 deletions.
21 changes: 21 additions & 0 deletions examples/test_width_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from gnn_reco.plots.width_plot import width_plot
import numpy as np


predictions_path = '/groups/hep/pcs557/phd/results/dev_lvl7_robustness_muon_neutrino_0000/dynedge_zenith_9_test_set/results.csv'
database = '/groups/hep/pcs557/GNNReco/data/databases/dev_lvl7_robustness_muon_neutrino_0000/data/dev_lvl7_robustness_muon_neutrino_0000.db'

key_limits = {'bias':{'energy':{'x':[0,3], 'y':[-100,100]},
'zenith': {'x':[0,3], 'y':[-100,100]}},
'width':{'energy':{'x':[0,3], 'y':[-0.5,1.5]},
'zenith': {'x':[0,3], 'y':[-100,100]}},
'rel_imp':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'osc':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'distributions':{'energy':{'x':[0,4], 'y':[-0.75,0.75]}}}
keys = ['zenith']
key_bins = { 'energy': np.arange(0, 3.25, 0.25),
'zenith': np.arange(0, 180, 10) }

performance_figure = width_plot(key_limits, keys, key_bins, database, predictions_path, figsize = (10,8), include_retro = True, track_cascade = True)

performance_figure.savefig('test_performance_figure.png')
7 changes: 5 additions & 2 deletions src/gnn_reco/legacy/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def forward(self, data):
if self.predict == False:
if self.target == 'energy':
data[self.target] = torch.tensor(self.scalers['truth'][self.target].transform(np.log10(data[self.target].cpu().numpy()).reshape(-1,1))).to(device)
else:
if self.target == 'zenith':
data[self.target] = torch.tensor(self.scalers['truth'][self.target].transform(data[self.target].cpu().numpy().reshape(-1,1))).to(device)
if self.target == 'azimuth':
data[self.target] = torch.reshape(data[self.target], (-1,1))
edge_index = knn_graph(x=x[:,0:3],k=k,batch=batch).to(device)

h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)
Expand Down Expand Up @@ -113,7 +115,8 @@ def forward(self, data):
else:
if self.target == 'zenith' or self.target == 'azimuth':
pred = np.arctan2(x[:,0].cpu().numpy(),x[:,1].cpu().numpy()).reshape(-1,1)
pred = torch.tensor(self.scalers['truth'][self.target].inverse_transform(pred),dtype = torch.float32)
if self.target == 'zenith':
pred = torch.tensor(self.scalers['truth'][self.target].inverse_transform(pred),dtype = torch.float32)
sigma = abs(1/x[:,2]).cpu()
return torch.cat((pred,sigma.reshape(-1,1)),dim = 1)
else:
Expand Down
Loading

0 comments on commit a622a18

Please sign in to comment.