Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add simple test to compute photon energy deposition #794

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/spack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ spack:
- py-sphinxcontrib-bibtex
- "root@6.24: cxxstd=17"
- "swig@4.1:"
- "vecgeom@1.2: +gdml cxxstd=17"
- "vecgeom@1.2.2: +gdml cxxstd=17"
view: true
concretizer:
unify: true
Expand Down
64 changes: 55 additions & 9 deletions test/celeritas/user/CaloTestBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,65 @@ void CaloTestBase::RunResult::print_expected() const
"/*** END CODE ***/\n";
}

//---------------------------------------------------------------------------//
//! Initalize results
void CaloTestBase::initalize()
{
// Initialize RunResult vectors
size_t num_detectors = this->calo_->num_detectors();
result.edep = std::vector<double>(num_detectors, 0.);
result.edep_err = std::vector<double>(num_detectors, 0.);
}

//---------------------------------------------------------------------------//
//! Gather results during a batch
void CaloTestBase::gather_batch_results()
{
// Retrieve energies deposited this batch for each detector
auto edep = calo_->calc_total_energy_deposition();

// Update results for each detector
size_t num_detectors = this->calo_->num_detectors();
for (size_t i_det = 0; i_det < num_detectors; ++i_det)
{
auto edep_det = edep[i_det];
result.edep.at(i_det) += edep_det;
result.edep_err.at(i_det) += (edep_det * edep_det);
}
calo_->clear();
}

//---------------------------------------------------------------------------//
//! Finalize results
void CaloTestBase::finalize()
{
if (num_batches_ <= 1)
return;

// Compute the mean and relative_err over batches for each detector
double norm = 1.0 / double(num_batches_);
size_t num_detectors = this->calo_->num_detectors();
for (size_t i_det = 0; i_det < num_detectors; ++i_det)
{
auto mu = result.edep.at(i_det) * norm;
auto var = result.edep_err.at(i_det) * norm - mu * mu;
CELER_ASSERT(var > 0);
auto err = sqrt(var) / mu;
result.edep.at(i_det) = mu;
result.edep_err.at(i_det) = err;
}
}

//---------------------------------------------------------------------------//
/*!
* Run a number of tracks.
*/
template<MemSpace M>
auto CaloTestBase::run(size_type num_tracks, size_type num_steps) -> RunResult
auto CaloTestBase::run(size_type num_tracks,
size_type num_steps,
size_type num_batches) -> RunResult
{
this->run_impl<M>(num_tracks, num_steps);

RunResult result;
result.edep = calo_->calc_total_energy_deposition();
calo_->clear();

this->run_impl<M>(num_tracks, num_steps, num_batches);
return result;
}

Expand All @@ -86,9 +132,9 @@ std::string CaloTestBase::output() const

//---------------------------------------------------------------------------//
template CaloTestBase::RunResult
CaloTestBase::run<MemSpace::device>(size_type, size_type);
CaloTestBase::run<MemSpace::device>(size_type, size_type, size_type);
template CaloTestBase::RunResult
CaloTestBase::run<MemSpace::host>(size_type, size_type);
CaloTestBase::run<MemSpace::host>(size_type, size_type, size_type);

//---------------------------------------------------------------------------//
} // namespace test
Expand Down
11 changes: 10 additions & 1 deletion test/celeritas/user/CaloTestBase.hh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CaloTestBase : virtual public StepCollectorTestBase
struct RunResult
{
std::vector<double> edep;
std::vector<double> edep_err;

void print_expected() const;
};
Expand All @@ -48,15 +49,23 @@ class CaloTestBase : virtual public StepCollectorTestBase
virtual VecString get_detector_names() const = 0;

template<MemSpace M>
RunResult run(size_type num_tracks, size_type num_steps);
RunResult
run(size_type num_tracks, size_type num_steps, size_type num_batches = 1);

// Get JSON output from the simple calo interface
std::string output() const;

protected:
virtual void gather_batch_results();
virtual void initalize();
virtual void finalize();

std::shared_ptr<SimpleCalo> calo_;
std::shared_ptr<StepCollector> collector_;
std::shared_ptr<OutputRegistry> output_;

private:
RunResult result;
};

//---------------------------------------------------------------------------//
Expand Down
10 changes: 6 additions & 4 deletions test/celeritas/user/DiagnosticTestBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ auto DiagnosticTestBase::run(size_type num_tracks, size_type num_steps)
-> RunResult
{
this->run_impl<M>(num_tracks, num_steps);
return result;
}

RunResult result;

//---------------------------------------------------------------------------//
//! Initalize results
void DiagnosticTestBase::gather_batch_results()
{
// Save action diagnostic results
for (auto const& [label, count] : action_diagnostic_->calc_actions_map())
{
Expand All @@ -108,8 +112,6 @@ auto DiagnosticTestBase::run(size_type num_tracks, size_type num_steps)
{
result.steps.insert(result.steps.end(), vec.begin(), vec.end());
}

return result;
}

//---------------------------------------------------------------------------//
Expand Down
5 changes: 5 additions & 0 deletions test/celeritas/user/DiagnosticTestBase.hh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ class DiagnosticTestBase : virtual public StepCollectorTestBase
void print_expected() const;

protected:
virtual void gather_batch_results();

std::shared_ptr<ActionDiagnostic> action_diagnostic_;
std::shared_ptr<StepDiagnostic> step_diagnostic_;

private:
RunResult result;
};

//---------------------------------------------------------------------------//
Expand Down
6 changes: 4 additions & 2 deletions test/celeritas/user/MctruthTestBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ auto MctruthTestBase::run(size_type num_tracks, size_type num_steps)
-> RunResult
{
this->run_impl<MemSpace::host>(num_tracks, num_steps);
return result;
}

void MctruthTestBase::gather_batch_results()
{
example_mctruth_->sort();

RunResult result;
for (ExampleMctruth::Step const& s : example_mctruth_->steps())
{
result.event.push_back(s.event);
Expand All @@ -94,7 +97,6 @@ auto MctruthTestBase::run(size_type num_tracks, size_type num_steps)
result.pos.insert(result.pos.end(), std::begin(s.pos), std::end(s.pos));
result.dir.insert(result.dir.end(), std::begin(s.dir), std::end(s.dir));
}
return result;
}

//---------------------------------------------------------------------------//
Expand Down
5 changes: 5 additions & 0 deletions test/celeritas/user/MctruthTestBase.hh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ class MctruthTestBase : virtual public StepCollectorTestBase
RunResult run(size_type num_tracks, size_type num_steps);

protected:
virtual void gather_batch_results();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual void gather_batch_results();
void gather_batch_results() override;


std::shared_ptr<ExampleMctruth> example_mctruth_;
std::shared_ptr<StepCollector> collector_;

private:
RunResult result;
};

//---------------------------------------------------------------------------//
Expand Down
61 changes: 61 additions & 0 deletions test/celeritas/user/StepCollector.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,31 @@ class TestEm3CollectorTestBase : public TestEm3Base,
}
};

class TestPhotonCollectorTestBase : public TestEm3CollectorTestBase
{
VecPrimary make_primaries(size_type count) override
{
// Get photo id
auto photon = this->particle()->find(pdg::gamma());
CELER_ASSERT(photon);

Primary p;
p.energy = MevEnergy{10.0};
p.position = {-22, 0, 0};
p.direction = {1, 0, 0};
p.time = 0;
p.particle_id = photon;
std::vector<Primary> result(count, p);

for (auto i : range(count))
{
result[i].event_id = EventId{0};
result[i].track_id = TrackId{i};
}
return result;
}
};

#define TestEm3MctruthTest TEST_IF_CELERITAS_GEANT(TestEm3MctruthTest)
class TestEm3MctruthTest : public TestEm3CollectorTestBase,
public MctruthTestBase
Expand All @@ -129,6 +154,16 @@ class TestEm3CaloTest : public TestEm3CollectorTestBase, public CaloTestBase
}
};

#define TestPhotonCaloTest TEST_IF_CELERITAS_GEANT(TestPhotonCaloTest)
class TestPhotonCaloTest : public TestPhotonCollectorTestBase,
public CaloTestBase
{
VecString get_detector_names() const final
{
return {"gap_lv_0", "gap_lv_1", "gap_lv_2"};
}
};

//---------------------------------------------------------------------------//
// ERROR CHECKING
//---------------------------------------------------------------------------//
Expand Down Expand Up @@ -313,6 +348,32 @@ TEST_F(TestEm3CaloTest, TEST_IF_CELER_DEVICE(step_device))
EXPECT_VEC_NEAR(expected_edep, result.edep, 0.5);
}

TEST_F(TestPhotonCaloTest, sixteen_batches)
{
auto result = this->run<MemSpace::host>(16, 32, 16);

static double const expected_edep[]
= {9.0653751813736, 17.177626720468, 12.691359768897};
static double const expected_edep_err[]
= {0.64823529758419, 0.42812745087497, 0.63485083267392};

EXPECT_VEC_NEAR(expected_edep, result.edep, 0.5);
EXPECT_VEC_NEAR(expected_edep_err, result.edep_err, 0.5);
}

TEST_F(TestPhotonCaloTest, TEST_IF_CELER_DEVICE(step_device))
{
auto result = this->run<MemSpace::device>(16, 32, 16);

static double const expected_edep[]
= {9.0653751813736, 17.177626720468, 12.691359768897};
static double const expected_edep_err[]
= {0.64823529758419, 0.42812745087497, 0.63485083267392};

EXPECT_VEC_NEAR(expected_edep, result.edep, 0.5);
EXPECT_VEC_NEAR(expected_edep_err, result.edep_err, 0.5);
}

//---------------------------------------------------------------------------//
} // namespace test
} // namespace celeritas
48 changes: 36 additions & 12 deletions test/celeritas/user/StepCollectorTestBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,55 @@ namespace test
* Run a stepping loop with the core data.
*/
template<MemSpace M>
void StepCollectorTestBase::run_impl(size_type num_tracks, size_type num_steps)
void StepCollectorTestBase::run_impl(size_type num_tracks_per_batch,
size_type num_steps,
size_type num_batches)
{
// Save number of batches internally for normalization
num_batches_ = num_batches;

// Compute total tracks
size_type num_tracks = num_tracks_per_batch * num_batches;

// Initialize stepper
StepperInput step_inp;
step_inp.params = this->core();
step_inp.stream_id = StreamId{0};
step_inp.num_track_slots = num_tracks;

Stepper<M> step(step_inp);
LogContextException log_context{this->output_reg().get()};

// Initial step
auto primaries = this->make_primaries(num_tracks);
StepperResult count;
CELER_TRY_HANDLE(count = step(make_span(primaries)), log_context);
// Initalize results
this->initalize();

while (count && --num_steps > 0)
// Loop over batches
for (size_type i_batch = 0; i_batch < num_batches; ++i_batch)
{
CELER_TRY_HANDLE(count = step(), log_context);
// Initialize primaries for this batch
auto primaries = this->make_primaries(num_tracks_per_batch);

// Take num_steps steps
StepperResult count;
CELER_TRY_HANDLE(count = step(make_span(primaries)), log_context);
while (count && --num_steps > 0)
{
CELER_TRY_HANDLE(count = step(), log_context);
}

// Gathering of results
this->gather_batch_results();
}

// Post-processing (e.g. normalization) of results
this->finalize();
}

template void
StepCollectorTestBase::run_impl<MemSpace::host>(size_type, size_type);
template void
StepCollectorTestBase::run_impl<MemSpace::device>(size_type, size_type);
template void StepCollectorTestBase::run_impl<MemSpace::host>(size_type,
size_type,
size_type);
template void StepCollectorTestBase::run_impl<MemSpace::device>(size_type,
size_type,
size_type);

//---------------------------------------------------------------------------//
} // namespace test
Expand Down
10 changes: 9 additions & 1 deletion test/celeritas/user/StepCollectorTestBase.hh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ class StepCollectorTestBase : virtual public GlobalTestBase

protected:
template<MemSpace M>
void run_impl(size_type num_tracks, size_type num_steps);
void run_impl(size_type num_tracks_per_batch,
size_type num_steps,
size_type num_batches = 1);

virtual void gather_batch_results(){};
virtual void initalize(){};
virtual void finalize(){};

size_t num_batches_;
};

//---------------------------------------------------------------------------//
Expand Down