Skip to content

RuntimeError: stack expects each tensor to be equal size during multi-task prediction when keep_all_cls_pred=True #16

@Prachi-Priyam

Description

@Prachi-Priyam

Environment
-scprint: 2.3.8
-Python: 3.10.20
-torch: 2.2.0+cu121
-cuda: 12.2
-NVIDIA A100 80GB

Describe the bug
When running the Embedder with keep_all_cls_pred=True and multiple labels in pred_embedding, a RuntimeError is raised because torch.stack() fails when classification heads have different output
sizes.

To Reproduce
from scprint import scPrint
from scprint.tasks import Embedder
model = scPrint.load_from_checkpoint("large-v1.ckpt", precpt_gene_emb=None)
embedder = Embedder( batch_size=64, keep_all_cls_pred=True, pred_embedding ['cell_type_ontology_term_id', # 424 classes 'disease_ontology_term_id', # 62 classes ], doclass=True, )
result = embedder(model=model, adata=adata) ```

Expected behavior
When keep_all_cls_pred=True, the Embedder should return probability distributions for all classes for each cell, stored correctly in adata.obs without requiring manual source code patches.When keep_all_cls_pred=False, the Embedder should return the top predicted class (argmax) for each label and store it correctly in adata.obs.In both cases, post-processing should handle CUDA tensors correctly by moving them to CPU before pandas DataFrame conversion.

Desktop (please complete the following information):
PRETTY_NAME="Ubuntu 22.04.4 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.4 LTS (Jammy Jellyfish)"

Additional context

Three separate bugs identified:

Bug 1: torch.stack fails with keep_all_cls_pred=True

Location: scprint/model/model.py, line 1572

When keep_all_cls_pred=True, each classification head returns a tensor of shape [batch, n_classes] where n_classes differs per label (e.g. 424 for cell_type_ontology_term_id vs 62 for disease_ontology_term_id). torch.stack() requires equal sizes across all entries, causing the crash.

**Error:
** RuntimeError: stack expects each tensor to be equal size, but got [32, 424] at entry 0 and [32, 62] at entry 1

Relevant code:

self.pred = (
torch.stack(
[ ( torch.argmax(output["cls_output_" + clsname], dim=1)
if not self.keep_all_cls_pred
else output["cls_output_" + clsname] # different sizes )

for clsname in self.classes ] ).transpose(0, 1)

Proposed fix:
if len(self.classes) > 0:
if self.keep_all_cls_pred:
self.pred = {
clsname: output["cls_output_" + clsname]
for clsname in self.classes
}
else:
self.pred = torch.stack([
torch.argmax(output["cls_output_" + clsname], dim=1)
for clsname in self.classes ]).transpose(0, 1)
else:
self.pred = None

Bug 2: CUDA tensor not moved to CPU before DataFrame conversion Location: scprint/tasks/cell_emb.py, line ~208

The tensor allclspred remains on GPU when passed to pd.DataFrame(), causing a TypeError.

Error: TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Current code:

allclspred = pd.DataFrame( allclspred, columns=columns, index=adata.obs.index )

Fix: allclspred = pd.DataFrame( allclspred.cpu().numpy(), columns=columns, index=adata.obs.index )

Bug 3: pd.concat called with wrong syntax

Location: scprint/tasks/cell_emb.py, line ~211

pd.concat receives two positional arguments instead of a list, causing the second argument to be interpreted as the axis parameter.
Current code:
adata.obs = pd.concat(adata.obs, allclspred)

Fix: adata.obs = pd.concat([adata.obs, allclspred], axis=1)

Workaround used to confirm embedding works correctly:

  1. Patched model.py to force argmax (keep_all_cls_pred=False behavior)
  2. Applied .cpu().numpy() fix in cell_emb.py
  3. Fixed pd.concat syntax in cell_emb.py
  4. Set model.keep_all_cls_pred = False before calling embedder

With these patches applied, the embedder ran successfully on 18,000 endometrial cells (test subset) in ~82 seconds on A100 GPU, producing correct embeddings in adata.obsm['scprint_emb'] with shape (18000, 512) and valid cell type predictions in adata.obs from initial 18000 × 44733 shape. The core embedding computation is unaffected by these bugs — only post-processing of classification predictions is broken.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions