Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo the ImageSetData class to use less memory #438

Merged
merged 14 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Asmit Bhowmick
Ben Williams
Billy K. Poon
Clemens Weninger
Daniel Paley
David Waterman
Derek Mendez
Dorothee Liebschner
Expand Down
1 change: 1 addition & 0 deletions newsfragments/438.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce, in some cases drastically, memory usage of ``ImageSet`` objects.
1 change: 1 addition & 0 deletions src/dxtbx/boost_python/imageset_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ namespace dxtbx { namespace boost_python {
.def("set_params", &ImageSetData_set_params)
.def("get_format_class", &ImageSetData_get_format)
.def("set_format_class", &ImageSetData_set_format)
.def("partial_data", &ImageSetData::partial_data)
.add_property(
"external_lookup",
make_function(&ImageSetData::external_lookup, return_internal_reference<>()))
Expand Down
2 changes: 1 addition & 1 deletion src/dxtbx/format/Format.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __len__(self):
return len(self._filenames)

def copy(self, filenames):
return Reader(self.format_class, filenames)
return Reader(self.format_class, filenames, **self._kwargs)

def is_single_file_reader(self):
return False
Expand Down
8 changes: 4 additions & 4 deletions src/dxtbx/format/FormatMultiImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def paths(self):
def __len__(self):
return self._num_images

def copy(self, filenames, indices=None):
return Reader(self.format_class, filenames, indices)
def copy(self, filenames, num_images=None):
return Reader(self.format_class, filenames, num_images, **self.kwargs)

def identifiers(self):
return ["%s-%d" % (self._filename, index) for index in range(len(self))]
Expand Down Expand Up @@ -146,8 +146,8 @@ def get_imageset(
# If get_num_images hasn't been implemented, we need indices for number of images
if cls.get_num_images == FormatMultiImage.get_num_images:
assert single_file_indices is not None
assert min(single_file_indices) >= 0
num_images = max(single_file_indices) + 1
assert len(single_file_indices) > 0
num_images = len(single_file_indices)
else:
num_images = None

Expand Down
65 changes: 45 additions & 20 deletions src/dxtbx/imageset.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,25 @@ class ImageSetData {
format_ = x;
}

ImageSetData partial_data(boost::python::object reader, std::size_t first, std::size_t last) const {
DXTBX_ASSERT(last > first);
ImageSetData partial = ImageSetData(reader, masker_);
for (size_t i = 0; i < last-first; i++) {
partial.beams_[i] = beams_[i+first];
partial.detectors_[i] = detectors_[i+first];
partial.goniometers_[i] = goniometers_[i+first];
partial.scans_[i] = scans_[i+first];
partial.reject_[i] = reject_[i+first];
}
partial.external_lookup_ = external_lookup_;
partial.template_ = template_;
partial.vendor_ = vendor_;
partial.params_ = params_;
partial.format_ = format_;

return partial;
}

protected:
ImageBuffer get_image_buffer_from_tuple(boost::python::tuple obj) {
// Get the class name
Expand Down Expand Up @@ -581,7 +600,7 @@ class ImageSet {
}
} else {
// Check indices
if (scitbx::af::max(indices) >= data.size()) {
if (indices.size() > data.size()) {
throw DXTBX_ERROR("Indices are not consistent with # of images");
}
}
Expand Down Expand Up @@ -903,7 +922,7 @@ class ImageSet {
*/
virtual beam_ptr get_beam_for_image(std::size_t index = 0) const {
DXTBX_ASSERT(index < indices_.size());
return data_.get_beam(indices_[index]);
return data_.get_beam(index);
}

/**
Expand All @@ -912,7 +931,7 @@ class ImageSet {
*/
virtual detector_ptr get_detector_for_image(std::size_t index = 0) const {
DXTBX_ASSERT(index < indices_.size());
return data_.get_detector(indices_[index]);
return data_.get_detector(index);
}

/**
Expand All @@ -921,7 +940,7 @@ class ImageSet {
*/
virtual goniometer_ptr get_goniometer_for_image(std::size_t index = 0) const {
DXTBX_ASSERT(index < indices_.size());
return data_.get_goniometer(indices_[index]);
return data_.get_goniometer(index);
}

/**
Expand All @@ -930,7 +949,7 @@ class ImageSet {
*/
virtual scan_ptr get_scan_for_image(std::size_t index = 0) const {
DXTBX_ASSERT(index < indices_.size());
return data_.get_scan(indices_[index]);
return data_.get_scan(index);
}

/**
Expand All @@ -940,7 +959,7 @@ class ImageSet {
*/
virtual void set_beam_for_image(const beam_ptr &beam, std::size_t index = 0) {
DXTBX_ASSERT(index < indices_.size());
data_.set_beam(beam, indices_[index]);
data_.set_beam(beam, index);
}

/**
Expand All @@ -951,7 +970,7 @@ class ImageSet {
virtual void set_detector_for_image(const detector_ptr &detector,
std::size_t index = 0) {
DXTBX_ASSERT(index < indices_.size());
data_.set_detector(detector, indices_[index]);
data_.set_detector(detector, index);
}

/**
Expand All @@ -962,7 +981,7 @@ class ImageSet {
virtual void set_goniometer_for_image(const goniometer_ptr &goniometer,
std::size_t index = 0) {
DXTBX_ASSERT(index < indices_.size());
data_.set_goniometer(goniometer, indices_[index]);
data_.set_goniometer(goniometer, index);
}

/**
Expand All @@ -973,7 +992,7 @@ class ImageSet {
virtual void set_scan_for_image(const scan_ptr &scan, std::size_t index = 0) {
DXTBX_ASSERT(scan == NULL || scan->get_num_images() == 1);
DXTBX_ASSERT(index < indices_.size());
data_.set_scan(scan, indices_[index]);
data_.set_scan(scan, index);
}

/**
Expand All @@ -996,7 +1015,7 @@ class ImageSet {
*/
std::string get_image_identifier(std::size_t index) const {
DXTBX_ASSERT(index < indices_.size());
return data_.get_image_identifier(indices_[index]);
return data_.get_image_identifier(index);
}

/**
Expand All @@ -1006,7 +1025,7 @@ class ImageSet {
*/
void mark_for_rejection(std::size_t index, bool reject) {
DXTBX_ASSERT(index < indices_.size());
data_.mark_for_rejection(indices_[index], reject);
data_.mark_for_rejection(index, reject);
}

/**
Expand All @@ -1015,7 +1034,7 @@ class ImageSet {
*/
bool is_marked_for_rejection(std::size_t index) const {
DXTBX_ASSERT(index < indices_.size());
return data_.is_marked_for_rejection(indices_[index]);
return data_.is_marked_for_rejection(index);
}

/**
Expand All @@ -1037,9 +1056,9 @@ class ImageSet {
* @param last The last slice index
* @returns The partial set
*/
virtual ImageSet partial_set(std::size_t first, std::size_t last) const {
virtual ImageSet partial_set(boost::python::object reader, std::size_t first, std::size_t last) const {
DXTBX_ASSERT(last > first);
return ImageSet(data_,
return ImageSet(data_.partial_data(reader, first, last),
scitbx::af::const_ref<std::size_t>(&indices_[first], last - first));
}

Expand Down Expand Up @@ -1476,21 +1495,27 @@ class ImageSequence : public ImageSet {
* @param last The last index
* @returns The partial sequence
*/
ImageSequence partial_sequence(std::size_t first, std::size_t last) const {
ImageSequence partial_sequence(boost::python::object reader, std::size_t first, std::size_t last) const {
// Check slice indices
DXTBX_ASSERT(last > first);

// Construct a partial scan
Scan scan = detail::safe_dereference(ImageSet::get_scan_for_image(first));
for (std::size_t i = first + 1; i < last; ++i) {
scan += detail::safe_dereference(ImageSet::get_scan_for_image(i));
// Construct a partial data
ImageSetData _partial_data = data_.partial_data(reader, first, last);


// Now we use the partial data to construct the partial scan
Scan scan = detail::safe_dereference(_partial_data.get_scan(0));
for (std::size_t i=1; i<last-first; ++i) {
scan_ptr temp_scan_ptr = _partial_data.get_scan(i);
Scan temp_scan = detail::safe_dereference(temp_scan_ptr);
scan += temp_scan;
}

// Construct the partial indices
scitbx::af::const_ref<std::size_t> indices(&indices_[first], last - first);

// Construct the partial sequence
ImageSequence result(data_,
ImageSequence result(_partial_data,
indices,
get_beam(),
get_detector(),
Expand Down
47 changes: 43 additions & 4 deletions src/dxtbx/imageset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class MemReader:
def __init__(self, images):
self._images = images

def copy(self, paths):
"""
Experimental implementation where a copy of the reader also copies all
the data
"""
return MemReader(self._images)

def paths(self):
return ["" for im in self._images]

Expand Down Expand Up @@ -96,7 +103,11 @@ def __getitem__(self, item):
stop = item.stop or len(self)
if item.step is not None and item.step != 1:
raise IndexError("Step must be 1")
return self.partial_set(start, stop)
if self.data().has_single_file_reader():
reader = self.reader().copy(self.reader().paths(), stop - start)
else:
reader = self.reader().copy(self.reader().paths())
return self.partial_set(reader, start, stop)
else:
return self.get_corrected_data(item)

Expand Down Expand Up @@ -162,7 +173,10 @@ def paths(self):
"""
Return the list of paths
"""
return [self.get_path(i) for i in range(len(self))]
if self.data().has_single_file_reader():
return [self.get_path(i) for i in range(len(self))]
else:
return [self.reader().paths()[i] for i in self.indices()]


class ImageSetLazy(ImageSet):
Expand Down Expand Up @@ -203,6 +217,15 @@ def get_detector(self, index=None):
def get_beam(self, index=None):
return self._get_item_from_parent_or_format("beam", index)

def get_mask(self, index=None):
"""
ImageSet::get_mask internally dereferences a pointer to the _detector
member of ImageSetData, so we ensure the detector gets populated first.
"""
if getattr(super(), "get_detector")(index) is None:
self._load_models(index)
return self._get_item_from_parent_or_format("mask", index)

def get_goniometer(self, index=None):
return self._get_item_from_parent_or_format("goniometer", index)

Expand All @@ -220,7 +243,18 @@ def _load_models(self, index):

def __getitem__(self, item):
if isinstance(item, slice):
return ImageSetLazy(self.data(), indices=self.indices()[item])
start = item.start or 0
stop = item.stop or len(self)
if item.step is not None and item.step != 1:
raise IndexError("Step must be 1")
if self.data().has_single_file_reader():
reader = self.reader().copy(self.reader().paths(), stop - start)
else:
reader = self.reader().copy(self.reader().paths())
return ImageSetLazy(
self.data().partial_data(reader, start, stop),
indices=self.indices()[item],
)
self._load_models(item)
return super().__getitem__(item)

Expand Down Expand Up @@ -268,11 +302,16 @@ def __getitem__(self, item):
stop = len(self)
else:
stop -= offset

return self.partial_set(start, stop)
else:
start = item.start or 0
stop = item.stop or (len(self) + offset)
return self.partial_set(start - offset, stop - offset)
if self.data().has_single_file_reader():
reader = self.reader().copy(self.reader().paths(), stop - start)
else:
reader = self.reader().copy(self.reader().paths())
return self.partial_set(reader, start - offset, stop - offset)
else:
return self.get_corrected_data(item)

Expand Down