Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type error when trying run trainer.fit with tft #1288

Open
jzicker opened this issue Apr 12, 2023 · 14 comments
Open

Type error when trying run trainer.fit with tft #1288

jzicker opened this issue Apr 12, 2023 · 14 comments

Comments

@jzicker
Copy link

jzicker commented Apr 12, 2023

  • PyTorch-Forecasting version: 1.0
  • PyTorch version: 2.0
  • Python version:
  • Operating System: running on google colab

Expected behavior

I executed code trainer.fit. It used to work and now I get a type error.

Actual behavior

I think it has to do with the april 10 release of pytorch-forecasting moving to pytorch 2.0

Code to reproduce the problem

trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-57-127e7bdaac70>](https://localhost:8080/#) in <cell line: 1>()
----> 1 trainer.fit(
      2     tft,
      3     train_dataloaders=train_dataloader,
      4     val_dataloaders=val_dataloader
      5 )

1 frames
[/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/compile.py](https://localhost:8080/#) in _maybe_unwrap_optimized(model)
    123     if isinstance(model, pl.LightningModule):
    124         return model
--> 125     raise TypeError(
    126         f"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`"
    127     )

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

Paste the command(s) you ran and the output. Including a link to a colab notebook will speed up issue resolution.
If there was a crash, please include the traceback here.
The code used to initialize the TimeSeriesDataSet and model should be also included.

@MorrisHsieh3059
Copy link

same problem!!!!!

@jzicker
Copy link
Author

jzicker commented Apr 13, 2023

I can't post all the code or give access to the notebook. But here are relevant portions.

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=MIN_DELTA,
patience=PATIENCE, verbose=False, mode="min")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
max_epochs=MAX_EPOCHS,
accelerator='gpu',
devices="auto",
enable_model_summary=True,
gradient_clip_val=GRADIENT_CLIP_VAL,
callbacks=[lr_logger, early_stop_callback, checkpoint_callback],
logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=LR,
hidden_size=HIDDEN_SIZE,
attention_head_size=ATTENTION_HEAD_SIZE,
dropout=DROPOUT,
hidden_continuous_size=HIDDEN_CONTINUOUS_SIZE,
output_size=7, # there are 7 quantiles by default: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
loss=QuantileLoss(),
log_interval=10,
log_val_interval = 10,
reduce_on_plateau_patience=REDUCE_ON_PLATEAU_PATIENCE,
lstm_layers = 2
)

@rmagesh148
Copy link

same error!

@pjwu1997
Copy link

Same error here

@vikolss
Copy link

vikolss commented Apr 14, 2023

I get the same error with a custom model subclassed from TFT

@shuya-li-wmg
Copy link

shuya-li-wmg commented Apr 14, 2023

Had the same issue but change pytorch_lighting to lightning.pytorch and resolved.

@MariaSky7
Copy link

still have this problem also with lightning.pytorch

@Smendowski
Copy link

Smendowski commented Apr 16, 2023

I suggest you to perform the following changes:

import lightning.pytorch as pl # Instead of import pytorch_lightning as pl

# [...]
# Also the way callbacks are imported should be slightly modified.
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor

Everything is up to date in the usage examples section:
https://github.com/jdb78/pytorch-forecasting

@jzicker
Copy link
Author

jzicker commented Apr 16, 2023

@Smendowski Thank you so much! you unblocked me

@DevSoftChuck
Copy link

Hey guys! I'm still having the same issue, what I am doing wrong?

from pytorch_forecasting import TimeSeriesDataSet
import torch
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner

.
.
.

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateMonitor()
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto", 
    gradient_clip_val=0.1,
    limit_train_batches=30, 
    callbacks=[lr_logger, early_stop_callback],
    logger=TensorBoardLogger("lightning_logs")
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    loss=QuantileLoss(),
    log_interval=2,
    learning_rate=0.03,
    reduce_on_plateau_patience=4
)

res = Tuner(trainer).lr_find(
    tft, 
    train_dataloaders=train_dataloader, 
    val_dataloaders=val_dataloader, 
    early_stop_threshold=1000.0, 
    max_lr=10.0,
    min_lr=1e-6,
)

This is the error message:

TypeError: `model` must be a `LightningModule`, got `TemporalFusionTransformer`

@grosestq
Copy link

grosestq commented Apr 18, 2023

@DevSoftChuck Your packages may be out of date (torch etc.)?

@MBristle
Copy link

MBristle commented Apr 18, 2023

I had the same problem. It must have something to do with the breaking changes of torch to version 2.0.0 and pytorch-forecasting to version 1.0.0 (as mentioned by @Smendowski above).

As an intermediate solution, I stuck with pytorch-forecasting==0.10.3 torch==1.13.1, which works fine.

Further adjustments might be necessary for the upgrade. E.g., the "gpus" argument of the trainer is not valid anymore (https://lightning.ai/docs/pytorch/latest/upgrade/from_1_9.html). So when considering the upgrade, one must most likely adapt parts of the code.

@DevSoftChuck
Copy link

@grosestq I was using an outdated version of torch.

pytorch-forecasting==0.10.3
pytorch-lightning==1.9.4
torch==1.13.1

then I upgraded some packages to the following version and that problem was solved:

pytorch-forecasting==1.0.0
pytorch-lightning==2.0.1.post0
torch==2.0.0

Thanks for your help guys!!

@chaoss16
Copy link

chaoss16 commented Apr 26, 2023

Changing all the lightning.pytorch to pytorch_lightning works for me! Like this:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests