In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor

In [3]:
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet, GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss

  from tqdm.autonotebook import tqdm


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
torch.set_float32_matmul_precision('medium')

In [6]:
# Define paths
data_path = "/home/naradaw/code/GCNTFT/data/processed/data_w_geo_v3.csv"
embeddings_path = "/home/naradaw/code/GCNTFT/data/embeddings_v2_lap_202503312035/tft_ready_embeddings.csv"

# Load the air quality data
air_quality_df = pd.read_csv(data_path)
air_quality_df['datetime'] = pd.to_datetime(air_quality_df['datetime'])

# Load the embeddings
embeddings_df = pd.read_csv(embeddings_path, index_col=0)


In [7]:

# Convert the datetime to proper format for TimeSeriesDataSet
air_quality_df['time_idx'] = (air_quality_df['datetime'] - air_quality_df['datetime'].min()).dt.total_seconds() // 3600
air_quality_df['time_idx'] = air_quality_df['time_idx'].astype(int)
air_quality_df = air_quality_df.sort_values(['station_loc', 'time_idx'])


In [8]:
import re

air_quality_df['station_loc'] = air_quality_df['station_loc'].apply(lambda x: x.split(",")[0].strip())
# Display the unique station names
unique_stations = air_quality_df['station_loc'].unique()
print(f"Number of unique stations: {len(unique_stations)}")
print("Sample station names:")
print(unique_stations[:5])

# Use regex to detect and remove parentheses from station names
air_quality_df['station_loc'] = air_quality_df['station_loc'].apply(lambda x: re.sub(r'\([^)]*\)', '', x).strip())

# Check if any changes were made
print("\nAfter removing parentheses:")
print(air_quality_df['station_loc'].unique()[:5])

# Count unique station names after cleaning
print(f"Number of unique stations after cleaning: {len(air_quality_df['station_loc'].unique())}")

Number of unique stations: 37
Sample station names:
['Alipur' 'Anand Vihar' 'Ashok Vihar' 'Aya Nagar' 'Bawana']

After removing parentheses:
['Alipur' 'Anand Vihar' 'Ashok Vihar' 'Aya Nagar' 'Bawana']
Number of unique stations after cleaning: 37


In [9]:

# Add a group_id column (required for pytorch-forecasting)
station_ids = air_quality_df['station_loc'].unique()
station_mapping = {station: idx for idx, station in enumerate(station_ids)}
air_quality_df['group_id'] = air_quality_df['station_loc'].map(station_mapping)

necessary_columns = ['datetime', 'time_idx', 'PM2.5 (ug/m3)', 'latitude', 'longitude', 'station_loc', 'group_id']
air_quality_df = air_quality_df[necessary_columns]
air_quality_df.rename(columns={'PM2.5 (ug/m3)': 'PM25'}, inplace=True)
# air_quality_df['station_loc'] = air_quality_df['station_loc'].astype('category')

In [10]:
air_quality_df.head()

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id
0,2022-03-31 23:00:00,0,122.0,28.797312,77.138667,Alipur,0
1,2022-04-01 00:00:00,1,85.0,28.797312,77.138667,Alipur,0
2,2022-04-01 01:00:00,2,85.0,28.797312,77.138667,Alipur,0
3,2022-04-01 02:00:00,3,81.0,28.797312,77.138667,Alipur,0
4,2022-04-01 03:00:00,4,68.0,28.797312,77.138667,Alipur,0


In [11]:
# Add time features
air_quality_df["hour"] = pd.to_datetime(air_quality_df["datetime"]).dt.hour.astype("category")
air_quality_df["day_of_week"] = pd.to_datetime(air_quality_df["datetime"]).dt.dayofweek.astype("category")
air_quality_df["month"] = pd.to_datetime(air_quality_df["datetime"]).dt.month.astype("category")

In [12]:
air_quality_df.dtypes

datetime       datetime64[ns]
time_idx                int64
PM25                  float64
latitude              float64
longitude             float64
station_loc            object
group_id                int64
hour                 category
day_of_week          category
month                category
dtype: object

In [13]:
air_quality_df.head()

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id,hour,day_of_week,month
0,2022-03-31 23:00:00,0,122.0,28.797312,77.138667,Alipur,0,23,3,3
1,2022-04-01 00:00:00,1,85.0,28.797312,77.138667,Alipur,0,0,4,4
2,2022-04-01 01:00:00,2,85.0,28.797312,77.138667,Alipur,0,1,4,4
3,2022-04-01 02:00:00,3,81.0,28.797312,77.138667,Alipur,0,2,4,4
4,2022-04-01 03:00:00,4,68.0,28.797312,77.138667,Alipur,0,3,4,4


In [14]:
air_quality_df['station_loc'].unique()

array(['Alipur', 'Anand Vihar', 'Ashok Vihar', 'Aya Nagar', 'Bawana',
       'Burari Crossing', 'CRRI Mathura Road', 'Chandni Chowk', 'DTU',
       'Dr. Karni Singh Shooting Range', 'Dwarka Sector 8',
       'East Arjun Nagar', 'IGI Airport', 'IHBAS', 'ITO', 'Jahangirpuri',
       'Jawaharlal Nehru Stadium', 'Lodhi Road',
       'Major Dhyan Chand National Stadium', 'Mandir Marg', 'Mundka',
       'NSIT Dwarka', 'Najafgarh', 'Narela', 'Nehru Nagar',
       'North Campus', 'Okhla Phase 2', 'Patparganj', 'Punjabi Bagh',
       'Pusa', 'R K Puram', 'Rohini', 'Shadipur', 'Sonia Vihar',
       'Sri Aurobindo Marg', 'Vivek Vihar', 'Wazirpur'], dtype=object)

In [15]:
air_quality_df['time_idx'].value_counts()

time_idx
8760    37
0       37
1       37
2       37
3       37
        ..
13      37
12      37
11      37
10      37
9       37
Name: count, Length: 8761, dtype: int64

In [16]:
# Check unique stations
print(f"Number of unique stations: {air_quality_df['station_loc'].nunique()}")
print(f"Station IDs: {air_quality_df['station_loc'].unique()}")

# Verify data distribution
for station in air_quality_df['station_loc'].unique():
    station_data = air_quality_df[air_quality_df['station_loc'] == station]
    print(f"Station {station}: {len(station_data)} records, PM25 mean: {station_data['PM25'].mean():.2f}")

Number of unique stations: 37
Station IDs: ['Alipur' 'Anand Vihar' 'Ashok Vihar' 'Aya Nagar' 'Bawana'
 'Burari Crossing' 'CRRI Mathura Road' 'Chandni Chowk' 'DTU'
 'Dr. Karni Singh Shooting Range' 'Dwarka Sector 8' 'East Arjun Nagar'
 'IGI Airport' 'IHBAS' 'ITO' 'Jahangirpuri' 'Jawaharlal Nehru Stadium'
 'Lodhi Road' 'Major Dhyan Chand National Stadium' 'Mandir Marg' 'Mundka'
 'NSIT Dwarka' 'Najafgarh' 'Narela' 'Nehru Nagar' 'North Campus'
 'Okhla Phase 2' 'Patparganj' 'Punjabi Bagh' 'Pusa' 'R K Puram' 'Rohini'
 'Shadipur' 'Sonia Vihar' 'Sri Aurobindo Marg' 'Vivek Vihar' 'Wazirpur']
Station Alipur: 8761 records, PM25 mean: 102.88
Station Anand Vihar: 8761 records, PM25 mean: 121.38
Station Ashok Vihar: 8761 records, PM25 mean: 98.88
Station Aya Nagar: 8761 records, PM25 mean: 71.89
Station Bawana: 8761 records, PM25 mean: 114.52
Station Burari Crossing: 8761 records, PM25 mean: 109.38
Station CRRI Mathura Road: 8761 records, PM25 mean: 90.10
Station Chandni Chowk: 8761 records, PM25 me

In [17]:
from lightning.pytorch.loggers import MLFlowLogger

mlf_logger = MLFlowLogger(experiment_name="tft_only_forecasting", tracking_uri="file:./../mlflow_experiments")
# trainer = Trainer(logger=mlf_logger)

In [18]:
# Define max_prediction_length and max_encoder_length
max_prediction_length = 24  # predict 24 hours ahead
max_encoder_length = 168  # use 7 days of history (168 hours) for making predictions

# Define time variables for training/validation split
training_cutoff = air_quality_df["time_idx"].max() - max_prediction_length * 2

# Create the training dataset
training = TimeSeriesDataSet(
    data=air_quality_df[air_quality_df["time_idx"] <= training_cutoff],
    time_idx="time_idx",
    target="PM25",
    group_ids=["station_loc"],
    min_encoder_length=max_encoder_length // 2,  # allow for some missing data
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["station_loc"],
    static_reals=["latitude", "longitude"],
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx","hour", "day_of_week", "month"], #"time_idx"
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=["PM25"],
    target_normalizer=GroupNormalizer(
        groups=["station_loc"], transformation="log1p"
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# Create the validation dataset
validation = TimeSeriesDataSet.from_dataset(training, air_quality_df, min_prediction_idx=training_cutoff + 1)

# Create data loaders
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8, shuffle=True)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 2, num_workers=8)

# Define TFT model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=64,           # Increase from 64
    attention_head_size=4,     # Increase from 4
    dropout=0.3,
    hidden_continuous_size=32, # Increase from 32
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
    weight_decay = 0.001, #l2 regularization
)

# Create PyTorch Lightning trainer
early_stop_callback = EarlyStopping(
    monitor="val_loss", 
    min_delta=1e-4, 
    patience=10, 
    verbose=True, 
    mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    gradient_clip_val=0.1,
    callbacks=[
        early_stop_callback, 
        lr_monitor,
        # Add learning rate scheduler
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ],
    logger=mlf_logger
)

# Train the model
trainer.fit(
    tft, 
    train_dataloaders=train_dataloader, 
    val_dataloaders=val_dataloader,
)

# Save the trained model
trainer.save_checkpoint("air_quality_tft_model.ckpt")

/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
-----------------------------

Epoch 0: 100%|██████████| 5050/5050 [07:14<00:00, 11.63it/s, v_num=c6b7, train_loss_step=20.80, val_loss=18.30, train_loss_epoch=20.10]

Metric val_loss improved. New best score: 18.302


Epoch 9: 100%|██████████| 5050/5050 [07:30<00:00, 11.20it/s, v_num=c6b7, train_loss_step=15.00, val_loss=16.70, train_loss_epoch=20.20]

Metric val_loss improved by 1.586 >= min_delta = 0.0001. New best score: 16.716


Epoch 10: 100%|██████████| 5050/5050 [07:29<00:00, 11.24it/s, v_num=c6b7, train_loss_step=22.40, val_loss=16.30, train_loss_epoch=20.20]

Metric val_loss improved by 0.435 >= min_delta = 0.0001. New best score: 16.282


Epoch 20: 100%|██████████| 5050/5050 [07:30<00:00, 11.21it/s, v_num=c6b7, train_loss_step=18.00, val_loss=21.30, train_loss_epoch=20.10]

Monitored metric val_loss did not improve in the last 10 records. Best score: 16.282. Signaling Trainer to stop.


Epoch 20: 100%|██████████| 5050/5050 [07:35<00:00, 11.09it/s, v_num=c6b7, train_loss_step=18.00, val_loss=21.30, train_loss_epoch=20.10]


In [19]:
training_cutoff, air_quality_df["time_idx"].nunique()


(np.int64(8712), 8761)

In [20]:
last_data = air_quality_df.groupby('station_loc').apply(lambda x: x.iloc[-max_encoder_length:]).reset_index(drop=True)


  last_data = air_quality_df.groupby('station_loc').apply(lambda x: x.iloc[-max_encoder_length:]).reset_index(drop=True)


In [21]:
last_data.time_idx.nunique() * 37

6216

In [32]:
last_data

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id,hour,day_of_week,month
0,2023-03-25 00:00:00,8593,63.00,28.797312,77.138667,Alipur,0,0,5,3
1,2023-03-25 01:00:00,8594,44.00,28.797312,77.138667,Alipur,0,1,5,3
2,2023-03-25 02:00:00,8595,33.00,28.797312,77.138667,Alipur,0,2,5,3
3,2023-03-25 03:00:00,8596,31.00,28.797312,77.138667,Alipur,0,3,5,3
4,2023-03-25 04:00:00,8597,22.00,28.797312,77.138667,Alipur,0,4,5,3
...,...,...,...,...,...,...,...,...,...,...
6211,2023-03-31 19:00:00,8756,26.25,28.680084,77.170221,Wazirpur,36,19,4,3
6212,2023-03-31 20:00:00,8757,27.50,28.680084,77.170221,Wazirpur,36,20,4,3
6213,2023-03-31 21:00:00,8758,38.50,28.680084,77.170221,Wazirpur,36,21,4,3
6214,2023-03-31 22:00:00,8759,36.25,28.680084,77.170221,Wazirpur,36,22,4,3


In [33]:

pred_dataset = TimeSeriesDataSet(
    data=last_data,
    time_idx="time_idx",
    target="PM25",
    group_ids=["station_loc"],
    min_encoder_length=max_encoder_length // 2,  # allow for some missing data
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["station_loc"],
    static_reals=["latitude", "longitude"],
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx","hour", "day_of_week", "month"], #"time_idx"
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=["PM25"],
    target_normalizer=GroupNormalizer(
        groups=["station_loc"], transformation="log1p"
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)


In [34]:

# Make predictions
predictions = tft.predict(pred_dataset)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


In [35]:
predictions.shape

torch.Size([6179, 24])

In [43]:
# Check prediction variation across stations
print(f"Predictions shape: {predictions.shape}")
print("Prediction samples for first 5 timestamps:")
for i, station in enumerate(list(station_ids)[:5]):
    station_idx = station_mapping[station]
    print(f"Station {station}: {predictions[station_idx][:5].detach().cpu().numpy()}")

Predictions shape: torch.Size([6179, 24])
Prediction samples for first 5 timestamps:
Station Alipur: [24.984472 25.007977 25.178965 25.487686 25.921286]
Station Anand Vihar: [24.984472 25.007977 25.178965 25.487686 25.921286]
Station Ashok Vihar: [24.984472 25.007977 25.178965 25.487686 25.921286]
Station Aya Nagar: [24.984472 25.007977 25.178965 25.487686 25.921286]
Station Bawana: [24.984472 25.007977 25.178965 25.487686 25.921286]


In [44]:
# Print station mapping to verify it's correct
print("Station mapping:")
for station, idx in station_mapping.items():
    print(f"Station {station} -> Index {idx}")

Station mapping:
Station Alipur -> Index 0
Station Anand Vihar -> Index 1
Station Ashok Vihar -> Index 2
Station Aya Nagar -> Index 3
Station Bawana -> Index 4
Station Burari Crossing -> Index 5
Station CRRI Mathura Road -> Index 6
Station Chandni Chowk -> Index 7
Station DTU -> Index 8
Station Dr. Karni Singh Shooting Range -> Index 9
Station Dwarka Sector 8 -> Index 10
Station East Arjun Nagar -> Index 11
Station IGI Airport -> Index 12
Station IHBAS -> Index 13
Station ITO -> Index 14
Station Jahangirpuri -> Index 15
Station Jawaharlal Nehru Stadium -> Index 16
Station Lodhi Road -> Index 17
Station Major Dhyan Chand National Stadium -> Index 18
Station Mandir Marg -> Index 19
Station Mundka -> Index 20
Station NSIT Dwarka -> Index 21
Station Najafgarh -> Index 22
Station Narela -> Index 23
Station Nehru Nagar -> Index 24
Station North Campus -> Index 25
Station Okhla Phase 2 -> Index 26
Station Patparganj -> Index 27
Station Punjabi Bagh -> Index 28
Station Pusa -> Index 29
Station

In [45]:
# Create separate prediction datasets for each station
all_predictions = []
station_preds_dict = {}

for station in station_ids:
    # Filter last data for just this station
    station_last_data = last_data[last_data['station_loc'] == station].reset_index(drop=True)
    
    # Create a station-specific prediction dataset
    station_pred_dataset = TimeSeriesDataSet(
        data=station_last_data,
        time_idx="time_idx",
        target="PM25",
        group_ids=["station_loc"],
        min_encoder_length=max_encoder_length // 2,
        max_encoder_length=max_encoder_length,
        min_prediction_length=1,
        max_prediction_length=max_prediction_length,
        static_categoricals=["station_loc"],
        static_reals=["latitude", "longitude"],
        time_varying_known_categoricals=[],
        time_varying_known_reals=["time_idx", "hour", "day_of_week", "month"],
        time_varying_unknown_categoricals=[],
        time_varying_unknown_reals=["PM25"],
        target_normalizer=GroupNormalizer(
            groups=["station_loc"], transformation="log1p"
        ),
        add_relative_time_idx=True,
        add_target_scales=True,
        add_encoder_length=True,
    )
    
    # Create dataloader
    pred_dataloader = station_pred_dataset.to_dataloader(batch_size=1, train=False, shuffle=False)
    
    # Generate prediction for this station
    station_prediction = tft.predict(pred_dataloader)
    station_preds_dict[station] = station_prediction[0].detach().cpu().numpy()
    print(f"Station {station} prediction mean: {station_prediction[0].mean().item():.4f}")

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Alipur prediction mean: 25.7630


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Anand Vihar prediction mean: 38.6262


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Ashok Vihar prediction mean: 21.9389


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Aya Nagar prediction mean: 19.3584


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Bawana prediction mean: 30.8356


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Burari Crossing prediction mean: 24.0747


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station CRRI Mathura Road prediction mean: 19.3829


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Chandni Chowk prediction mean: 30.1853


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station DTU prediction mean: 28.3143


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Dr. Karni Singh Shooting Range prediction mean: 26.5865


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Dwarka Sector 8 prediction mean: 27.8226


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station East Arjun Nagar prediction mean: 27.9763


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station IGI Airport prediction mean: 26.3854


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station IHBAS prediction mean: 28.6942


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station ITO prediction mean: 29.2822


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Jahangirpuri prediction mean: 32.0560


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Jawaharlal Nehru Stadium prediction mean: 24.4023


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Lodhi Road prediction mean: 21.4182


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Major Dhyan Chand National Stadium prediction mean: 28.4186


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Mandir Marg prediction mean: 26.5525


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Mundka prediction mean: 28.7108


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station NSIT Dwarka prediction mean: 25.5814


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Najafgarh prediction mean: 19.0731


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Narela prediction mean: 29.4543


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Nehru Nagar prediction mean: 30.9204


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station North Campus prediction mean: 16.1365


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Okhla Phase 2 prediction mean: 26.0362


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Patparganj prediction mean: 29.5695


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Punjabi Bagh prediction mean: 30.1490


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Pusa prediction mean: 27.1588


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station R K Puram prediction mean: 26.5606


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Rohini prediction mean: 31.6524


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Shadipur prediction mean: 60.3552


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Sonia Vihar prediction mean: 30.9660


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Sri Aurobindo Marg prediction mean: 24.3410


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Station Vivek Vihar prediction mean: 32.9103


/home/naradaw/miniconda3/envs/gnns/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Station Wazirpur prediction mean: 32.3688


In [49]:
# Set up output directory
output_loc = f"/home/naradaw/code/GCNTFT/outputs/images_tft_only/{pd.Timestamp.now().strftime('%m%d%H%M')}_station_specific"

if not os.path.exists(output_loc):
    os.makedirs(output_loc)

# Plot predictions for each station using station-specific predictions
for station in station_ids:
    # Get the prediction for this station
    station_pred = station_preds_dict[station]
    
    plt.figure(figsize=(10, 6))
    
    # Plot the prediction
    plt.plot(range(max_prediction_length), station_pred, label='Prediction', color='blue')
    
    # Get historical data for comparison
    historical = air_quality_df[air_quality_df['station_loc'] == station].tail(max_encoder_length)['PM25'].values
    plt.plot(range(-len(historical), 0), historical, label='Historical', color='green')
    
    # Check if we have actual future data for this station
    station_data = air_quality_df[air_quality_df['station_loc'] == station]
    last_historical_idx = station_data[station_data['time_idx'] <= station_data['time_idx'].max() - max_prediction_length].tail(1)['time_idx'].item()
    
    # Get actual values for the prediction period if they exist in the dataset
    future_indices = range(last_historical_idx + 1, last_historical_idx + max_prediction_length + 1)
    future_data = station_data[station_data['time_idx'].isin(future_indices)]['PM25'].values
    
    # Plot actual future values if available
    if len(future_data) > 0:
        plt.plot(range(min(len(future_data), max_prediction_length)), future_data[:max_prediction_length], 
                 label='Actual', color='red', linestyle='-')
    
    plt.axvline(x=0, linestyle='--', color='gray')
    plt.title(f'24-Hour Air Quality Forecast for Station {station}')
    plt.xlabel('Hours')
    plt.ylabel('PM25')
    plt.legend()
    plt.grid(True)
    
    # Save the plot
    plt.savefig(os.path.join(output_loc, f'station_{station}_forecast.png'))
    plt.close()

# Plot summary comparison of all station predictions
plt.figure(figsize=(12, 8))
for station in station_ids:
    plt.plot(range(max_prediction_length), station_preds_dict[station], label=f'Station {station}')

plt.title('Comparison of Predictions Across All Stations')
plt.xlabel('Hours')
plt.ylabel('PM25')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_loc, 'all_stations_comparison.png'))
plt.close()

# Print prediction statistics
print("\nPrediction statistics:")
for station in station_ids:
    pred = station_preds_dict[station]
    print(f"Station {station}: Min={pred.min():.2f}, Max={pred.max():.2f}, Mean={pred.mean():.2f}, Std={pred.std():.2f}")

print(f"\nForecasting completed. Plots saved to {output_loc}")


Prediction statistics:
Station Alipur: Min=22.96, Max=28.34, Mean=25.76, Std=1.64
Station Anand Vihar: Min=34.68, Max=42.24, Mean=38.63, Std=2.31
Station Ashok Vihar: Min=19.51, Max=24.17, Mean=21.94, Std=1.42
Station Aya Nagar: Min=17.14, Max=21.40, Mean=19.36, Std=1.30
Station Bawana: Min=27.59, Max=33.81, Mean=30.84, Std=1.90
Station Burari Crossing: Min=21.40, Max=26.53, Mean=24.07, Std=1.57
Station CRRI Mathura Road: Min=16.81, Max=21.78, Mean=19.38, Std=1.52
Station Chandni Chowk: Min=28.03, Max=32.12, Mean=30.19, Std=1.25
Station DTU: Min=26.66, Max=29.79, Mean=28.31, Std=0.95
Station Dr. Karni Singh Shooting Range: Min=24.09, Max=28.86, Mean=26.59, Std=1.45
Station Dwarka Sector 8: Min=24.94, Max=30.46, Mean=27.82, Std=1.68
Station East Arjun Nagar: Min=27.97, Max=27.98, Mean=27.98, Std=0.00
Station IGI Airport: Min=23.89, Max=28.66, Mean=26.39, Std=1.46
Station IHBAS: Min=26.23, Max=30.93, Mean=28.69, Std=1.43
Station ITO: Min=26.69, Max=31.64, Mean=29.28, Std=1.51
Station Ja

In [48]:
# Check variable importance
raw_predictions, x = tft.predict(val_dataloader, return_x=True)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


ValueError: too many values to unpack (expected 2)

In [None]:

# Generate interpretation
interpretation = tft.interpret_prediction(x, reduction="mean")

# Plot variable importance
plt.figure(figsize=(10, 8))
order = interpretation["static_variables"].mean(axis=0).argsort()
plt.barh(
    np.array(interpretation["static_variable_names"])[order],
    interpretation["static_variables"].mean(axis=0)[order]
)
plt.title("Static Variable Importance")
plt.tight_layout()
plt.savefig(os.path.join(output_loc, "static_variable_importance.png"))

In [36]:
station = station_ids[5]
station

'Burari Crossing'

In [37]:
station_idx = station_mapping[station]
station_idx

5

In [38]:
station_preds = predictions[station_idx].detach().cpu().numpy()
station_preds

array([24.984472, 25.007977, 25.178965, 25.487686, 25.921286, 26.458723,
       27.062433, 27.662863, 28.14237 , 28.3374  , 28.094805, 27.365152,
       26.261204, 25.020758, 23.92476 , 23.203539, 22.960508, 23.161558,
       23.692024, 24.422945, 25.249105, 26.096481, 26.921837, 27.700588],
      dtype=float32)

In [39]:
last_timestamp = air_quality_df[air_quality_df['station_loc'] == station]['time_idx'].max()
last_timestamp

np.int64(8760)

In [40]:
air_quality_df['station_loc']

0           Alipur
1           Alipur
2           Alipur
3           Alipur
4           Alipur
            ...   
324152    Wazirpur
324153    Wazirpur
324154    Wazirpur
324155    Wazirpur
324156    Wazirpur
Name: station_loc, Length: 324157, dtype: object

In [41]:
air_quality_df[air_quality_df['station_loc'] == station]

Unnamed: 0,datetime,time_idx,PM25,latitude,longitude,station_loc,group_id,hour,day_of_week,month
43805,2022-03-31 23:00:00,0,107.52,28.728594,77.199325,Burari Crossing,5,23,3,3
43806,2022-04-01 00:00:00,1,81.85,28.728594,77.199325,Burari Crossing,5,0,4,4
43807,2022-04-01 01:00:00,2,60.89,28.728594,77.199325,Burari Crossing,5,1,4,4
43808,2022-04-01 02:00:00,3,71.02,28.728594,77.199325,Burari Crossing,5,2,4,4
43809,2022-04-01 03:00:00,4,67.48,28.728594,77.199325,Burari Crossing,5,3,4,4
...,...,...,...,...,...,...,...,...,...,...
52561,2023-03-31 19:00:00,8756,12.63,28.728594,77.199325,Burari Crossing,5,19,4,3
52562,2023-03-31 20:00:00,8757,28.63,28.728594,77.199325,Burari Crossing,5,20,4,3
52563,2023-03-31 21:00:00,8758,19.34,28.728594,77.199325,Burari Crossing,5,21,4,3
52564,2023-03-31 22:00:00,8759,25.36,28.728594,77.199325,Burari Crossing,5,22,4,3


In [42]:
output_loc = f"/home/naradaw/code/GCNTFT/outputs/images_tft_only/{pd.Timestamp.now().strftime('%m%d%H%M')}"

if not os.path.exists(output_loc):
    os.makedirs(output_loc)

# Plot predictions for each station
for station in station_ids:
    station_idx = station_mapping[station]
    station_preds = predictions[station_idx].detach().cpu().numpy()
    
    # Get the future actual values for comparison if they exist
    last_timestamp = air_quality_df[air_quality_df['station_loc'] == station]['time_idx'].max()
    
    # Check if we have actual data for the prediction period
    station_data = air_quality_df[air_quality_df['station_loc'] == station]
    last_historical_idx = station_data.iloc[-1]['time_idx']
    
    # Get actual values for the prediction period if they exist in the dataset
    future_data = None
    if last_timestamp >= last_historical_idx + max_prediction_length:
        future_indices = range(last_historical_idx + 1, last_historical_idx + max_prediction_length + 1)
        future_data = station_data[station_data['time_idx'].isin(future_indices)]['PM25'].values
    plt.figure(figsize=(10, 6))
    plt.plot(range(max_prediction_length), station_preds, label='Prediction')
    
    # Get historical data for comparison
    historical = air_quality_df[air_quality_df['station_loc'] == station].tail(max_encoder_length)['PM25'].values
    plt.plot(range(-len(historical), 0), historical, label='Historical')
    
    plt.axvline(x=0, linestyle='--', color='gray')
    plt.title(f'24-Hour Air Quality Forecast for Station {station}')
    plt.xlabel('Hours')
    plt.ylabel('PM25')
    plt.legend()
    plt.grid(True)
    
    # Save the plot
    
    plt.savefig(os.path.join(output_loc, f'station_{station}_forecast.png'))
    plt.close()

print("Forecasting completed. Plots saved to outputs/images/")

Forecasting completed. Plots saved to outputs/images/
