In [5]:
import os
import tempfile

import torch
from google.cloud import storage

from cellarium.ml.core import CellariumModule, CellariumPipeline


def get_pretrained_model_as_pipeline(
    trained_model: str = "gs://dsp-cell-annotation-service/cellarium/trained_models/cerebras/lightning_logs/version_0/checkpoints/epoch=2-step=83250.ckpt",
    transforms: list[torch.nn.Module] = [],
    device: str = "cuda",
) -> CellariumPipeline:
    if trained_model.startswith("gs://"):
        # download the trained model
        with tempfile.TemporaryDirectory() as tmpdir:
            tmp_file = os.path.join(tmpdir, "model.ckpt")

            client = storage.Client()
            bucket_name = trained_model.split("/")[2]
            blob_name = "/".join(trained_model.split("/")[3:])
            bucket = client.get_bucket(bucket_name)
            blob = bucket.blob(blob_name)
            blob.download_to_filename(tmp_file)

            # load the model
            model = CellariumModule.load_from_checkpoint(tmp_file).model
    else:
        # load the model
        model = CellariumModule.load_from_checkpoint(trained_model).model

    # insert the trained model params
    model.to(device)
    model.eval()

    # construct the pipeline
    pipeline = CellariumPipeline(transforms + [model])

    return pipeline

In [7]:
pipeline = get_pretrained_model_as_pipeline(
    trained_model="../lightning_logs/version_65/checkpoints/epoch=1-step=8.ckpt",
)

/home/sfleming/miniforge3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [8]:
pipeline

CellariumPipeline(
  (0): NonNegativeMatrixFactorization()
)

In [10]:
pipeline[-1].factors_kg

tensor([[-3.3687e-04,  7.2682e-04, -1.6188e-04,  ..., -2.9623e-04,
         -4.4200e-04, -3.2031e-04],
        [-3.1955e-04,  8.8420e-04, -8.4525e-05,  ..., -2.5221e-04,
         -3.8733e-04, -2.8206e-04],
        [ 1.0151e-03,  2.5830e-03,  9.9171e-04,  ...,  8.7675e-04,
          1.2697e-03,  1.0073e-03],
        ...,
        [-2.0874e-04,  1.2259e-03,  2.9339e-05,  ..., -1.0489e-04,
         -2.3736e-04, -1.4860e-04],
        [-6.2403e-05,  1.1807e-03,  8.3635e-05,  ..., -2.2226e-05,
         -8.5726e-05, -3.9641e-05],
        [ 8.1264e-06,  1.0017e-03,  1.1781e-04,  ..., -8.2809e-07,
         -8.6355e-07, -2.0023e-06]], device='cuda:0')

In [12]:
pipeline[-1].factors_kg >= 0

tensor([[False,  True, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')