Skip to content

Commit

Permalink
Remov Argmax Computation for torchmetrics in Classification and Segme…
Browse files Browse the repository at this point in the history
…ntation (#1777)

* remove y_hat_hard

* argmax vs softmax :)
  • Loading branch information
nilsleh committed Dec 15, 2023
1 parent 464f01f commit ece7350
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
11 changes: 4 additions & 7 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,9 @@ def training_step(
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics)

return loss
Expand All @@ -183,10 +182,9 @@ def validation_step(
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics)

if (
Expand All @@ -198,7 +196,7 @@ def validation_step(
and hasattr(self.logger.experiment, "add_figure")
):
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard
batch["prediction"] = y_hat.argmax(dim=-1)
for key in ["image", "label", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
Expand Down Expand Up @@ -227,10 +225,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics)

def predict_step(
Expand Down
11 changes: 4 additions & 7 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,9 @@ def training_step(
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics)
return loss

Expand All @@ -238,10 +237,9 @@ def validation_step(
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics)

if (
Expand All @@ -253,7 +251,7 @@ def validation_step(
and hasattr(self.logger.experiment, "add_figure")
):
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard
batch["prediction"] = y_hat.argmax(dim=1)
for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
Expand Down Expand Up @@ -282,10 +280,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics)

def predict_step(
Expand Down

0 comments on commit ece7350

Please sign in to comment.