Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PixelwiseRegressionTask #1241

Merged
merged 28 commits into from
Apr 25, 2023

Conversation

isaaccorley
Copy link
Collaborator

@isaaccorley isaaccorley commented Apr 13, 2023

This PR adds a PixelwiseRegressionTask which can be used for regression on 2D imagery, e.g. height estimation or other continuous per-pixel variables. It's basically a mixture of the current RegressionTask and SemanticSegmentationTask in that it performs regression but uses smp models e.g. U-Net with a single output channel and L1 or L2 loss for training.

Unless I'm mistaken, we don't have any datasets to actually test this on.

Closes #849

@isaaccorley isaaccorley added this to the 0.5.0 milestone Apr 13, 2023
@isaaccorley isaaccorley self-assigned this Apr 13, 2023
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Apr 13, 2023
@adamjstewart
Copy link
Collaborator

Not sure if this belongs in regression.py or segmentation.py or in a new file. Will try to review when I find time to sleep.

@isaaccorley
Copy link
Collaborator Author

isaaccorley commented Apr 14, 2023

I vote for regression.py because it's basically the RegressionTask but with a smp base model instead of a timm model.

@@ -35,8 +36,10 @@ class RegressionTask(LightningModule): # type: ignore[misc]
print(timm.list_models())
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
target_key: str = "label"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows to not have to duplicate all the train/val/test steps just to change label to mask. Let me know if you have any other suggestions.

@github-actions github-actions bot added the testing Continuous integration testing label Apr 21, 2023
@isaaccorley isaaccorley changed the title Add DenseRegressionTask Add PixelwiseRegressionTask Apr 21, 2023
Comment on lines 233 to 234
("inria", InriaAerialImageLabelingDataModule, 1, "mse"),
("inria", InriaAerialImageLabelingDataModule, 2, "mae"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing regression on Inria binary [0, 1] masks for now since we don't have a readily available pixelwise regression datamodule.

y_hat = self(x)

loss = F.mse_loss(y_hat, y)
if y_hat.ndim != y.ndim:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_outputs=1 the target variable ground truth is missing the necessary channel dim e.g.

  • (b,) instead of (b, 1)
  • (b, h, w) instead of (b, 1, h, w)

while the output of the models will be:

  • (b, 1)
  • (b, 1, h, w)

self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)
self.train_metrics(y_hat, y.to(torch.float))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast to float only for loss and metrics in case the plotting expects a different dtype

@adamjstewart
Copy link
Collaborator

Looks like pixel-wise regression won the poll.

Can you rebase and use the new hydra-style configs?

@isaaccorley isaaccorley merged commit 7678627 into microsoft:main Apr 25, 2023
@isaaccorley isaaccorley deleted the trainers/dense-regression branch April 25, 2023 14:05
@adamjstewart
Copy link
Collaborator

Dimension stuff looks super confusing. Wonder if there's a way to simplify that.

We should consider abstracting the segmentation model stuff into a shared utility.

Will update things to the new style in a separate PR.

@isaaccorley
Copy link
Collaborator Author

isaaccorley commented Apr 25, 2023

The dimension lines are essentially the same as the

.view(-1, 1)

except it only does it if necessary. It's a result of not adding a channel dimension to the output from the dataset targets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

a version of the semantic segmentation task for regression
3 participants