Skip to content

Commit

Permalink
Better condition
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBeiske committed Nov 22, 2023
1 parent f8d5072 commit 1ad8b25
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,17 @@ def _read_input_data(self, tel_type):
n_events=n_background,
)
if n_events is None: # use as many events as possible (keeping signal_fraction)
if len(signal) < len(background) / (1 / self.signal_fraction - 1):
n_signal = len(signal)
n_background = int(n_signal * (1 / self.signal_fraction - 1))
n_signal = len(signal)
n_background = len(background)

if n_signal < (n_signal + n_background) * self.signal_fraction:
n_background = int(n_signal * (1 / self.signal_fraction - 1))
self.log.info("Sampling %d background events", n_background)
idx = self.rng.choice(len(background), n_background, replace=False)
idx.sort()
background = background[idx]
else:
n_background = len(background)
n_signal = int(n_background / (1 / self.signal_fraction - 1))

self.log.info("Sampling %d signal events", n_signal)
idx = self.rng.choice(len(signal), n_signal, replace=False)
idx.sort()
Expand Down

0 comments on commit 1ad8b25

Please sign in to comment.