Skip to content

Commit

Permalink
Merge pull request #2468 from LukasBeiske/impact_in_cv
Browse files Browse the repository at this point in the history
Add true_impact_distance to cross validation output
  • Loading branch information
maxnoe committed Nov 21, 2023
2 parents 596b783 + b754690 commit bf51ad3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ def __call__(self, telescope_type, table):
"prediction": cv_prediction,
"truth": truth,
"true_energy": test["true_energy"],
"true_impact_distance": test["true_impact_distance"],
}
)
)
Expand Down
6 changes: 5 additions & 1 deletion ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def start(self):
self.log.info("Loading events for %s", tel_type)
feature_names = self.models.features + [
"true_energy",
"true_impact_distance",
"subarray_pointing_lat",
"subarray_pointing_lon",
"true_alt",
Expand All @@ -138,7 +139,10 @@ def start(self):
n_events=self.n_events.tel[tel_type],
)
table[self.models.target] = self._get_true_disp(table)
table = table[self.models.features + [self.models.target, "true_energy"]]
table = table[
self.models.features
+ [self.models.target, "true_energy", "true_impact_distance"]
]

self.log.info("Train models on %s events", len(table))
self.cross_validate(tel_type, table)
Expand Down
5 changes: 4 additions & 1 deletion ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def start(self):
self.log.info("Training models for %d types", len(types))
for tel_type in types:
self.log.info("Loading events for %s", tel_type)
feature_names = self.regressor.features + [self.regressor.target]
feature_names = self.regressor.features + [
self.regressor.target,
"true_impact_distance",
]
table = read_training_events(
loader=self.loader,
chunk_size=self.chunk_size,
Expand Down
1 change: 1 addition & 0 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _read_input_data(self, tel_type):
feature_names = self.classifier.features + [
self.classifier.target,
"true_energy",
"true_impact_distance",
]
signal = read_training_events(
loader=self.signal_loader,
Expand Down
1 change: 1 addition & 0 deletions docs/changes/2468.optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``true_impact_distance`` to the output of ``CrossValidator``.

0 comments on commit bf51ad3

Please sign in to comment.