Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OSCD RuntimeError: Input type (torch.cuda.LongTensor) and weight type (torch.cuda.FloatTensor) should be the same #1652

Closed
robmarkcole opened this issue Oct 12, 2023 · 1 comment · Fixed by #1656
Milestone

Comments

@robmarkcole
Copy link
Contributor

robmarkcole commented Oct 12, 2023

Description

Using the OSCD datamodule with a segmentation task results in the error:

RuntimeError: Input type (torch.cuda.LongTensor) and weight type (torch.cuda.FloatTensor) should be the same

It is necessary to cast the image to a float to resolve this (using x = batch["image"].float()), which I verified using a custom task. However I'm pretty sure this is not the intention to do this

Steps to reproduce

from lightning.pytorch import Trainer
from torchgeo.datamodules import OSCDDataModule

datamodule = OSCDDataModule(
    num_workers=num_workers,
    download=True,
    bands="rgb"
)

task = SemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=True,
    in_channels=6,
    num_classes=2,
    loss="ce",
    ignore_index=None,
)

trainer = Trainer(
    min_epochs=5,
    max_epochs=25,
)

_ = trainer.fit(model=task, datamodule=datamodule)

I note that the data is type int64:

datamodule.train_dataset[0]["image"].dtype
torch.int64

Version

0.5.0

@adamjstewart
Copy link
Collaborator

This relates to #985. We currently don't have a good way to test this. OSCD isn't really intended for use with SemanticSegmentationTask, we need a new trainer for change detection. I would be fine with a PR that simply casts the image to float32.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants