Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-39146: Memory efficiency for rbTransiNetTask/Interface #17

Merged
merged 4 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 40 additions & 6 deletions python/lsst/meas/transiNet/rbTransiNetInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def init_model(self):
# Put the model in evaluation mode instead of training model.
self.model.eval()

def input_to_batches(self, inputs, batchSize):
"""Convert a list of inputs to a generator of batches.

Parameters
----------
inputs : `list` [`CutoutInputs`]
Inputs to be scored.

Returns
-------
batches : `generator`
Generator of batches of inputs.
"""
for i in range(0, len(inputs), batchSize):
yield inputs[i:i + batchSize]

def prepare_input(self, inputs):
"""
Convert inputs from numpy arrays, etc. to a torch.tensor blob.
Expand Down Expand Up @@ -103,8 +119,8 @@ def prepare_input(self, inputs):

labelsList.append(inp.label)

torchBlob = torch.stack(cutoutsList)
return torchBlob, labelsList
blob = torch.stack(cutoutsList)
return blob, labelsList

def infer(self, inputs):
"""Return the score of this cutout.
Expand All @@ -119,9 +135,27 @@ def infer(self, inputs):
scores : `numpy.array`
Float scores for each element of ``inputs``.
"""
blob, labels = self.prepare_input(inputs)
result = self.model(blob)
scores = torch.sigmoid(result)
npyScores = scores.detach().numpy().ravel()

# Convert the inputs to batches.
# TODO: The batch size is set to 64 for now. Later when
# deploying parallel instances of the task, memory limits
# should be taken into account, if necessary.
batches = self.input_to_batches(inputs, batchSize=64)

# Loop over the batches
for i, batch in enumerate(batches):
torchBlob, labelsList = self.prepare_input(batch)

# Run the model
with torch.no_grad():
output_ = self.model(torchBlob)
output = torch.sigmoid(output_)

# And append the results to the list
if i == 0:
scores = output
else:
scores = torch.cat((scores, output.cpu()), dim=0)

npyScores = scores.detach().numpy().ravel()
return npyScores
3 changes: 2 additions & 1 deletion python/lsst/meas/transiNet/rbTransiNetTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def __init__(self, **kwargs):

def run(self, template, science, difference, diaSources):
cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources]
self.log.info("Extracted %d cutouts.", len(cutouts))
scores = self.interface.infer(cutouts)

self.log.info("Scored %d cutouts.", len(scores))
schema = lsst.afw.table.Schema()
schema.addField(diaSources.schema["id"].asField())
schema.addField("score", doc="real/bogus score of this source", type=float)
Expand Down
16 changes: 13 additions & 3 deletions tests/test_RBTransiNetInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,25 @@
from lsst.meas.transiNet import RBTransiNetInterface, CutoutInputs


class TestOneCutout(unittest.TestCase):
class TestInference(unittest.TestCase):
def setUp(self):
self.interface = RBTransiNetInterface("dummy", "local")

def test_infer_empty(self):
"""Test running infer on images containing all zeros.
def test_infer_single_empty(self):
"""Test running infer on a single blank triplet.
"""
data = np.zeros((256, 256), dtype=np.single)
inputs = CutoutInputs(science=data, difference=data, template=data)
result = self.interface.infer([inputs])
self.assertTupleEqual(result.shape, (1,))
self.assertAlmostEqual(result[0], 0.5011908) # Empricial meaningless value spit by this very model

def test_infer_many(self):
"""Test running infer on a large number of images,
to make sure partitioning to batches works.
"""
data = np.zeros((256, 256), dtype=np.single)
inputs = [CutoutInputs(science=data, difference=data, template=data) for _ in range(100)]
result = self.interface.infer(inputs)
self.assertTupleEqual(result.shape, (100,))
self.assertAlmostEqual(result[0], 0.5011908)