# Monsoon LPS Intensity Forecasting using TFT Model

This notebook performs the following steps:
- Loads and preprocesses the monsoon LPS track dataset
- Computes rolling averages for selected meteorological variables
- Limits each LPS track to a maximum length (5 days)
- Prepares the data for training
- Trains a Temporal Fusion Transformer (TFT) model
- Evaluates the trained model

---

In [None]:
# Import necessary libraries and modules
import os
import warnings
import copy
from pathlib import Path

import numpy as np
import pandas as pd
import torch

# Import the custom model class
from models.IntensityForecasting import LPSIntensityForecasting

## Load and preprocess the data

In [None]:
# Load main dataset and background data
data = pd.read_csv("processed/all-lps-dataframe.csv").drop(columns=['Unnamed: 0'])
bg_data = pd.read_csv("processed/data_with_bg.csv").drop(columns=['Unnamed: 0']).drop(columns=['Unnamed: 0.1'])

# Convert datetime columns to pandas datetime format
data['Genesis_Date'] = pd.to_datetime(data['Genesis_Date'])
data['DateTime'] = pd.to_datetime(data['DateTime'])

# Merge background variables into the main data
data["Q850_bg"] = bg_data["Q850_bg"]
data["VS_bg"] = bg_data["VS_bg"]

## Create time index for each timestep within each LPS track

In [None]:
time_idx = []
for i in range(1, len(data.groupby("id").count()) + 1):
    for j in range(data.groupby("id").count()["Genesis_Date"][i]):
        time_idx.append(j)
data["time_idx"] = time_idx

## Compute rolling averages for key variables (6-hour window)

In [None]:
rolling_columns = [
    'Latitude', 'Longitude', 'mslp', 'ls_ratio', 'VO550', 'VO750', 'VO850',
    'PV', 'Q850', 'Q850_grad', 'Q2', 'US_850', 'UN_850', 'VE_850', 'VW_850',
    'T2', 'Z_tilt', 'integrated_mse', 'Z250', 'Z550', 'Z850', 'RF'
]

window_size = 6
data_rolling = data.groupby('id')[rolling_columns].rolling(window=window_size, min_periods=1).mean().reset_index(drop=True)

# Replace original columns with smoothed values
data[rolling_columns] = data_rolling[rolling_columns]

## Crop tracks longer than 5 days (120 time steps)

In [None]:
# Define the max allowed rows (5 days)
max_rows = 8 * 24 

def process_track(track):
    # Crop the track if it's longer than 5 days
    if len(track) > max_rows:
        return track.iloc[:max_rows]
    return track

# Apply cropping function
data = data.groupby('id', group_keys=False).apply(process_track)

# Reset index after processing
data.reset_index(drop=True, inplace=True)

## Initialize the LPS Intensity Forecasting model

In [None]:
unknown_variables = [
    "Latitude", "Longitude", "mslp", 'ls_ratio', 'VO850', "PV", "T2", "Q850",
    "Q2", "UN_850", "US_850", "VE_850", "RF"
]

max_prediction_length = 5 * 24  # Forecasting 5 days ahead
max_encoder_length = 24  # Using past 24 hours as input
bg_data = ["Q850_bg", "VS_bg"]
target = "mslp"

# Instantiate the TFT-based forecasting model
intensity_tft_model = LPSIntensityForecasting(
    data=data,
    target=target,
    max_prediction_length=max_prediction_length,
    max_encoder_length=max_encoder_length,
    unknown_variables=unknown_variables,
    bg_data=bg_data
)

## Train the model

In [None]:
intensity_tft_model.train()

## Evaluate model performance

In [None]:
intensity_tft_model.evaluate()