In [None]:
import torch

torch.set_num_threads(snakemake.threads)
import logging
import numpy as np
from cellwhisperer.jointemb.cellwhisperer_lightning import (
    TranscriptomeTextDualEncoderLightning,
)
from cellwhisperer.jointemb.processing import TranscriptomeTextDualEncoderProcessor
from cellwhisperer.config import get_path, model_path_from_name

from cellwhisperer.jointemb.dataset.inference import CellxGenePreparationLoader

In [None]:
# configure logging
log_file = snakemake.log.log_file  # Get the log file path from Snakemake
logging.basicConfig(
    filename=log_file,
    filemode="a",  # Append to the existing log file
    level=logging.INFO,  # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Define the log message format
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pl_model = TranscriptomeTextDualEncoderLightning.load_from_checkpoint(
    snakemake.input.model
)
pl_model.eval().to(device)
pl_model.model.prepare_models(
    pl_model.model.transcriptome_model, pl_model.model.text_model, force_freeze=True
)
pl_model.freeze()

In [None]:
from tqdm.auto import tqdm

# traverse all transcriptomes, log features and embeddings and save them


dl = CellxGenePreparationLoader(
    read_count_table=snakemake.input.read_count_table,
    transcriptome_processor=pl_model.model.transcriptome_model.config.model_type,
    batch_size=32,
)

In [None]:
next(iter(dl))

In [None]:
results = []
for batch in tqdm(dl):
    transcriptome_features, transcriptome_embeds = (
        pl_model.model.get_transcriptome_features(
            **{k: t.to(device) for k, t in batch.items()},
            normalize_embeds=True,
        )
    )
    results.append(
        {
            k: t.detach().cpu()
            for k, t in zip(
                ["transcriptome_features", "transcriptome_embeds"],
                [transcriptome_features, transcriptome_embeds],
            )
        }
    )

In [None]:
aggregated_dict = {key: torch.cat([d[key] for d in results]) for key in results[0]}
aggregated_dict["orig_ids"] = dl.dataset.orig_ids

In [None]:
np.savez(snakemake.output["model_outputs"], **aggregated_dict)