diff --git a/src/mokka/pulseshaping/torch.py b/src/mokka/pulseshaping/torch.py index 108dd47..6b16695 100644 --- a/src/mokka/pulseshaping/torch.py +++ b/src/mokka/pulseshaping/torch.py @@ -52,7 +52,7 @@ def forward(self, y, n_up): ) return y_shaped - def matched(self, r, n_down): + def matched(self, r, n_down, max_energy=False): """Perform matched filtering. This function assumes perfect timing sync. @@ -63,7 +63,24 @@ def matched(self, r, n_down): """ y_filt = functional.torch.convolve(r, self.impulse_response_conj / n_down) offset = self.impulse_response_conj.shape[0] - 1 - y = y_filt[::n_down][int(offset / n_down) : -int(offset / n_down)] + energy_index = 0 + if max_energy: + curr_energy = 0.0 + for nd in range(n_down): + energy = torch.sum( + torch.pow( + torch.abs( + y_filt[nd::n_down][ + int(offset / n_down) : -int(offset / n_down) + ] + ), + 2, + ) + ) + if energy > curr_energy: + curr_energy = energy + energy_index = nd + y = y_filt[energy_index::n_down][int(offset / n_down) : -int(offset / n_down)] return y def normalize_filter(self):