Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use signal_fraction for training particle classifier #2465

Merged
merged 8 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes/2465.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Replace ``n_signal`` and ``n_background`` options in ``ctapipe-train-particle-classifier``
with ``n_events`` and ``signal_fraction``, where ``signal_fraction`` = n_signal / (n_signal + n_background).
10 changes: 4 additions & 6 deletions src/ctapipe/resources/train_particle_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
# ======================================================================

TrainParticleClassifier:
random_seed: 0 # Seed used for sampling n_* events for training.
n_signal: # The number of signal events used for training that can be provided
random_seed: 0 # Seed used for sampling n_events for training.
n_events: # The number of events used for training that can be provided
# - [type, "LST*", 1000] # independently for each telescope type (e.g. "LST_LST_LSTCam").
# - [type, "MST*", 1000] # If not specified, all events in the file are used.
n_background: # Same as above, but for background events.
# - [type, "LST*", 1000]
# - [type, "MST*", 1000]
# - [type, "MST*", 1000] # If not specified, as many events as possible are used.
signal_fraction: 0.5 # signal_fraction = n_signal / n_events

CrossValidator:
n_cross_validations: 5
Expand Down
52 changes: 52 additions & 0 deletions src/ctapipe/tools/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

from ctapipe.core import ToolConfigurationError, run_tool
Expand Down Expand Up @@ -63,6 +64,57 @@ def test_sampling(tmp_path, dl2_shower_geometry_file):
)


def test_signal_fraction(tmp_path, gamma_train_clf, proton_train_clf):
from ctapipe.tools.train_particle_classifier import TrainParticleClassifier

tool = TrainParticleClassifier()
config = resource_file("train_particle_classifier.yaml")
out_file = tmp_path / "particle_classifier_.pkl"
log_file = tmp_path / "train_particle.log"

with pytest.raises(
ToolConfigurationError,
match="The signal_fraction has to be between 0 and 1",
):
run_tool(
tool,
argv=[
f"--signal={gamma_train_clf}",
f"--background={proton_train_clf}",
f"--output={out_file}",
f"--config={config}",
"--signal-fraction=1.1",
"--log-level=INFO",
],
raises=True,
)

for frac in [0.7, 0.1]:
run_tool(
tool,
argv=[
f"--signal={gamma_train_clf}",
f"--background={proton_train_clf}",
f"--output={out_file}",
f"--config={config}",
f"--log-file={log_file}",
f"--signal-fraction={frac}",
"--log-level=INFO",
"--overwrite",
],
)

with open(log_file, "r") as f:
log = f.readlines()

for line in log[::-1]:
if "Train on" in line:
n_signal, n_background = [int(line.split(" ")[i]) for i in (7, 10)]
break

assert np.allclose(n_signal / (n_signal + n_background), frac, atol=1e-4)


def test_cross_validation_results(tmp_path, gamma_train_clf, proton_train_clf):
from ctapipe.tools.train_disp_reconstructor import TrainDispReconstructor
from ctapipe.tools.train_energy_regressor import TrainEnergyRegressor
Expand Down
68 changes: 49 additions & 19 deletions src/ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
from astropy.table import vstack

from ctapipe.core.tool import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from ctapipe.core.tool import Tool, ToolConfigurationError
from ctapipe.core.traits import Float, Int, IntTelescopeParameter, Path
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, ParticleClassifier

Expand Down Expand Up @@ -60,21 +60,22 @@
),
).tag(config=True)

n_signal = IntTelescopeParameter(
n_events = IntTelescopeParameter(
default_value=None,
allow_none=True,
help=(
"Number of signal events to be used for training."
"Total number of events to be used for training."
" If not given, all available events will be used"
" (considering ``signal_fraction``)."
),
).tag(config=True)

n_background = IntTelescopeParameter(
default_value=None,
allow_none=True,
signal_fraction = Float(
default_value=0.5,
allow_none=False,
help=(
"Number of background events to be used for training."
" If not given, all available events will be used"
"Fraction of signal events in all events to be used for training."
" ``signal_fraction`` = n_signal / (n_signal + n_background)"
),
).tag(config=True)

Expand All @@ -83,7 +84,7 @@
allow_none=True,
help=(
"How many subarray events to load at once before training on"
" n_signal and n_background events."
" n_events (or all available) events."
),
).tag(config=True)

Expand All @@ -100,8 +101,8 @@
aliases = {
"signal": "TrainParticleClassifier.input_url_signal",
"background": "TrainParticleClassifier.input_url_background",
"n-signal": "TrainParticleClassifier.n_signal",
"n-background": "TrainParticleClassifier.n_background",
"n-events": "TrainParticleClassifier.n_events",
"signal-fraction": "TrainParticleClassifier.signal_fraction",
"n-jobs": "ParticleClassifier.n_jobs",
("o", "output"): "TrainParticleClassifier.output_path",
"cv-output": "CrossValidator.output_path",
Expand Down Expand Up @@ -132,11 +133,10 @@
if self.signal_loader.subarray != self.background_loader.subarray:
raise ValueError("Signal and background subarrays do not match")

self.subarray = self.signal_loader.subarray
self.n_signal.attach_subarray(self.subarray)
self.n_background.attach_subarray(self.subarray)
self.classifier = ParticleClassifier(subarray=self.subarray, parent=self)

self.n_events.attach_subarray(self.signal_loader.subarray)
self.classifier = ParticleClassifier(
subarray=self.signal_loader.subarray, parent=self
)
self.cross_validate = self.enter_context(
CrossValidator(
parent=self, model_component=self.classifier, overwrite=self.overwrite
Expand Down Expand Up @@ -166,11 +166,24 @@
self.log.info("done")

def _read_input_data(self, tel_type):
if self.signal_fraction < 0 or self.signal_fraction > 1:
raise ToolConfigurationError(
"The signal_fraction has to be between 0 and 1"
)

feature_names = self.classifier.features + [
self.classifier.target,
"true_energy",
"true_impact_distance",
]
n_events = self.n_events.tel[tel_type]
if n_events is not None:
n_signal = int(self.signal_fraction * n_events)
n_background = n_events - n_signal

Check warning on line 182 in src/ctapipe/tools/train_particle_classifier.py

View check run for this annotation

Codecov / codecov/patch

src/ctapipe/tools/train_particle_classifier.py#L181-L182

Added lines #L181 - L182 were not covered by tests
else:
n_signal = None
n_background = None

signal = read_training_events(
loader=self.signal_loader,
chunk_size=self.chunk_size,
Expand All @@ -179,7 +192,7 @@
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_signal.tel[tel_type],
n_events=n_signal,
)
background = read_training_events(
loader=self.background_loader,
Expand All @@ -189,8 +202,25 @@
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_background.tel[tel_type],
n_events=n_background,
)
if n_events is None: # use as many events as possible (keeping signal_fraction)
n_signal = len(signal)
n_background = len(background)

if n_signal < (n_signal + n_background) * self.signal_fraction:
n_background = int(n_signal * (1 / self.signal_fraction - 1))
self.log.info("Sampling %d background events", n_background)
idx = self.rng.choice(len(background), n_background, replace=False)
idx.sort()
background = background[idx]
else:
n_signal = int(n_background / (1 / self.signal_fraction - 1))
self.log.info("Sampling %d signal events", n_signal)
idx = self.rng.choice(len(signal), n_signal, replace=False)
idx.sort()
signal = signal[idx]

table = vstack([signal, background])
self.log.info(
"Train on %s signal and %s background events", len(signal), len(background)
Expand Down