Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to call metric from any step in Lightning module #102

Closed
SpirinEgor opened this issue Mar 17, 2021 · 9 comments
Closed

Unable to call metric from any step in Lightning module #102

SpirinEgor opened this issue Mar 17, 2021 · 9 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@SpirinEgor
Copy link

🐛 Bug

I implemented my own Metric class that returns from the compute data class with some aggregated metrics -- precision, recall, and f1-score. But when I try to call metric inside *_step I got the error from PyTorch internals.

The error happened in this line. If I call validation metric (initialized with compute_on_step=False) during validation_step I got:

TypeError: 'NoneType' object is not subscriptable

In the case of training metric during training_step:

TypeError: 'ClassificationMetrics' object is not subscriptable

ClassificationMetrics is the name of my data class.

I also tried to return float from compute, but it also causes the same error. I assume that PyTorch expects to receive tensor and therefore trying to get from var. An obvious solution is to return tensor from compute, but it doesn't fix calling validation metric that doesn't return anything.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0
  • OS (e.g., Linux): MacOS BigSur
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.9.2
  • CUDA/cuDNN version: -
  • GPU models and configuration: -
  • Any other relevant information: pytorch-lightning (1.1.7) / torchmetrics (0.2.0)
@SpirinEgor SpirinEgor added the help wanted Extra attention is needed label Mar 17, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the bug / fix Something isn't working label Mar 17, 2021
@SkafteNicki
Copy link
Member

Hi @SpirinEgor would it be possible for you to share some reproducible code?

@SpirinEgor
Copy link
Author

SpirinEgor commented Mar 17, 2021

You can find metric code here. Since it is something like repo with further useful for me things, I add it to PyPI too. You can install it to your test dev and try inside the PyTorch Lightning module.
Btw, there are some simple tests. They work great without any failure.

@SkafteNicki
Copy link
Member

@SpirinEgor I must be doing something different from you, because if I call your metric inside training_step I cannot reproduce the error. Could you provide what your training_step looks like?

@SpirinEgor
Copy link
Author

Sure, here what I'm doing.

class MyModel(LightningModule):
    def __init__(self, config: DictConfig, vocabulary: Vocabulary):
        super().__init__()
        ...
        pad_idx = vocabulary.label_to_id[PAD]
        ignore_idx = [vocabulary.label_to_id[i] for i in [UNK, EOS, SOS] if i in vocabulary.label_to_id]
        self._train_metrics = SequentialF1Score(mask_after_pad=True, pad_idx=pad_idx, ignore_idx=ignore_idx)
        self._val_metrics = SequentialF1Score(
            mask_after_pad=True, pad_idx=pad_idx, ignore_idx=ignore_idx, compute_on_step=False
        )
        self._test_metrics = SequentialF1Score(
            mask_after_pad=True, pad_idx=pad_idx, ignore_idx=ignore_idx, compute_on_step=False
        )
        ....

    ...

    def training_step(self, batch: Tuple[torch.Tensor, dgl.DGLGraph], batch_idx: int) -> Dict:  # type: ignore
        labels, graph = batch
        # [seq length; batch size; vocab size]
        logits = self(graph, labels.shape[0], labels)
        loss = self._calculate_loss(logits, labels)
        prediction = logits.argmax(-1)

        batch_metrics: ClassificationMetrics = self._train_metrics(prediction, labels)  # **Error happens here**
        log = {
            "train/loss": loss,
            "train/f1": batch_metrics.f1_score,
            "train/precision": batch_metrics.precision,
            "train/recall": batch_metrics.recall,
        }
        self.log_dict(log)
        self.log("f1", batch_metrics.f1_score, prog_bar=True, logger=False)

        return {"loss": loss}

    def validation_step(self, batch: Tuple[torch.Tensor, dgl.DGLGraph], batch_idx: int) -> Dict:  # type: ignore
        labels, graph = batch
        # [seq length; batch size; vocab size]
        logits = self(graph, labels.shape[0], labels)
        loss = self._calculate_loss(logits, labels)
        prediction = logits.argmax(-1)

        self._val_metrics(prediction, labels)  # **And here**
        return {"loss": loss}

This is very strange for me too. Because when I wrote tests for metric class it was work just fine...
I also tried to use the last version of PyTorch-Lightning, but it doesn't change anything.

@SkafteNicki
Copy link
Member

Strange, still cannot reproduce your error. Both training and validation seems to run fine :/
Do you have any backward hooks registered in your code? because it should only get to that location in the pytorch code if you have.

@SpirinEgor
Copy link
Author

Oh, it's become really tough...
I didn't define any hooks. Any chancing Lightning can do it? Also, I used DGL to train on graphs, maybe it defines hooks.

@SpirinEgor
Copy link
Author

I found the reason for all my troubles. In the main train function, I use WandbLogger and I also use

wandb_logger.watch(model)

This step adds backward hooks to track gradients, their norms, etc. After I disable it metric start working as I expected.

Anyway, big thanks for your response! The library is great, hope to see more cool features.

@SkafteNicki
Copy link
Member

Thanks for letting me know, and glad that you like the library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants