Skip to content

Commit

Permalink
Add additional options helpful for LBA calibration (#111)
Browse files Browse the repository at this point in the history
* Add constraint that prevents large amplitude ratios between antennas
* Add some checks to avoid NaNs, etc.
* Add option to flag unconverged solutions
* Add option (propagateconvergedonly) to propagate converged solutions only
* Add option to flag only unconverged solutions that diverged
* Improve handling of diverged solutions
* Set diverged flags for all stations instead of station-by-station
* Add back reporting of new options
  • Loading branch information
darafferty authored and aroffringa committed Feb 5, 2019
1 parent 2f47827 commit 99c15d1
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 46 deletions.
20 changes: 10 additions & 10 deletions DDECal/Constraint.h
Expand Up @@ -11,7 +11,7 @@
* This class is the base class for classes that implement a constraint on
* calibration solutions. Constraints are used to increase
* the converge of calibration by applying these inside the solving step.
*
*
* The MultiDirSolver class uses this class for constrained calibration.
*/
class Constraint
Expand All @@ -32,16 +32,16 @@ class Constraint
_nAntennas(0), _nDirections(0), _nChannelBlocks(0),
_nThreads(0)
{ }

virtual ~Constraint() { }

/**
* Function that initializes the constraint for the next calibration iteration.
* It should be called each time all antenna solutions have been calculated,
* but before the constraint has been applied to all those antenna solutions.
*
*
* Unlike Apply(), this method is not thread safe.
*
*
* @param bool This can be used to specify whether the previous solution "step" is
* smaller than the requested precision, i.e. calibration with the constrained
* has converged. This allows a constraint to apply
Expand All @@ -50,7 +50,7 @@ class Constraint
* when hasReachedPrecision=true.
*/
virtual void PrepareIteration(bool /*hasReachedPrecision*/, size_t /*iteration*/, bool /*finalIter*/) { }

/**
* Whether the constraint has been satisfied. The calibration process will continue
* at least as long as Satisfied()=false, and performs at least one more iteration
Expand All @@ -59,7 +59,7 @@ class Constraint
* convergence.
*/
virtual bool Satisfied() const { return true; }

/**
* This method applies the constraints to the solutions.
* @param solutions is an array of array, such that:
Expand Down Expand Up @@ -133,7 +133,7 @@ class DiagonalConstraint : public Constraint
{
public:
DiagonalConstraint(size_t polsPerSolution) : _polsPerSolution(polsPerSolution) {};

virtual std::vector<Result> Apply(
std::vector<std::vector<dcomplex> >& solutions,
double time,
Expand All @@ -145,7 +145,7 @@ class DiagonalConstraint : public Constraint
/**
* This constraint averages the solutions of a given list of antennas,
* so that they have equal solutions.
*
*
* The DDE solver uses this constraint to average the solutions of the core
* antennas. Core antennas are determined by a given maximum distance from
* a reference antenna. The reference antenna is by default the first
Expand All @@ -161,12 +161,12 @@ class CoreConstraint : public Constraint
{
_coreAntennas = coreAntennas;
}

virtual std::vector<Result> Apply(
std::vector<std::vector<dcomplex> >& solutions,
double time,
std::ostream* statStream) final override;

private:
std::set<size_t> _coreAntennas;
};
Expand Down
127 changes: 97 additions & 30 deletions DDECal/DDECal.cc
Expand Up @@ -90,7 +90,13 @@ DDECal::DDECal (DPInput* input,
"/instrument.h5")),
itsH5Parm (itsH5ParmName, true),
itsPropagateSolutions (parset.getBool (prefix + "propagatesolutions",
false)),
false)),
itsPropagateConvergedOnly (parset.getBool (prefix + "propagateconvergedonly",
false)),
itsFlagUnconverged (parset.getBool (prefix + "flagunconverged",
false)),
itsFlagDivergedOnly (parset.getBool (prefix + "flagdivergedonly",
false)),
itsTimeStep (0),
itsSolInt (parset.getInt (prefix + "solint", 1)),
itsMinVisRatio (parset.getDouble (prefix + "minvisratio", 0.0)),
Expand Down Expand Up @@ -118,7 +124,7 @@ DDECal::DDECal (DPInput* input,

if(!itsStatFilename.empty())
itsStatStream.reset(new std::ofstream(itsStatFilename));

vector<string> strDirections;
if (itsUseModelColumn) {
itsModelData.resize(itsSolInt);
Expand Down Expand Up @@ -148,7 +154,7 @@ DDECal::DDECal (DPInput* input,
itsMode = GainCal::stringToCalType(
boost::to_lower_copy(parset.getString(prefix + "mode",
"complexgain")));

initializeConstraints(parset, prefix);
initializePredictSteps(parset, prefix);
}
Expand All @@ -169,7 +175,7 @@ void DDECal::initializeConstraints(const ParameterSet& parset, const string& pre
itsConstraints.emplace_back(new CoreConstraint());
}
if(itsSmoothnessConstraint != 0.0) {
itsConstraints.emplace_back(new SmoothnessConstraint(itsSmoothnessConstraint));
itsConstraints.emplace_back(new SmoothnessConstraint(itsSmoothnessConstraint));
}
switch(itsMode) {
case GainCal::DIAGONAL:
Expand Down Expand Up @@ -253,7 +259,7 @@ void DDECal::initializeConstraints(const ParameterSet& parset, const string& pre
itsFullMatrixMinimalization = true;
break;
default:
throw std::runtime_error("Unexpected mode: " +
throw std::runtime_error("Unexpected mode: " +
GainCal::calTypeToString(itsMode));
}
}
Expand Down Expand Up @@ -413,7 +419,7 @@ void DDECal::updateInfo (const DPInfo& infoIn)
}
coreConstraint->initialize(coreAntennaIndices);
}

#ifdef HAVE_ARMADILLO
ScreenConstraint* screenConstraint = dynamic_cast<ScreenConstraint*>(itsConstraints[i].get());
if(screenConstraint != 0)
Expand Down Expand Up @@ -445,7 +451,7 @@ void DDECal::updateInfo (const DPInfo& infoIn)
screenConstraint->setOtherAntennas(otherAntennaIndices);
}
#endif

TECConstraintBase* tecConstraint = dynamic_cast<TECConstraintBase*>(itsConstraints[i].get());
if(tecConstraint != nullptr)
{
Expand Down Expand Up @@ -484,7 +490,10 @@ void DDECal::show (std::ostream& os) const
os
<< " tolerance: " << itsMultiDirSolver.get_accuracy() << '\n'
<< " max iter: " << itsMultiDirSolver.max_iterations() << '\n'
<< " propagatesolutions: " << std::boolalpha << itsPropagateSolutions << '\n'
<< " flag unconverged: " << std::boolalpha << itsFlagUnconverged << '\n'
<< " diverged only: " << std::boolalpha << itsFlagDivergedOnly << '\n'
<< " propagate solutions: " << std::boolalpha << itsPropagateSolutions << '\n'
<< " converged only: " << std::boolalpha << itsPropagateConvergedOnly << '\n'
<< " detect stalling: " << std::boolalpha << itsMultiDirSolver.get_detect_stalling() << '\n'
<< " step size: " << itsMultiDirSolver.get_step_size() << '\n'
<< " mode (constraints): " << GainCal::calTypeToString(itsMode) << '\n'
Expand Down Expand Up @@ -534,8 +543,16 @@ void DDECal::showTimings (std::ostream& os, double duration) const

void DDECal::initializeScalarSolutions() {
if (itsTimeStep/itsSolInt>0 && itsPropagateSolutions) {
// initialize solutions with those of the previous step
itsSols[itsTimeStep/itsSolInt] = itsSols[itsTimeStep/itsSolInt-1];
if (itsNIter[itsTimeStep/itsSolInt-1]>itsMultiDirSolver.max_iterations() && itsPropagateConvergedOnly) {
// initialize solutions with 1.
size_t n = itsDirections.size()*info().antennaNames().size();
for (vector<DComplex>& solvec : itsSols[itsTimeStep/itsSolInt]) {
solvec.assign(n, 1.0);
}
} else {
// initialize solutions with those of the previous step
itsSols[itsTimeStep/itsSolInt] = itsSols[itsTimeStep/itsSolInt-1];
}
} else {
// initialize solutions with 1.
size_t n = itsDirections.size()*info().antennaNames().size();
Expand All @@ -547,8 +564,23 @@ void DDECal::initializeScalarSolutions() {

void DDECal::initializeFullMatrixSolutions() {
if (itsTimeStep/itsSolInt>0 && itsPropagateSolutions) {
// initialize solutions with those of the previous step
itsSols[itsTimeStep/itsSolInt] = itsSols[itsTimeStep/itsSolInt-1];
if (itsNIter[itsTimeStep/itsSolInt-1]>itsMultiDirSolver.max_iterations() && itsPropagateConvergedOnly) {
// initialize solutions with unity matrix [1 0 ; 0 1].
size_t n = itsDirections.size()*info().antennaNames().size();
for (vector<DComplex>& solvec : itsSols[itsTimeStep/itsSolInt]) {
solvec.resize(n*4);
for(size_t i=0; i!=n; ++i)
{
solvec[i*4 + 0] = 1.0;
solvec[i*4 + 1] = 0.0;
solvec[i*4 + 2] = 0.0;
solvec[i*4 + 3] = 1.0;
}
}
} else {
// initialize solutions with those of the previous step
itsSols[itsTimeStep/itsSolInt] = itsSols[itsTimeStep/itsSolInt-1];
}
} else {
// initialize solutions with unity matrix [1 0 ; 0 1].
size_t n = itsDirections.size()*info().antennaNames().size();
Expand Down Expand Up @@ -592,7 +624,7 @@ void DDECal::flagChannelBlock(size_t cbIndex)
// Set the antenna-based weights to zero
for(size_t bl=0; bl<nBl; ++bl)
{
size_t
size_t
ant1 = info().getAnt1()[bl],
ant2 = info().getAnt2()[bl];
for(size_t ch=itsChanBlockStart[cbIndex]; ch!=itsChanBlockStart[cbIndex+1]; ++ch)
Expand Down Expand Up @@ -628,11 +660,11 @@ void DDECal::checkMinimumVisibilities()
void DDECal::doSolve ()
{
checkMinimumVisibilities();

for (std::unique_ptr<Constraint>& constraint : itsConstraints) {
constraint->SetWeights(itsWeightsPerAntenna);
}

if(itsFullMatrixMinimalization)
initializeFullMatrixSolutions();
else
Expand All @@ -657,6 +689,41 @@ void DDECal::doSolve ()
itsNIter[itsTimeStep/itsSolInt] = solveResult.iterations;
itsNApproxIter[itsTimeStep/itsSolInt] = solveResult.constraintIterations;

// Check for nonconvergence and flag if desired. Unconverged solutions are
// identified by the number of iterations being one more than the max allowed
// number
if (solveResult.iterations > itsMultiDirSolver.max_iterations() && itsFlagUnconverged) {
for (size_t i=0; i!=solveResult._results.size(); ++i) {
for (size_t j=0; j!=solveResult._results[i].size(); ++j) {
if (itsFlagDivergedOnly) {
// Set weights with negative values (indicating unconverged
// solutions that diverged) to zero (all other unconverged
// solutions remain unflagged)
for (size_t k=0; k!=solveResult._results[i][j].weights.size(); ++k) {
if (solveResult._results[i][j].weights[k] < 0.) {
solveResult._results[i][j].weights[k] = 0.;
}
}
} else {
// Set all weights to zero
solveResult._results[i][j].weights.assign(solveResult._results[i][j].weights.size(), 0.);
}
}
}
} else {
// Set any negative weights (indicating unconverged solutions that diverged) to
// one (all other unconverged solutions are unflagged already)
for (size_t i=0; i!=solveResult._results.size(); ++i) {
for (size_t j=0; j!=solveResult._results[i].size(); ++j) {
for (size_t k=0; k!=solveResult._results[i][j].weights.size(); ++k) {
if (solveResult._results[i][j].weights[k] < 0.) {
solveResult._results[i][j].weights[k] = 1.;
}
}
}
}
}

// Store constraint solutions if any constaint has a non-empty result
bool someConstraintHasResult = false;
for (uint constraintnum=0; constraintnum<solveResult._results.size(); ++constraintnum) {
Expand All @@ -668,7 +735,7 @@ void DDECal::doSolve ()
if (someConstraintHasResult) {
itsConstraintSols[itsTimeStep/itsSolInt]=solveResult._results;
}

itsTimer.stop();

for(size_t time=0; time<=itsStepInSolInt; ++time)
Expand All @@ -679,7 +746,7 @@ void DDECal::doSolve ()
// Push data (possibly changed) to next step
getNextStep()->process(itsBufs[time]);
}

itsTimer.start();
}

Expand All @@ -695,10 +762,10 @@ bool DDECal::process (const DPBuffer& bufin)
itsBufs[itsStepInSolInt].copy(bufin);
itsOriginalFlags[itsStepInSolInt].assign( bufin.getFlags() );
itsOriginalWeights[itsStepInSolInt].assign( bufin.getWeights() );

itsDataPtrs[itsStepInSolInt] = itsBufs[itsStepInSolInt].getData().data();
itsWeightPtrs[itsStepInSolInt] = itsBufs[itsStepInSolInt].getWeights().data();

// UVW flagging happens on the copy of the buffer
// These flags are later restored and therefore not written
itsUVWFlagStep.process(itsBufs[itsStepInSolInt]);
Expand All @@ -715,7 +782,7 @@ bool DDECal::process (const DPBuffer& bufin)
std::mutex measuresMutex;
for(DP3::DPPP::Predict& predict : itsPredictSteps)
predict.setThreadData(*itsThreadPool, measuresMutex);

itsThreadPool->For(0, itsPredictSteps.size(), [&](size_t dir, size_t /*thread*/) {
itsPredictSteps[dir].process(itsBufs[itsStepInSolInt]);
itsModelDataPtrs[itsStepInSolInt][dir] =
Expand All @@ -727,13 +794,13 @@ bool DDECal::process (const DPBuffer& bufin)
const size_t nBl = info().nbaselines();
const size_t nCh = info().nchan();
const size_t nCr = 4;

size_t nchanblocks = itsChanBlockFreqs.size();

double weightFactor = 1./(nCh*(info().nantenna()-1)*nCr*itsSolInt);

for (size_t bl=0; bl<nBl; ++bl) {
size_t
size_t
chanblock = 0,
ant1 = info().getAnt1()[bl],
ant2 = info().getAnt2()[bl];
Expand Down Expand Up @@ -908,11 +975,11 @@ void DDECal::writeSolutions()
soltab = itsH5Parm.createSolTab(solTabName, "amplitude", axes);
soltab.setComplexValues(sols, vector<double>(), true, historyString);
break;
default:
default:
throw std::runtime_error("Constraint should have produced output");
}

// Tell H5Parm that all antennas and directions were used
// Tell H5Parm that all antennas and directions were used
std::vector<std::string> antennaNames(info().antennaNames().size());
for (uint i=0; i<info().antennaNames().size(); ++i) {
antennaNames[i]=info().antennaNames()[i];
Expand Down Expand Up @@ -988,13 +1055,13 @@ void DDECal::writeSolutions()
"step " + itsName + " in parset: \n" +
itsParsetString);

// Tell H5Parm that all antennas and directions were used
// Tell H5Parm that all antennas and directions were used
std::vector<std::string> antennaNames(info().antennaNames().size());
for (uint i=0; i<info().antennaNames().size(); ++i) {
antennaNames[i]=info().antennaNames()[i];
}
soltab.setAntennas(antennaNames);

soltab.setSources(getDirectionNames());

if (soltab.hasAxis("pol")) {
Expand All @@ -1011,7 +1078,7 @@ void DDECal::writeSolutions()
default:
throw std::runtime_error("No metadata for numpolarizations = " + std::to_string(soltab.getAxis("pol").size));
}

soltab.setPolarizations(polarizations);
}

Expand Down Expand Up @@ -1082,11 +1149,11 @@ void DDECal::subtractCorrectedModel(bool fullJones)
std::vector<std::complex<float>*>& modelData = itsModelDataPtrs[time];
for (size_t bl=0; bl<nBl; ++bl)
{
size_t
size_t
chanblock = 0,
ant1 = info().getAnt1()[bl],
ant2 = info().getAnt2()[bl];

for (size_t ch=0; ch<nCh; ++ch)
{
MC2x2 value(MC2x2::Zero());
Expand Down Expand Up @@ -1118,5 +1185,5 @@ void DDECal::subtractCorrectedModel(bool fullJones)
} //bl loop
} //time loop
}

} } //# end namespace

0 comments on commit 99c15d1

Please sign in to comment.