Skip to content

Commit

Permalink
Major: define a more butler-oriented wrapper around modelPackages
Browse files Browse the repository at this point in the history
to be passed back and forth between the formatter and
storageAdapterButler.
  • Loading branch information
NimSed committed Dec 12, 2023
1 parent bdee3c0 commit dc72b2b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 74 deletions.
31 changes: 24 additions & 7 deletions python/lsst/meas/transiNet/modelPackages/formatters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from lsst.daf.butler import Formatter
import torch
from io import BytesIO
from . import utils
#from . import utils
import dataclasses

__all__ = ["PytorchCheckpointFormatter", "PytorchCheckpointFormatter"]
__all__ = ["PytorchCheckpointFormatter", "PytorchCheckpointFormatter",
"NNModelPackageFormatter", "NNModelPackagePayload"]


class PytorchCheckpointFormatter(Formatter):
Expand Down Expand Up @@ -42,15 +44,30 @@ def write(self, inMemoryDataset):
path = self.fileDescriptor.location.path
self.writeFile(path, inMemoryDataset)

class BinaryFormatter(Formatter):
"""Formatter for binary files.
class NNModelPackagePayload():
""" A thin wrapper around the payload of a NNModelPackageFormatter,
which simply carries an in-memory file between the formatter and the
storage adapter of model pacakges.
"""
extension = ".bin"
def __init__(self):
self.bytes = BytesIO()

class NNModelPackageFormatter(Formatter):
"""Formatter for NN model packages.
"""
extension = ".zip"

def read(self, component=None):
payload = NNModelPackagePayload()
with open(self.fileDescriptor.location.path, "rb") as f:
return BytesIO(f.read())
payload.bytes = BytesIO(f.read())
return payload

def write(self, inMemoryDataset):
with open(self.fileDescriptor.location.path, "wb") as f:
f.write(inMemoryDataset.getvalue())
f.write(inMemoryDataset.bytes.getvalue())
print("Wrote model package to", self.fileDescriptor.location.path)


# payload = NNModelPackagePayload
# payload.bytes = BytesIO(b"hello")
154 changes: 88 additions & 66 deletions python/lsst/meas/transiNet/modelPackages/storageAdapterButler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from .storageAdapterBase import StorageAdapterBase
from lsst.meas.transiNet.modelPackages.formatters import BinaryFormatter
from lsst.daf.butler import DatasetType, FileDataset, DatasetRef
from lsst.meas.transiNet.modelPackages.formatters import NNModelPackagePayload
from lsst.daf.butler import DatasetType
from . import utils

import torch
import zipfile
import tempfile
import io
import os
import yaml
Expand Down Expand Up @@ -43,43 +42,87 @@ def __init__(self, model_package_name, butler=None, butler_loaded_package=None):

self.model_package_name = model_package_name
self.butler = butler
self.butler_loaded_package = butler_loaded_package

self.model_file = self.checkpoint_file = self.metadata_file = None

self.fetch()
# butler and butler_loaded_package are mutually exclusive.
if butler is not None and butler_loaded_package is not None:
raise ValueError('butler and butler_loaded_package are mutually exclusive')

@staticmethod
def lookupFunction(config, dataSetType, registry, dataId, collections):
"""Lookup function that locates the pretrained weights of a model
package in the Butler repository.
# Use the butler_loaded_package if it is provided.
if butler_loaded_package is not None:
self.from_payload(butler_loaded_package)

# If the butler is provided, we are in the "offline" mode. Let's go
# and fetch the model package from the butler repository.
if butler is not None:
self.fetch()

@classmethod
def from_other(cls, other, use_name=None):
"""
Create a new instance of this class from another instance, which
can be of a different mode.
Parameters
----------
other : `StorageAdapterBase`
The instance to create a new instance from.
"""

instance = cls(model_package_name = use_name or other.model_package_name)

if hasattr(other, 'model_file'):
instance.model_file = other.model_file
instance.checkpoint_file = other.checkpoint_file
instance.metadata_file = other.metadata_file
else:
with open(other.model_filename, mode="rb") as f:
instance.model_file = io.BytesIO(f.read())
with open(other.checkpoint_filename, mode="rb") as f:
instance.checkpoint_file = io.BytesIO(f.read())
with open(other.metadata_filename, mode="rb") as f:
instance.metadata_file = io.BytesIO(f.read())

All parameters are automatically set by the graph builder, except
for the `config` parameter, which is manually set in the init()
method of the client task's `Connections` class.
return instance

def from_payload(self, payload):
"""
Decompress the payload into the memory and save each component
as an in-memory file.
Parameters
----------
config : `lsst.pipe.base.PipelineTaskConfig`
The configuration of the client pipeline task.
dataSetType : `lsst.daf.butler.DatasetType`
The `DatasetType` being queried.
registry : `lsst.daf.butler.Registry`
The `Registry` to use to find datasets.
dataId : `dict`
The `DataId` to use to find datasets -- ignored.
collections : `str` or `list` of `str`
The collection or collections to search for datasets -- ignored.
payload : `NNModelPackagePayload`
The payload to create the instance from.
"""
with zipfile.ZipFile(payload.bytes, mode="r") as zf:
with zf.open('checkpoint') as f:
self.checkpoint_file = io.BytesIO(f.read())
with zf.open('architecture') as f:
self.model_file = io.BytesIO(f.read())
with zf.open('metadata') as f:
self.metadata_file = io.BytesIO(f.read())

def to_payload(self):
"""
Compress the model package into a payload.
Returns
-------
ref : `list` of `lsst.daf.butler.DatasetRef`
List of DatasetRefs for the requested dataset.
Ideally, there should be only one.
payload : `NNModelPackagePayload`
The payload containing the compressed model package.
"""

results = registry.queryDatasets(dataSetType,
collections=f'{StorageAdapterButler.packages_parent_collection}/{config.modelPackageName}')
return list(results)
payload = NNModelPackagePayload()

with zipfile.ZipFile(payload.bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr('checkpoint', self.checkpoint_file.read())
zf.writestr('architecture', self.model_file.read())
zf.writestr('metadata', self.metadata_file.read())

return payload

def fetch(self):
"""Fetch the model package from the butler repository, decompress it
Expand All @@ -89,29 +132,19 @@ def fetch(self):
butler repository is already done, which is the "normal" case.
"""

# Check if at least one of the properties is non-empty.
if self.butler is None and self.butler_loaded_package is None:
raise ValueError('Either butler or butler_loaded_package must be non-empty')

# If fetch() has already been called, do nothing.
# If we have already loaded the package, there's nothing left to do here.
if self.model_file is not None:
return

# Do the fetching from butler, if needed.
if self.butler_loaded_package is None: # We are not using a preloaded package. Use butler.
results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name,
collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}')
# fetch the object using butler
self.butler_loaded_package = self.butler.get(list(results)[0])
# Fetching needs a butler object.
if self.butler is None:
raise ValueError('The `butler` object is required for fetching the model package')

# The object in the memory is a zip file. We need to decompress it.
with zipfile.ZipFile(self.butler_loaded_package, 'r') as zip_ref:
with zip_ref.open('checkpoint') as f:
self.checkpoint_file = io.BytesIO(f.read())
with zip_ref.open('architecture') as f:
self.model_file = io.BytesIO(f.read())
with zip_ref.open('metadata') as f:
self.metadata_file = io.BytesIO(f.read())
# Fetch the model package from the butler repository.
results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name,
collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}')
payload = self.butler.get(list(results)[0])
self.from_payload(payload)

def load_arch(self, device):
"""
Expand Down Expand Up @@ -214,31 +247,20 @@ def ingest(model_package, butler, model_package_name=None):
data_id = {}
dataset_type = DatasetType(StorageAdapterButler.dataset_type_name,
dimensions=[],
storageClass="ModelPackage",
storageClass="NNModelPackagePayload",
universe=butler.registry.dimensions)

# Register the dataset type.
def register_dataset_type(butler, dataset_type_name, dataset_type):
try: # Do nothing if the dataset type is already registered
butler.registry.getDatasetType(dataset_type_name)
except KeyError:
butler.registry.registerDatasetType(dataset_type)

register_dataset_type(butler, StorageAdapterButler.dataset_type_name, dataset_type)

# Create a temporary file and Zip all the three components into it.
temp_fd, temp_path = tempfile.mkstemp(suffix='.bin')
with os.fdopen(temp_fd, 'wb') as tmp:
with zipfile.ZipFile(tmp, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
zf.write(model_package.adapter.checkpoint_filename, arcname="checkpoint")
zf.write(model_package.adapter.model_filename, arcname="architecture")
zf.write(model_package.adapter.metadata_filename, arcname="metadata")

file_dataset = FileDataset(path=temp_path,
refs=DatasetRef(dataset_type, data_id, run=run_collection),
formatter=BinaryFormatter)

# Ingest the file into the butler repo
butler.ingest(file_dataset, transfer='copy')

# Remove the temporary file
os.remove(temp_path)
# Create an instance of StorageAdapterButler, and ingest its payload.
payload = StorageAdapterButler.from_other(model_package.adapter).to_payload()
butler.put(payload,
dataset_type,
data_id,
run = run_collection)
2 changes: 1 addition & 1 deletion python/lsst/meas/transiNet/rbTransiNetTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections,
pretrainedModel = lsst.pipe.base.connectionTypes.PrerequisiteInput(
doc="Pretrained neural network model (-package) for the RBClassifier.",
dimensions=(),
storageClass="ModelPackage",
storageClass="NNModelPackagePayload",
name=StorageAdapterButler.dataset_type_name,
)

Expand Down

0 comments on commit dc72b2b

Please sign in to comment.