Skip to content

Commit

Permalink
Merge pull request #4 from uhecr-project:feature_TADataset
Browse files Browse the repository at this point in the history
Implement TA dataset to code
  • Loading branch information
kwat0308 authored Aug 4, 2021
2 parents 47358f6 + 68fca32 commit ea4053a
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 44 deletions.
55 changes: 44 additions & 11 deletions fancy/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from matplotlib import pyplot as plt
import h5py
from tqdm import tqdm as progress_bar
from multiprocessing import Pool, cpu_count

import stan_utility

from ..interfaces.integration import ExposureIntegralTable
from ..interfaces.stan import Direction, convert_scale
from ..interfaces.data import Uhecr
from ..plotting import AllSkyMap
from ..propagation.energy_loss import get_Eth_src, get_kappa_ex, get_Eex, get_Eth_sim, get_arrival_energy
from ..propagation.energy_loss import get_Eth_src, get_kappa_ex, get_Eex, get_Eth_sim, get_arrival_energy, get_arrival_energy_vec

__all__ = ['Analysis']

Expand All @@ -21,6 +22,9 @@ class Analysis():
"""
To manage the running of simulations and fits based on Data and Model objects.
"""

nthreads = int(cpu_count() * 0.75)

def __init__(self,
data,
model,
Expand Down Expand Up @@ -82,7 +86,10 @@ def __init__(self,
varpi = self.data.source.unit_vector
self.tables = ExposureIntegralTable(varpi=varpi, params=params)

def build_tables(self, num_points=50, sim_only=False, fit_only=False):
# cpu count


def build_tables(self, num_points=50, sim_only=False, fit_only=False, parallel=True):
"""
Build the necessary integral tables.
"""
Expand All @@ -99,8 +106,12 @@ def build_tables(self, num_points=50, sim_only=False, fit_only=False):
self.data.source.distance)
kappa_true = self.kappa_ex

self.tables.build_for_sim(kappa_true, self.model.alpha,
if parallel:
self.tables.build_for_sim_parallel(kappa_true, self.model.alpha,
self.model.B, self.data.source.distance)
else:
self.tables.build_for_sim(kappa_true, self.model.alpha,
self.model.B, self.data.source.distance)

if fit_only:

Expand All @@ -121,9 +132,13 @@ def build_tables(self, num_points=50, sim_only=False, fit_only=False):
(kappa_first, kappa_second[1:], kappa_third[1:]), axis=0)

# full table for fit
self.tables.build_for_fit(kappa)
if parallel:
# self.tables.build_for_fit(kappa)
self.tables.build_for_fit_parallel(kappa)
else:
self.tables.build_for_fit(kappa)

def build_energy_table(self, num_points=50, table_file=None):
def build_energy_table(self, num_points=50, table_file=None, parallel=True):
"""
Build the energy interpolation tables.
"""
Expand All @@ -134,11 +149,24 @@ def build_energy_table(self, num_points=50, table_file=None):
base=np.e)
self.Earr_grid = []

for i in progress_bar(range(len(self.data.source.distance)),
desc='Precomputing energy grids'):
d = self.data.source.distance[i]
self.Earr_grid.append(
[get_arrival_energy(e, d)[0] for e in self.E_grid])
if parallel:

args_list = [(self.E_grid, d) for d in self.data.source.distance]
# parallelize for each source distance
with Pool(self.nthreads) as mpool:
results = list(progress_bar(
mpool.imap(get_arrival_energy_vec, args_list), total=len(args_list),
desc='Precomputing exposure integral'
))

self.Earr_grid = results

else:
for i in progress_bar(range(len(self.data.source.distance)),
desc='Precomputing energy grids'):
d = self.data.source.distance[i]
self.Earr_grid.append(
[get_arrival_energy(e, d)[0] for e in self.E_grid])

if table_file:
with h5py.File(table_file, 'r+') as f:
Expand Down Expand Up @@ -181,7 +209,8 @@ def _simulate_zenith_angles(self):
Simulate zenith angles for a set of arrival_directions.
"""

start_time = 2004
# start_time = 2004
start_time = 2008

if len(self.arrival_direction.d.icrs) == 1:
c_icrs = self.arrival_direction.d.icrs[0]
Expand Down Expand Up @@ -356,6 +385,10 @@ def _prepare_fit_inputs(self):
E_grid = self.E_grid
Earr_grid = list(self.Earr_grid)

# KW: due to multiprocessing appending,
# collapse dimension from (1, 23, 50) -> (23, 50)
eps_fit.resize(self.Earr_grid.shape)

# handle selected sources
if (self.data.source.N < len(eps_fit)):
eps_fit = [eps_fit[i] for i in self.data.source.selection]
Expand Down
6 changes: 5 additions & 1 deletion fancy/detector/TA2015.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
# reconstruction uncertainty for energy
f_E = 0.20

# threshold energy [EeV]
Eth = 57

# For convenience
detector_properties = {}
detector_properties['label'] = 'TA'
Expand All @@ -95,4 +98,5 @@
detector_properties['kappa_d'] = kappa_d
detector_properties['f_E'] = f_E
detector_properties['A'] = A
detector_properties['alpha_T'] = alpha_T
detector_properties['alpha_T'] = alpha_T
detector_properties['Eth'] = Eth
4 changes: 4 additions & 0 deletions fancy/detector/auger2010.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
# reconstruction uncertainty for energy
f_E = 0.12

# threshold energy [EeV]
Eth = 52

# For convenience
detector_properties = {}
detector_properties['label'] = 'PAO'
Expand All @@ -83,4 +86,5 @@
detector_properties['f_E'] = f_E
detector_properties['A'] = A
detector_properties['alpha_T'] = alpha_T
detector_properties['Eth'] = Eth

4 changes: 4 additions & 0 deletions fancy/detector/auger2014.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
# reconstruction uncertainty for energy
f_E = 0.12

# threshold energy [EeV]
Eth = 52

# For convenience
detector_properties = {}
detector_properties['label'] = 'PAO'
Expand All @@ -101,6 +104,7 @@
detector_properties['f_E'] = f_E
detector_properties['A'] = A
detector_properties['alpha_T'] = alpha_T
detector_properties['Eth'] = Eth



1 change: 0 additions & 1 deletion fancy/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def exposure(self):
# normalise to a maximum at 1
# max value of exposure factor is normalization constant
self.exposure_factor = (m / np.max(m))
# self.exposure_factor = m

# find the point at which the exposure factor is 0
# indexing value depends on TA or PAO
Expand Down
81 changes: 52 additions & 29 deletions fancy/interfaces/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,13 @@ def __init__(self):
self.properties = None
self.source_labels = None

def from_data_file(self, filename, label):
def from_data_file(self, filename, label, exp_factor = 1.):
"""
Define UHECR from data file of original information.
Handles calculation of observation periods and
effective areas assuming the UHECR are detected
by the Pierre Auger Observatory.
by the Pierre Auger Observatory or TA.
:param filename: name of the data file
:param label: reference label for the UHECR data set
Expand All @@ -440,7 +440,7 @@ def from_data_file(self, filename, label):

self.unit_vector = coord_to_uv(self.coord)
self.period = self._find_period()
self.A = self._find_area()
self.A = self._find_area(exp_factor)

def _get_properties(self):
"""
Expand Down Expand Up @@ -575,7 +575,7 @@ def save(self, file_handle):
for key, value in self.properties.items():
file_handle.create_dataset(key, data=value)

def _find_area(self):
def _find_area(self, exp_factor):
"""
Find the effective area of the observatory at
the time of detection.
Expand All @@ -587,9 +587,9 @@ def _find_area(self):
if self.label == 'auger2010':
from ..detector.auger2010 import A1, A2, A3
possible_areas = [A1, A2, A3]
area = [possible_areas[i - 1] for i in self.period]
area = [possible_areas[i - 1]* exp_factor for i in self.period]

if self.label == 'auger2014':
elif self.label == 'auger2014':
from ..detector.auger2014 import A1, A2, A3, A4, A1_incl, A2_incl, A3_incl, A4_incl
possible_areas_vert = [A1, A2, A3, A4]
possible_areas_incl = [A1_incl, A2_incl, A3_incl, A4_incl]
Expand All @@ -598,9 +598,14 @@ def _find_area(self):
area = []
for i, p in enumerate(self.period):
if self.zenith_angle[i] <= 60:
area.append(possible_areas_vert[p - 1])
area.append(possible_areas_vert[p - 1]* exp_factor )
if self.zenith_angle[i] > 60:
area.append(possible_areas_incl[p - 1])
area.append(possible_areas_incl[p - 1]* exp_factor )

elif self.label == 'TA2015':
from ..detector.TA2015 import A1, A2
possible_areas = [A1, A2]
area = [possible_areas[i - 1] * exp_factor for i in self.period]

else:
print('Error: effective areas and periods not defined')
Expand All @@ -613,28 +618,46 @@ def _find_period(self):
in table 1 in Abreu et al. (2010) or in Collaboration et al. 2014.
"""

from ..detector.auger2014 import (period_1_start, period_1_end,
period_2_start, period_2_end,
period_3_start, period_3_end,
period_4_start, period_4_end)

# check dates
period = []
for y, d in np.nditer([self.year, self.day]):
d = int(d)
test_date = date(y, 1, 1) + timedelta(d)

if period_1_start <= test_date <= period_1_end:
period.append(1)
elif period_2_start <= test_date <= period_2_end:
period.append(2)
elif period_3_start <= test_date <= period_3_end:
period.append(3)
elif test_date >= period_3_end:
period.append(4)
else:
print('Error: cannot determine period for year', y, 'and day',
d)
if self.label == "auger2014":
from ..detector.auger2014 import (period_1_start, period_1_end,
period_2_start, period_2_end,
period_3_start, period_3_end,
period_4_start, period_4_end)

# check dates
for y, d in np.nditer([self.year, self.day]):
d = int(d)
test_date = date(y, 1, 1) + timedelta(d)

if period_1_start <= test_date <= period_1_end:
period.append(1)
elif period_2_start <= test_date <= period_2_end:
period.append(2)
elif period_3_start <= test_date <= period_3_end:
period.append(3)
elif test_date >= period_3_end:
period.append(4)
else:
print('Error: cannot determine period for year', y, 'and day',
d)

elif self.label == "TA2015":
from ..detector.TA2015 import (period_1_start, period_1_end,
period_2_start, period_2_end)
for y, d in np.nditer([self.year, self.day]):
d = int(d)
test_date = date(y, 1, 1) + timedelta(d)

if period_1_start <= test_date <= period_1_end:
period.append(1)
elif period_2_start <= test_date <= period_2_end:
period.append(2)
elif test_date >= period_2_end:
period.append(2)
else:
print('Error: cannot determine period for year', y, 'and day',
d)

return period

Expand Down
Loading

0 comments on commit ea4053a

Please sign in to comment.