Skip to content

Commit

Permalink
removed redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rasmus 0rs0e committed Nov 12, 2021
1 parent c2893cd commit 4c3298e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 82 deletions.
7 changes: 5 additions & 2 deletions src/gnn_reco/legacy/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def forward(self, data):
if self.predict == False:
if self.target == 'energy':
data[self.target] = torch.tensor(self.scalers['truth'][self.target].transform(np.log10(data[self.target].cpu().numpy()).reshape(-1,1))).to(device)
else:
if self.target == 'zenith':
data[self.target] = torch.tensor(self.scalers['truth'][self.target].transform(data[self.target].cpu().numpy().reshape(-1,1))).to(device)
if self.target == 'azimuth':
data[self.target] = torch.reshape(data[self.target], (-1,1))
edge_index = knn_graph(x=x[:,0:3],k=k,batch=batch).to(device)

h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)
Expand Down Expand Up @@ -113,7 +115,8 @@ def forward(self, data):
else:
if self.target == 'zenith' or self.target == 'azimuth':
pred = np.arctan2(x[:,0].cpu().numpy(),x[:,1].cpu().numpy()).reshape(-1,1)
pred = torch.tensor(self.scalers['truth'][self.target].inverse_transform(pred),dtype = torch.float32)
if self.target == 'zenith':
pred = torch.tensor(self.scalers['truth'][self.target].inverse_transform(pred),dtype = torch.float32)
sigma = abs(1/x[:,2]).cpu()
return torch.cat((pred,sigma.reshape(-1,1)),dim = 1)

This comment has been minimized.

Copy link
@asogaard

asogaard Nov 12, 2021

Collaborator

@RasmusOrsoe I get an error here now, because pred is a numpy array if the target is azimuth, but (correctly) a torch tensor of the target is zenith. Have fixed locally but we should update the code in main (or revert).

else:
Expand Down
83 changes: 3 additions & 80 deletions src/gnn_reco/plots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def extract_statistics(data,keys, key_bins):
biases[key]['cascade']['84th'].append(np.percentile(bias_tmp,84))
return biases

def GetRetro(data, keys,db):
def get_retro(data, keys,db):
events = data['event_no']
key_count = 0
for key in keys:
Expand All @@ -279,11 +279,11 @@ def GetRetro(data, keys,db):
def calculate_statistics(data,keys, key_bins,db,include_retro = False):
biases = {'dynedge': extract_statistics(data, keys, key_bins)}
if include_retro:
retro = GetRetro(data,keys,db)
retro = get_retro(data,keys,db)
biases['retro'] = extract_statistics(retro, keys, key_bins)
return biases

def PlotBiases(key_limits, biases, is_retro = False):
def plot_biases(key_limits, biases, is_retro = False):
key_limits = key_limits['bias']
if is_retro:
prefix = 'RetroReco'
Expand Down Expand Up @@ -458,83 +458,6 @@ def PlotRelativeImprovement(key_limits, biases):
return fig




def WriteReport(archive,data_path, db, key_limits,keys, key_bins):
data = pd.read_csv(data_path)
data = AddEnergy(db, data)
data = AddPIDInteraction(db, data)

biases = CalculateStatistics(data,keys, key_bins,db,include_retro = True)
figures = []
figures.append(MakeSummaryWidthPlot(key_limits, biases,include_retro = True, track_cascade= True))
#figures.append(PlotBiases(key_limits, biases['dynedge']))
#figures.append(PlotBiases(key_limits, biases['retro'], is_retro = True))
#figures.append(PlotWidth(key_limits, biases))
return

def calculate_relative_improvement_error(relimp, w1, w1_sigma, w2, w2_sigma):
sigma = np.sqrt((np.array(w1_sigma)/np.array(w1))**2 + (np.array(w2_sigma)/np.array(w2))**2)
return sigma



def MakeEnergyReport(archive, result_path, db_name):
db_tag = db_name + '.db'
db = archive / 'data' / 'databases' / db_name / 'data' / db_tag
key_limits = {'bias':{'energy':{'x':[0,3], 'y':[-100,100]}},
'width':{'energy':{'x':[0,3.25], 'y':[-0.5,1.5]}},
'rel_imp':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'osc':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'distributions':{'energy':{'x':[0,4], 'y':[-0.75,0.75]}}}
keys = ['energy']
key_bins = { 'energy': np.arange(0, 3.25, 0.25)}
WriteReport(archive,result_path, db, key_limits,keys, key_bins)

def MakeZenithReport(archive, result_path, db_name):
db_tag = db_name + '.db'
db = archive / 'data' / 'databases' / db_name / 'data' / db_tag
key_limits = {'bias':{'energy':{'x':[0,3], 'y':[-100,100]},
'zenith': {'x':[0,3], 'y':[-100,100]}},
'width':{'energy':{'x':[0,3], 'y':[-0.5,1.5]},
'zenith': {'x':[0,3], 'y':[-100,100]}},
'rel_imp':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'osc':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'distributions':{'energy':{'x':[0,4], 'y':[-0.75,0.75]}}}
keys = ['zenith']
key_bins = { 'energy': np.arange(0, 3.25, 0.25),
'zenith': np.arange(0, 180, 10) }
WriteReport(archive,result_path, db, key_limits,keys, key_bins)

def MakeAzimuthReport(archive, result_path, db_name):
db_tag = db_name + '.db'
db = archive / 'data' / 'databases' / db_name / 'data' / db_tag
key_limits = {'bias':{'energy':{'x':[0,3], 'y':[-100,100]},
'zenith': {'x':[0,3], 'y':[-100,100]},
'azimuth': {'x':[0,3], 'y':[-100,100]}},
'width':{'energy':{'x':[0,3], 'y':[-0.5,1.5]},
'zenith': {'x':[0,3], 'y':[-100,100]},
'azimuth': {'x':[0,3], 'y':[-100,100]}},
'rel_imp':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'osc':{'energy':{'x':[0,3], 'y':[-0.75,0.75]}},
'distributions':{'energy':{'x':[0,4], 'y':[-0.75,0.75]}}}
keys = ['azimuth']
key_bins = { 'energy': np.arange(0, 3.25, 0.25),
'zenith': np.arange(0, 180, 10),
'azimuth': np.arange(0, 2*180, 20) }
WriteReport(archive,result_path, db, key_limits,keys, key_bins)

#archive = Path('/groups/hep/pcs557/GNNReco')
#database_name = 'dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3'
#MakeEnergyReport(archive= archive,
#result_path= '/groups/hep/pcs557/phd/results/dev_lvl7_robustness_muon_neutrino_0000/dynedge_energy_3_test/results.csv' ,
#db_name= 'dev_lvl7_robustness_muon_neutrino_0000')

#MakeZenithReport(archive= archive,
#result_path= '/groups/hep/pcs557/phd/results/dev_lvl7_robustness_muon_neutrino_0000/dynedge_zenith_9_test/results.csv' ,
#db_name= 'dev_lvl7_robustness_muon_neutrino_0000')

#MakeAzimuthReport(archive= archive,
#result_path= '/groups/hep/pcs557/GNNReco/results/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/even_neutrino_types/lvl7_azimuth/valid_resultsv3.csv' ,
#db_name= 'dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3')

0 comments on commit 4c3298e

Please sign in to comment.