Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation Data is not generated properly for my dataset #1

Closed
RaviKumarAndroid opened this issue Jul 26, 2020 · 11 comments
Closed

Validation Data is not generated properly for my dataset #1

RaviKumarAndroid opened this issue Jul 26, 2020 · 11 comments

Comments

@RaviKumarAndroid
Copy link

RaviKumarAndroid commented Jul 26, 2020

Hi, Really appreciate your work on the TFT.

I am trying to use my own dataset in the code but there seems to be a bug due to which the dataset is not being loaded properly for validation.
The train dataloader is good. but the validation dataloader only has one batch and also validation(TimeSeriesDataSet) has only 1 entry.

Below is my complete code

data = load_csv()
data['date']= pd.to_datetime(data['date'])
data.reset_index(inplace=True, drop=True)
data.reset_index(inplace=True)
data.rename(columns={'index':'time_idx'}, inplace=True) # I use index as time_idx since my data is of minute frequency

validation_len = int(len(data) * 0.1)
training_cutoff = int(len(data)) - validation_len

max_encode_length = 36
max_prediction_length = 6

print('Len of training data is : ',len(data[:training_cutoff]))
print('Len of val data is : ',len(data[training_cutoff:]))
training = TimeSeriesDataSet(
    data[:training_cutoff],
    time_idx="time_idx",
    target="T",
    group_ids=["Symbol"],
    max_encoder_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["Symbol"],
    static_reals=[],
    time_varying_known_categoricals=[
        "hour_of_day",
        "day_of_week",
    ],
    time_varying_known_reals=[
        "time_idx",
    ],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=["V1", "V2","V3", "T", "V4"],
    constant_fill_strategy={"T": 0},
    dropout_categoricals=[],
)
print('Max Prediction Index : ',training.index.time.max())
validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max()+1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=1)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=1)

print(len(training), len(validation))
print(len(train_dataloader), len(val_dataloader))`

This is what gets printed by the code :

Len of training data is : 25920
Len of val data is : 2880
Min Prediction Index Value is 25919

25920 1
202 1

You can see that the training dataset is good and the batches are also okay but validation batch is also 1 and dataset length is also 1.

One more thing. if i use predict = False it generates validation data correctly but another bug arises due to that.
if i use predict = True, only 1 batch and 1 sequence is given
if i use predict_mode = true on the training dataset it also generates only 1 batch.

Here is a sample of my CSV
sample_data.csv.zip

Please Help

@RaviKumarAndroid
Copy link
Author

RaviKumarAndroid commented Jul 26, 2020

Please help me how to use the attached dataset through the forecasting library.
and even though if the batch is okay(with predict_mode=False) another error occurs :

Traceback (most recent call last):
  File "/Users/wilson/PycharmProjects/forecast_test/predictors/temporal_fusion_transformer/main_torch_forecasting.py", line 159, in <module>
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1044, in fit
    results = self.run_pretrain_routine(model)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1196, in run_pretrain_routine
    False)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 272, in _evaluate
    for batch_idx, batch in enumerate(dataloader):
  File "/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/wilson/PycharmProjects/forecast_test/pytorch_forecasting/data.py", line 385, in __getitem__
    assert decoder_length >= self.min_prediction_length
AssertionError

@jdb78
Copy link
Collaborator

jdb78 commented Jul 26, 2020

The first error is due to an unfortunate default in the from_dataset method. The default was predict=True which causes to the dataset to select the last samples per timeseries for prediction - as you only have one timeseries there is only one sample). I changed the default to predict=False. You might want to pass now stop_randomization=True for the validation dataset.

The second error is a genuine error (1 off error) which I just fixed. Thanks for bringing this to my attention! The fix is pushed to master, so installing from git should do the job.

@RaviKumarAndroid
Copy link
Author

RaviKumarAndroid commented Jul 27, 2020

Thanks Man! Really appreciate your work.

Got another error after predict=False
Another 1 off error

File "/Users/wilson/PycharmProjects/forecast_test/main_torch_forecasting.py", line 154, in <module>
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1044, in fit
    results = self.run_pretrain_routine(model)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1196, in run_pretrain_routine
    False)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 293, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 470, in evaluation_forward
    output = model.validation_step(*args)
  File "/Users/wilson/PycharmProjects/forecast_test/pytorch_forecasting/models/base_model.py", line 126, in validation_step
    log, _ = self.step(x, y, batch_idx, label="val")
  File "/Users/wilson/PycharmProjects/forecast_test/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py", line 544, in step
    attention_prediction_horizon=0,  # attention only for first prediction horizon
  File "/Users/wilson/PycharmProjects/forecast_test/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py", line 576, in interpret_output
    encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
  File "/Users/wilson/PycharmProjects/forecast_test/pytorch_forecasting/utils.py", line 27, in integer_histogram
    dim=0, index=uniques - min, src=counts
RuntimeError: index 37 is out of bounds for dimension 0 with size 37

Thanks Again

@jdb78
Copy link
Collaborator

jdb78 commented Jul 28, 2020

This is probably due to only having a single timeseries in the dataset. I will have a closer look later or on the weekend. You could probably run with log_interval=-1 in the meantime and skip the interpretation step alltogether.

@RaviKumarAndroid
Copy link
Author

RaviKumarAndroid commented Jul 28, 2020

Yeah, For Now I commented out just the interpretation part because i would like to be able to see prediction outputs in tensorboard. But I found an issue in the visualisation.
I Think. You are plotting Non Normalised target vs output(not yet denormalised).

Screenshot 2020-07-29 at 1 24 55 AM

See the scale of the graph. The target value is about 300 and prediction is about 0 point something. So its not denormalised back again

@jdb78
Copy link
Collaborator

jdb78 commented Jul 31, 2020

Sorry for taking a bit of time. The plotting functionality is not broken. However, I think there are two issues here:

  1. The network has probably not trained for long enough
  2. To train, timeseries should be normalized (I think almost all networks require this) for faster convergence and better generalisation.

I am working on this normalization so you do not have to take care of it manually.

@RaviKumarAndroid
Copy link
Author

RaviKumarAndroid commented Aug 1, 2020

Thanks, I am not normalising manually. I read the code and i saw you are already normalising them. But while plotting its plotting like above.

Actually when i printed the x["encoder_target"][0], x["decoder_target"][0] I found those values to be not normalised

for batch, _ in test_dataloader:    
      output = tft(batch)
      print(batch['encoder_target'][0])
      print(batch['decoder_target'][0])
      print(output['prediction'][0])
      break

This is what gets printed

tensor([309.5000, 309.5000, 309.4500, 309.5000, 309.4500, 309.4500, 309.4000,
        309.4000, 309.3500, 309.3000, 309.2500, 309.3000, 309.4000, 309.4000,
        309.4000, 309.3000, 309.1000, 309.2500, 309.3500, 309.4500, 309.5500,
        309.5500, 309.5000, 309.4500, 309.2500, 309.2500, 309.5500, 309.5000,
        309.4500, 309.6500, 309.6000, 309.6000, 309.5000, 309.5500, 309.5500,
        309.6500, 309.7500, 309.8000, 309.7000, 309.5000, 309.7000])
tensor([310.9000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([[-0.4033, -0.5201,  1.2264, -0.3546, -0.1055,  0.1809, -0.7248],
        [-0.2388,  0.1727, -0.3465, -0.3504,  0.4635,  0.3774, -0.0414],
        [-0.2694,  0.1961, -0.3326, -0.2950,  0.4649,  0.2898, -0.0930],
        [-0.2943,  0.1073, -0.2853, -0.4206,  0.6087,  0.3669, -0.2790],
        [-0.2427, -0.0640, -0.3871, -0.2410,  0.5516,  0.6015, -0.3814],
        [-0.3132,  0.1434, -0.1311, -0.4787,  0.5760,  0.4773, -0.2528]],
       grad_fn=<SelectBackward>)

The scale of the values is wrong here

I believe its happening due to "target" being saved non scaled in the dataframe??

See the batch :

{'encoder_cat': tensor([[[0, 5, 3],
         [0, 5, 3],
         [0, 5, 3],
         ...,
         [0, 6, 3],
         [0, 6, 3],
         [0, 6, 3]],

        [[0, 5, 3],
         [0, 5, 3],
         [0, 5, 3],
         ...,
         [0, 6, 3],
         [0, 6, 3],
         [0, 0, 0]],

        [[0, 5, 3],
         [0, 5, 3],
         [0, 5, 3],
         ...,
         [0, 6, 3],
         [0, 0, 0],
         [0, 0, 0]],

        ...,

        [[0, 1, 4],
         [0, 1, 4],
         [0, 1, 4],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 4],
         [0, 1, 4],
         [0, 1, 4],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 4],
         [0, 1, 4],
         [0, 1, 4],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]]), 'encoder_cont': tensor([[[ 0.9753, -1.1389, -1.3834,  ..., -1.3258, -1.3772, -0.2323],
         [ 0.9759, -1.1111, -1.3834,  ..., -1.3258, -1.3772, -0.3194],
         [ 0.9765, -1.0833, -1.3834,  ..., -1.3721, -1.4003,  0.1851],
         ...,
         [ 0.9988, -1.0000, -1.2444,  ..., -1.2334, -1.2846,  0.8774],
         [ 0.9994, -1.0000, -1.3139,  ..., -1.3721, -1.3772,  3.0574],
         [ 1.0000, -1.0000, -1.3834,  ..., -1.3258, -1.2846,  1.0244]],

        [[ 0.9759, -1.1111, -1.3834,  ..., -1.3258, -1.3772, -0.3194],
         [ 0.9765, -1.0833, -1.3834,  ..., -1.3721, -1.4003,  0.1851],
         [ 0.9772, -1.0556, -1.4066,  ..., -1.3952, -1.3772,  0.1334],
         ...,
         [ 0.9994, -1.0000, -1.3139,  ..., -1.3721, -1.3772,  3.0574],
         [ 1.0000, -1.0000, -1.3834,  ..., -1.3258, -1.2846,  1.0244],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.9765, -1.0833, -1.3834,  ..., -1.3721, -1.4003,  0.1851],
         [ 0.9772, -1.0556, -1.4066,  ..., -1.3952, -1.3772,  0.1334],
         [ 0.9778, -1.0278, -1.3834,  ..., -1.4183, -1.4003,  0.7037],
         ...,
         [ 1.0000, -1.0000, -1.3834,  ..., -1.3258, -1.2846,  1.0244],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        ...,

        [[ 1.0525, -1.0000,  0.1688,  ...,  0.1534,  0.1499, -0.4531],
         [ 1.0531, -1.0000,  0.1457,  ...,  0.1765,  0.1499, -0.4374],
         [ 1.0537, -1.0000,  0.1457,  ...,  0.1303,  0.0805, -0.5453],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 1.0531, -1.0000,  0.1457,  ...,  0.1765,  0.1499, -0.4374],
         [ 1.0537, -1.0000,  0.1457,  ...,  0.1303,  0.0805, -0.5453],
         [ 1.0543, -0.9722,  0.0993,  ...,  0.1303,  0.1499, -0.4982],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 1.0537, -1.0000,  0.1457,  ...,  0.1303,  0.0805, -0.5453],
         [ 1.0543, -0.9722,  0.0993,  ...,  0.1303,  0.1499, -0.4982],
         [ 1.0550, -0.9444,  0.1225,  ...,  0.1534,  0.1036, -0.5765],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]), 'encoder_target': tensor([[309.5000, 309.5000, 309.4500,  ..., 309.7000, 309.5000, 309.7000],
        [309.5000, 309.4500, 309.5000,  ..., 309.5000, 309.7000,   0.0000],
        [309.4500, 309.5000, 309.4500,  ..., 309.7000,   0.0000,   0.0000],
        ...,
        [312.8000, 312.8000, 312.6500,  ...,   0.0000,   0.0000,   0.0000],
        [312.8000, 312.6500, 312.8000,  ...,   0.0000,   0.0000,   0.0000],
        [312.6500, 312.8000, 312.7000,  ...,   0.0000,   0.0000,   0.0000]]), 'encoder_lengths': tensor([41, 40, 39, 38, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36]), 'decoder_cat': tensor([[[0, 0, 4],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 4],
         [0, 0, 4],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 4],
         [0, 0, 4],
         [0, 0, 4],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        ...,

        [[0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4]],

        [[0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4]],

        [[0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4],
         [0, 2, 4]]]), 'decoder_cont': tensor([[[ 1.0006, -1.0000, -1.2907,  ..., -1.2334, -0.7293,  3.8031],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 1.0006, -1.0000, -1.2907,  ..., -1.2334, -0.7293,  3.8031],
         [ 1.0012, -1.0000, -0.7347,  ..., -0.6787, -0.5905,  0.2554],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 1.0006, -1.0000, -1.2907,  ..., -1.2334, -0.7293,  3.8031],
         [ 1.0012, -1.0000, -0.7347,  ..., -0.6787, -0.5905,  0.2554],
         [ 1.0019, -1.0000, -0.5957,  ..., -0.6325, -0.5905,  0.9632],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        ...,

        [[ 1.0747, -0.0556,  0.4237,  ...,  0.4076,  0.4044, -0.5491],
         [ 1.0753, -0.0278,  0.4005,  ...,  0.4307,  0.4044, -0.1958],
         [ 1.0759,  0.0000,  0.3774,  ...,  0.3614,  0.3118, -0.5494],
         [ 1.0766,  0.0278,  0.2847,  ...,  0.3383,  0.3813, -0.4259],
         [ 1.0772,  0.0556,  0.3774,  ...,  0.4307,  0.3813, -0.6217],
         [ 1.0778,  0.0833,  0.4237,  ...,  0.4076,  0.3581, -0.7255]],

        [[ 1.0753, -0.0278,  0.4005,  ...,  0.4307,  0.4044, -0.1958],
         [ 1.0759,  0.0000,  0.3774,  ...,  0.3614,  0.3118, -0.5494],
         [ 1.0766,  0.0278,  0.2847,  ...,  0.3383,  0.3813, -0.4259],
         [ 1.0772,  0.0556,  0.3774,  ...,  0.4307,  0.3813, -0.6217],
         [ 1.0778,  0.0833,  0.4237,  ...,  0.4076,  0.3581, -0.7255],
         [ 1.0784,  0.1111,  0.4005,  ...,  0.3614,  0.3581, -0.4660]],

        [[ 1.0759,  0.0000,  0.3774,  ...,  0.3614,  0.3118, -0.5494],
         [ 1.0766,  0.0278,  0.2847,  ...,  0.3383,  0.3813, -0.4259],
         [ 1.0772,  0.0556,  0.3774,  ...,  0.4307,  0.3813, -0.6217],
         [ 1.0778,  0.0833,  0.4237,  ...,  0.4076,  0.3581, -0.7255],
         [ 1.0784,  0.1111,  0.4005,  ...,  0.3614,  0.3581, -0.4660],
         [ 1.0790,  0.1389,  0.3542,  ...,  0.4076,  0.3813, -0.4521]]]), 'decoder_target': tensor([[310.9000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
        [310.9000, 311.2000,   0.0000,   0.0000,   0.0000,   0.0000],
        [310.9000, 311.2000, 311.2000,   0.0000,   0.0000,   0.0000],
        [310.9000, 311.2000, 311.2000, 311.4000,   0.0000,   0.0000],
        [310.9000, 311.2000, 311.2000, 311.4000, 311.7500,   0.0000],
        [310.9000, 311.2000, 311.2000, 311.4000, 311.7500, 311.7500],
        [311.2000, 311.2000, 311.4000, 311.7500, 311.7500, 311.9000],
        [311.2000, 311.4000, 311.7500, 311.7500, 311.9000, 311.9000],
        [311.4000, 311.7500, 311.7500, 311.9000, 311.9000, 312.1000],
        [311.7500, 311.7500, 311.9000, 311.9000, 312.1000, 312.3000],
        [311.7500, 311.9000, 311.9000, 312.1000, 312.3000, 312.3000],
        [311.9000, 311.9000, 312.1000, 312.3000, 312.3000, 312.1500],
        [311.9000, 312.1000, 312.3000, 312.3000, 312.1500, 311.8500],
        [312.1000, 312.3000, 312.3000, 312.1500, 311.8500, 312.0500],
        [312.3000, 312.3000, 312.1500, 311.8500, 312.0500, 311.8500],
        [312.3000, 312.1500, 311.8500, 312.0500, 311.8500, 312.0500],
        [312.1500, 311.8500, 312.0500, 311.8500, 312.0500, 312.0000],
        [311.8500, 312.0500, 311.8500, 312.0500, 312.0000, 312.7500],
        [312.0500, 311.8500, 312.0500, 312.0000, 312.7500, 312.7500],
        [311.8500, 312.0500, 312.0000, 312.7500, 312.7500, 312.6500],
        [312.0500, 312.0000, 312.7500, 312.7500, 312.6500, 312.6000],
        [312.0000, 312.7500, 312.7500, 312.6500, 312.6000, 312.5000],
        [312.7500, 312.7500, 312.6500, 312.6000, 312.5000, 312.3500],
        [312.7500, 312.6500, 312.6000, 312.5000, 312.3500, 312.5000],
        [312.6500, 312.6000, 312.5000, 312.3500, 312.5000, 312.4000],
        [312.6000, 312.5000, 312.3500, 312.5000, 312.4000, 312.4500],
        [312.5000, 312.3500, 312.5000, 312.4000, 312.4500, 312.3500],
        [312.3500, 312.5000, 312.4000, 312.4500, 312.3500, 312.4000],
        [312.5000, 312.4000, 312.4500, 312.3500, 312.4000, 312.6000],
        [312.4000, 312.4500, 312.3500, 312.4000, 312.6000, 312.5500],
        [312.4500, 312.3500, 312.4000, 312.6000, 312.5500, 312.7000],
        [312.3500, 312.4000, 312.6000, 312.5500, 312.7000, 312.7000],
        [312.4000, 312.6000, 312.5500, 312.7000, 312.7000, 312.6500],
        [312.6000, 312.5500, 312.7000, 312.7000, 312.6500, 312.6000],
        [312.5500, 312.7000, 312.7000, 312.6500, 312.6000, 312.4500],
        [312.7000, 312.7000, 312.6500, 312.6000, 312.4500, 312.5500],
        [312.7000, 312.6500, 312.6000, 312.4500, 312.5500, 312.6500],
        [312.6500, 312.6000, 312.4500, 312.5500, 312.6500, 312.6500],
        [312.6000, 312.4500, 312.5500, 312.6500, 312.6500, 312.6000],
        [312.4500, 312.5500, 312.6500, 312.6500, 312.6000, 312.5000],
        [312.5500, 312.6500, 312.6500, 312.6000, 312.5000, 312.3500],
        [312.6500, 312.6500, 312.6000, 312.5000, 312.3500, 312.5000],
        [312.6500, 312.6000, 312.5000, 312.3500, 312.5000, 312.8000],
        [312.6000, 312.5000, 312.3500, 312.5000, 312.8000, 312.9000],
        [312.5000, 312.3500, 312.5000, 312.8000, 312.9000, 313.2500],
        [312.3500, 312.5000, 312.8000, 312.9000, 313.2500, 313.3000],
        [312.5000, 312.8000, 312.9000, 313.2500, 313.3000, 313.6000],
        [312.8000, 312.9000, 313.2500, 313.3000, 313.6000, 314.1000],
        [312.9000, 313.2500, 313.3000, 313.6000, 314.1000, 314.1000],
        [313.2500, 313.3000, 313.6000, 314.1000, 314.1000, 313.7500],
        [313.3000, 313.6000, 314.1000, 314.1000, 313.7500, 313.7000],
        [313.6000, 314.1000, 314.1000, 313.7500, 313.7000, 313.7500],
        [314.1000, 314.1000, 313.7500, 313.7000, 313.7500, 313.6500],
        [314.1000, 313.7500, 313.7000, 313.7500, 313.6500, 313.6000],
        [313.7500, 313.7000, 313.7500, 313.6500, 313.6000, 313.4000],
        [313.7000, 313.7500, 313.6500, 313.6000, 313.4000, 313.3000],
        [313.7500, 313.6500, 313.6000, 313.4000, 313.3000, 313.3000],
        [313.6500, 313.6000, 313.4000, 313.3000, 313.3000, 313.1000],
        [313.6000, 313.4000, 313.3000, 313.3000, 313.1000, 313.0000],
        [313.4000, 313.3000, 313.3000, 313.1000, 313.0000, 313.3000],
        [313.3000, 313.3000, 313.1000, 313.0000, 313.3000, 313.3500],
        [313.3000, 313.1000, 313.0000, 313.3000, 313.3500, 313.1500],
        [313.1000, 313.0000, 313.3000, 313.3500, 313.1500, 313.0000],
        [313.0000, 313.3000, 313.3500, 313.1500, 313.0000, 313.0000],
        [313.3000, 313.3500, 313.1500, 313.0000, 313.0000, 312.9500],
        [313.3500, 313.1500, 313.0000, 313.0000, 312.9500, 312.9000],
        [313.1500, 313.0000, 313.0000, 312.9500, 312.9000, 312.7000],
        [313.0000, 313.0000, 312.9500, 312.9000, 312.7000, 312.9500],
        [313.0000, 312.9500, 312.9000, 312.7000, 312.9500, 312.9000],
        [312.9500, 312.9000, 312.7000, 312.9500, 312.9000, 312.9000],
        [312.9000, 312.7000, 312.9500, 312.9000, 312.9000, 312.8500],
        [312.7000, 312.9500, 312.9000, 312.9000, 312.8500, 312.7500],
        [312.9500, 312.9000, 312.9000, 312.8500, 312.7500, 312.8500],
        [312.9000, 312.9000, 312.8500, 312.7500, 312.8500, 312.8500],
        [312.9000, 312.8500, 312.7500, 312.8500, 312.8500, 312.8000],
        [312.8500, 312.7500, 312.8500, 312.8500, 312.8000, 312.8000],
        [312.7500, 312.8500, 312.8500, 312.8000, 312.8000, 312.7000],
        [312.8500, 312.8500, 312.8000, 312.8000, 312.7000, 312.6000],
        [312.8500, 312.8000, 312.8000, 312.7000, 312.6000, 312.7000],
        [312.8000, 312.8000, 312.7000, 312.6000, 312.7000, 312.6000],
        [312.8000, 312.7000, 312.6000, 312.7000, 312.6000, 312.6500],
        [312.7000, 312.6000, 312.7000, 312.6000, 312.6500, 312.7500],
        [312.6000, 312.7000, 312.6000, 312.6500, 312.7500, 312.6000],
        [312.7000, 312.6000, 312.6500, 312.7500, 312.6000, 312.8500],
        [312.6000, 312.6500, 312.7500, 312.6000, 312.8500, 312.8000],
        [312.6500, 312.7500, 312.6000, 312.8500, 312.8000, 312.8000],
        [312.7500, 312.6000, 312.8500, 312.8000, 312.8000, 312.6500],
        [312.6000, 312.8500, 312.8000, 312.8000, 312.6500, 312.8000],
        [312.8500, 312.8000, 312.8000, 312.6500, 312.8000, 312.7000],
        [312.8000, 312.8000, 312.6500, 312.8000, 312.7000, 312.6500],
        [312.8000, 312.6500, 312.8000, 312.7000, 312.6500, 312.7500],
        [312.6500, 312.8000, 312.7000, 312.6500, 312.7500, 312.9000],
        [312.8000, 312.7000, 312.6500, 312.7500, 312.9000, 312.8500],
        [312.7000, 312.6500, 312.7500, 312.9000, 312.8500, 312.8000],
        [312.6500, 312.7500, 312.9000, 312.8500, 312.8000, 312.8500],
        [312.7500, 312.9000, 312.8500, 312.8000, 312.8500, 312.7500],
        [312.9000, 312.8500, 312.8000, 312.8500, 312.7500, 312.8000],
        [312.8500, 312.8000, 312.8500, 312.7500, 312.8000, 312.9000],
        [312.8000, 312.8500, 312.7500, 312.8000, 312.9000, 312.9500],
        [312.8500, 312.7500, 312.8000, 312.9000, 312.9500, 312.9500],
        [312.7500, 312.8000, 312.9000, 312.9500, 312.9500, 313.0000],
        [312.8000, 312.9000, 312.9500, 312.9500, 313.0000, 312.9500],
        [312.9000, 312.9500, 312.9500, 313.0000, 312.9500, 312.9000],
        [312.9500, 312.9500, 313.0000, 312.9500, 312.9000, 312.7500],
        [312.9500, 313.0000, 312.9500, 312.9000, 312.7500, 313.0000],
        [313.0000, 312.9500, 312.9000, 312.7500, 313.0000, 313.4000],
        [312.9500, 312.9000, 312.7500, 313.0000, 313.4000, 313.1500],
        [312.9000, 312.7500, 313.0000, 313.4000, 313.1500, 313.1500],
        [312.7500, 313.0000, 313.4000, 313.1500, 313.1500, 313.0500],
        [313.0000, 313.4000, 313.1500, 313.1500, 313.0500, 313.0500],
        [313.4000, 313.1500, 313.1500, 313.0500, 313.0500, 313.1500],
        [313.1500, 313.1500, 313.0500, 313.0500, 313.1500, 313.0000],
        [313.1500, 313.0500, 313.0500, 313.1500, 313.0000, 313.0500],
        [313.0500, 313.0500, 313.1500, 313.0000, 313.0500, 313.1000],
        [313.0500, 313.1500, 313.0000, 313.0500, 313.1000, 313.0500],
        [313.1500, 313.0000, 313.0500, 313.1000, 313.0500, 313.0000],
        [313.0000, 313.0500, 313.1000, 313.0500, 313.0000, 313.0500],
        [313.0500, 313.1000, 313.0500, 313.0000, 313.0500, 313.5500],
        [313.1000, 313.0500, 313.0000, 313.0500, 313.5500, 313.4500],
        [313.0500, 313.0000, 313.0500, 313.5500, 313.4500, 313.3000],
        [313.0000, 313.0500, 313.5500, 313.4500, 313.3000, 313.3500],
        [313.0500, 313.5500, 313.4500, 313.3000, 313.3500, 313.3500],
        [313.5500, 313.4500, 313.3000, 313.3500, 313.3500, 313.1500],
        [313.4500, 313.3000, 313.3500, 313.3500, 313.1500, 313.3000],
        [313.3000, 313.3500, 313.3500, 313.1500, 313.3000, 313.3000],
        [313.3500, 313.3500, 313.1500, 313.3000, 313.3000, 313.2500],
        [313.3500, 313.1500, 313.3000, 313.3000, 313.2500, 313.2500],
        [313.1500, 313.3000, 313.3000, 313.2500, 313.2500, 313.3000]]), 'decoder_lengths': tensor([1, 2, 3, 4, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6])}

@RaviKumarAndroid
Copy link
Author

Also please try to fix the issue : #1 (comment)

@jdb78
Copy link
Collaborator

jdb78 commented Aug 16, 2020

Give it another try. I fixed a couple of issues including target normalization. Keep in mind that the strength of the algorithm is forecasting multiple timeseries.

@jdb78
Copy link
Collaborator

jdb78 commented Aug 23, 2020

I fixed some issues in 0.2.1 for GPU support.

@jdb78
Copy link
Collaborator

jdb78 commented Aug 25, 2020

Closing this as there were numerous fixed to the package - thanks again for reporting this issue! Please feel encouraged to raise a new issue in case you encounter any bugs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants