In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor

In [3]:
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet, GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss

  from tqdm.autonotebook import tqdm


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
torch.set_float32_matmul_precision('medium')

In [6]:
# Define paths
data_path = "/home/naradaw/code/GCNTFT/data/processed/data_w_geo_v3.csv"
embeddings_path = "/home/naradaw/code/GCNTFT/data/embeddings_v2_lap_202503312035/tft_ready_embeddings.csv"

# Load the air quality data
air_quality_df = pd.read_csv(data_path)
air_quality_df['datetime'] = pd.to_datetime(air_quality_df['datetime'])

# Load the embeddings
embeddings_df = pd.read_csv(embeddings_path, index_col=0)


In [7]:

# Convert the datetime to proper format for TimeSeriesDataSet
air_quality_df['time_idx'] = (air_quality_df['datetime'] - air_quality_df['datetime'].min()).dt.total_seconds() // 3600
air_quality_df['time_idx'] = air_quality_df['time_idx'].astype(int)
air_quality_df = air_quality_df.sort_values(['station_loc', 'time_idx'])

# Add a group_id column (required for pytorch-forecasting)
station_ids = air_quality_df['station_loc'].unique()
station_mapping = {station: idx for idx, station in enumerate(station_ids)}
air_quality_df['group_id'] = air_quality_df['station_loc'].map(station_mapping)

necessary_columns = ['datetime', 'time_idx', 'PM2.5 (ug/m3)', 'latitude', 'longitude', 'station_loc', 'group_id']
air_quality_df = air_quality_df[necessary_columns]
air_quality_df.rename(columns={'PM2.5 (ug/m3)': 'PM25'}, inplace=True)
# air_quality_df['station_loc'] = air_quality_df['station_loc'].astype('category')

In [8]:
air_quality_df.head()

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id
0,2022-03-31 23:00:00,0,122.0,28.797312,77.138667,"Alipur, Delhi",0
1,2022-04-01 00:00:00,1,85.0,28.797312,77.138667,"Alipur, Delhi",0
2,2022-04-01 01:00:00,2,85.0,28.797312,77.138667,"Alipur, Delhi",0
3,2022-04-01 02:00:00,3,81.0,28.797312,77.138667,"Alipur, Delhi",0
4,2022-04-01 03:00:00,4,68.0,28.797312,77.138667,"Alipur, Delhi",0


In [9]:
air_quality_df['time_idx'].value_counts()

time_idx
8760    37
0       37
1       37
2       37
3       37
        ..
13      37
12      37
11      37
10      37
9       37
Name: count, Length: 8761, dtype: int64

In [None]:
# # Merge embeddings with the air quality data
# # First, create a column for the embedding index
# max_time_idx = air_quality_df['time_idx'].max()
# window_size = 24  # assuming 24-hour window size used in GNN

# # We need to match embeddings to the right time points
# processed_data = []

# for station in station_ids:
#     station_data = air_quality_df[air_quality_df['station_loc'] == station]
    
#     # For each time point with sufficient history
#     for t in range(window_size, int(max_time_idx) + 1):
#         if t in station_data['time_idx'].values:
#             curr_data = station_data[station_data['time_idx'] == t].iloc[0].to_dict()
            
#             # Find the embedding for this station at this time point
#             station_idx = station_mapping[station]
#             embedding_idx = (t - window_size) * len(station_ids) + station_idx
            
#             if embedding_idx < len(embeddings_df):
#                 # Add embedding features
#                 for i in range(1, embeddings_df.shape[1]):
#                     curr_data[f'embedding_{i}'] = embeddings_df.iloc[embedding_idx, i]
                
#                 processed_data.append(curr_data)

# # Create new dataframe with embeddings
# combined_df = pd.DataFrame(processed_data)

In [None]:
# combined_df.to_csv("/home/naradaw/code/GCNTFT/data/processed/data_w_geo_v4.csv", index=False)
combined_df = pd.read_csv("/home/naradaw/code/GCNTFT/data/processed/data_w_geo_v4.csv")

In [13]:
combined_df.head()

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id,embedding_1,embedding_2,embedding_3,...,embedding_55,embedding_56,embedding_57,embedding_58,embedding_59,embedding_60,embedding_61,embedding_62,embedding_63,embedding_64
0,2022-04-01 23:00:00,24,174.0,28.797312,77.138667,"Alipur, Delhi",0,0.98076,0.781723,0.089967,...,0.81844,1.003087,-1.53129,-0.277108,-0.400032,0.328314,1.042713,2.067984,1.964328,-1.641902
1,2022-04-02 00:00:00,25,160.0,28.797312,77.138667,"Alipur, Delhi",0,1.045634,0.744446,0.17416,...,0.736366,0.928042,-1.570557,-0.240158,-0.42621,0.207121,1.044672,1.831818,1.962974,-1.470275
2,2022-04-02 01:00:00,26,139.0,28.797312,77.138667,"Alipur, Delhi",0,1.185331,0.842519,0.276737,...,0.786741,1.080483,-1.695227,-0.291337,-0.514872,0.125565,1.194915,1.753948,2.026214,-1.416072
3,2022-04-02 02:00:00,27,156.0,28.797312,77.138667,"Alipur, Delhi",0,1.326631,0.923632,0.321435,...,0.840739,1.282769,-1.822931,-0.407805,-0.556706,0.075545,1.305144,1.692214,2.009205,-1.434708
4,2022-04-02 03:00:00,28,157.0,28.797312,77.138667,"Alipur, Delhi",0,1.429755,0.932156,0.28359,...,0.762491,1.442178,-1.737819,-0.491285,-0.474537,0.052086,1.235436,1.56024,1.808958,-1.356754


In [14]:
# Define prediction parameters
max_prediction_length = 24  # predict 24 hours into the future
max_encoder_length = 72     # use 72 hours of history

# Create training dataset
training_cutoff = combined_df["time_idx"].max() - max_prediction_length

In [15]:
training_cutoff

np.int64(8736)

In [13]:
combined_df['station_loc'] = combined_df['station_loc'].astype('category')
# combined_df.dtypes

In [14]:
station_groups = np.array_split(station_ids, 5)

In [15]:
station_groups

[array(['Alipur, Delhi ', 'Anand Vihar, Delhi ', 'Ashok Vihar, Delhi ',
        'Aya Nagar, Delhi ', 'Bawana, Delhi ', 'Burari Crossing, Delhi ',
        'CRRI Mathura Road, Delhi ', 'Chandni Chowk, Delhi '], dtype=object),
 array(['DTU, Delhi ', 'Dr. Karni Singh Shooting Range, Delhi ',
        'Dwarka Sector 8, Delhi ', 'East Arjun Nagar, Delhi ',
        'IGI Airport (T3), Delhi ', 'IHBAS, Dilshad Garden, Delhi ',
        'ITO, Delhi ', 'Jahangirpuri, Delhi '], dtype=object),
 array(['Jawaharlal Nehru Stadium, Delhi ', 'Lodhi Road, Delhi ',
        'Major Dhyan Chand National Stadium, Delhi ',
        'Mandir Marg, Delhi ', 'Mundka, Delhi ', 'NSIT Dwarka, Delhi ',
        'Najafgarh, Delhi '], dtype=object),
 array(['Narela, Delhi ', 'Nehru Nagar, Delhi ',
        'North Campus, DU, Delhi ', 'Okhla Phase 2, Delhi ',
        'Patparganj, Delhi ', 'Punjabi Bagh, Delhi ', 'Pusa, Delhi '],
       dtype=object),
 array(['R K Puram, Delhi ', 'Rohini, Delhi ', 'Shadipur, Delhi ',
        '

In [16]:
# Prepare the dataset
tft_dataset = TimeSeriesDataSet(
    data=combined_df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="PM25",
    group_ids=["group_id"],
    min_encoder_length=max_encoder_length // 2,  # allow for smaller encoder lengths
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["station_loc"],
    static_reals=["latitude", "longitude"],
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "PM25",
    ] + [f'embedding_{i}' for i in range(1, embeddings_df.shape[1])],
    target_normalizer=GroupNormalizer(
        groups=["group_id"], transformation="softplus"
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

In [17]:
batch_size = 64
train_dataloader = tft_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=4)
val_dataloader = tft_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=4)

In [18]:
# Define the Temporal Fusion Transformer model
tft = TemporalFusionTransformer.from_dataset(
    tft_dataset,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

tft.to(device)

/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


TemporalFusionTransformer(
  	"attention_head_size":               1
  	"categorical_groups":                {}
  	"causal_attention":                  True
  	"dataset_parameters":                {'time_idx': 'time_idx', 'target': 'PM25', 'group_ids': ['group_id'], 'weight': None, 'max_encoder_length': 72, 'min_encoder_length': 36, 'min_prediction_idx': 24, 'min_prediction_length': 1, 'max_prediction_length': 24, 'static_categoricals': ['station_loc'], 'static_reals': ['latitude', 'longitude'], 'time_varying_known_categoricals': [], 'time_varying_known_reals': ['time_idx'], 'time_varying_unknown_categoricals': [], 'time_varying_unknown_reals': ['PM25', 'embedding_1', 'embedding_2', 'embedding_3', 'embedding_4', 'embedding_5', 'embedding_6', 'embedding_7', 'embedding_8', 'embedding_9', 'embedding_10', 'embedding_11', 'embedding_12', 'embedding_13', 'embedding_14', 'embedding_15', 'embedding_16', 'embedding_17', 'embedding_18', 'embedding_19', 'embedding_20', 'embedding_21', 'embedding_

In [19]:
# Verify that the model is a LightningModule
print(f"Model is LightningModule: {isinstance(tft, pl.LightningModule)}")

Model is LightningModule: True


In [None]:
default_root_dir = os.path.join(os.getcwd(), "lightning_logs")

In [20]:
# Configure trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_monitor = LearningRateMonitor(logging_interval="epoch")

# Create directory for checkpoints if it doesn't exist
import os
os.makedirs("models", exist_ok=True)

# Updated Trainer initialization for newer PyTorch Lightning versions
trainer = pl.Trainer(
    max_epochs=30,
    devices = 1, 
    accelerator="gpu", 
    precision= 32,
    gradient_clip_val=0.1,
    # limit_train_batches=50,
    callbacks=[lr_monitor, early_stop_callback],
    enable_checkpointing=True,
    default_root_dir="models/sharded"  # directory to save checkpoints
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [21]:
for epoch in range(5):  # 5 epochs total
    for i, stations_shard in enumerate(station_groups):
        print(f"Epoch {epoch}, Shard {i+1}/{len(station_groups)}")
        
        # Filter data for current stations
        shard_filter = combined_df['station_loc'].isin(stations_shard)
        shard_df = combined_df[shard_filter]
        
        # Create dataset for this shard
        shard_dataset = TimeSeriesDataSet(
            data=shard_df[lambda x: x.time_idx <= training_cutoff],
            time_idx="time_idx",
            target="PM25",
            group_ids=["group_id"],
            min_encoder_length=max_encoder_length // 2,
            max_encoder_length=max_encoder_length,
            min_prediction_length=1,
            max_prediction_length=max_prediction_length,
            static_categoricals=["station_loc"],
            static_reals=["latitude", "longitude"],
            time_varying_known_categoricals=[],
            time_varying_known_reals=["time_idx"],
            time_varying_unknown_categoricals=[],
            time_varying_unknown_reals=[
                "PM25",
            ] + [f'embedding_{i}' for i in range(1, embeddings_df.shape[1])],
            target_normalizer=GroupNormalizer(
                groups=["group_id"], transformation="softplus"
            ),
            add_relative_time_idx=True,
            add_target_scales=True,
            add_encoder_length=True,
        )
        
        # Create dataloader
        shard_train_dataloader = shard_dataset.to_dataloader(
            train=True, 
            batch_size=batch_size, 
            num_workers=4,
            pin_memory=True
        )
        
        # Train for one epoch on this shard

        # try :
        #     trainer.fit(
        #         tft,
        #         train_dataloaders=shard_train_dataloader,
        #         val_dataloaders=val_dataloader,
        #         ckpt_path=None
        #     )

        # except Exception as e:

        trainer.fit(
            tft,
            train_dataloaders=shard_train_dataloader,
            val_dataloaders=val_dataloader,
            ckpt_path="last"  # Continue from previous shard
        )
        
        # Clean up memory
        torch.cuda.empty_cache()
        del shard_dataset, shard_train_dataloader
        
    # Save checkpoint at end of each full epoch
    trainer.save_checkpoint(f"models/sharded/tft_air_quality_epoch_{epoch}.ckpt")

Epoch 0, Shard 1/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 0, Shard 2/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 0, Shard 3/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 0, Shard 4/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 0, Shard 5/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 1, Shard 1/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 1, Shard 2/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 1, Shard 3/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 1, Shard 4/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 1, Shard 5/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 2, Shard 1/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 2, Shard 2/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 2, Shard 3/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 2, Shard 4/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 2, Shard 5/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 3, Shard 1/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 3, Shard 2/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 3, Shard 3/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 3, Shard 4/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 3, Shard 5/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 4, Shard 1/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 4, Shard 2/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 4, Shard 3/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 4, Shard 4/5


/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory models/sharded/lightning_logs/version_1/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 4, Shard 5/5



   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 444    | train
3  | prescalers                         | ModuleDict                      | 1.3 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 9.9 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 86.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 3.8 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 4.3 K  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [22]:
import sys 
sys.exit()

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [23]:
# # Train the model
# trainer.fit(
#     tft,
#     train_dataloaders=train_dataloader,
#     val_dataloaders=val_dataloader,
# )

In [24]:
# # Save the trained model
# trainer.save_checkpoint("models/tft_air_quality_forecast_v1.ckpt")

In [25]:
# Get the last available data point for each station
last_data = combined_df.groupby('station_loc').apply(lambda x: x.iloc[-max_encoder_length:]).reset_index(drop=True)

# Create a prediction dataset
pred_dataset = TimeSeriesDataSet(
    data=last_data,
    time_idx="time_idx",
    target="PM25",
    group_ids=["group_id"],
    min_encoder_length=max_encoder_length // 2,  # allow for smaller encoder lengths
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["station_loc"],
    static_reals=["latitude", "longitude"],
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "PM25",
    ] + [f'embedding_{i}' for i in range(1, embeddings_df.shape[1])],
    target_normalizer=GroupNormalizer(
        groups=["group_id"], transformation="softplus"
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

  last_data = combined_df.groupby('station_loc').apply(lambda x: x.iloc[-max_encoder_length:]).reset_index(drop=True)
  last_data = combined_df.groupby('station_loc').apply(lambda x: x.iloc[-max_encoder_length:]).reset_index(drop=True)


In [26]:
# Make predictions
predictions = tft.predict(pred_dataset)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/naradalinux/miniconda3/envs/graph-tft/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


In [27]:
output_loc = f"/home/naradalinux/dev/GCNTFT/outputs/images/{pd.Timestamp.now().strftime('%m%d%H%M')}"

if not os.path.exists(output_loc):
    os.makedirs(output_loc)

# Plot predictions for each station
for station in station_ids:
    station_idx = station_mapping[station]
    station_preds = predictions[station_idx].detach().cpu().numpy()
    
    plt.figure(figsize=(10, 6))
    plt.plot(range(max_prediction_length), station_preds, label='Prediction')
    
    # Get historical data for comparison
    historical = combined_df[combined_df['station_loc'] == station].tail(max_encoder_length)['PM25'].values
    plt.plot(range(-len(historical), 0), historical, label='Historical')
    
    plt.axvline(x=0, linestyle='--', color='gray')
    plt.title(f'24-Hour Air Quality Forecast for Station {station}')
    plt.xlabel('Hours')
    plt.ylabel('PM25')
    plt.legend()
    plt.grid(True)
    
    # Save the plot
    
    plt.savefig(os.path.join(output_loc, f'station_{station}_forecast.png'))
    plt.close()

print("Forecasting completed. Plots saved to outputs/images/")

Forecasting completed. Plots saved to outputs/images/
