In [1]:

from datetime import datetime, timedelta
import pandas as pd
from sqlalchemy import create_engine
import time
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def fetch_noaa_data(conn, cities, start_ts, end_ts):
    """Retrieve NOAA GSOD data for target cities"""
    tick = time.time()
    noaa_query = f"""
        SELECT 
            city_name AS city,
            ts,
            latitude,
            longitude,
            temp AS temperature,
            dewp AS dew_point,
            slp AS sea_level_pressure,
            wdsp AS wind_speed,
            mxspd AS max_wind_speed,           
            prcp AS precipitation,
            gust
        FROM 
            ods_noaa_addr
        WHERE 
            ts >= {start_ts} and ts < {end_ts}
            AND city_name IN ('{"','".join(cities)}')
    """
    
    
    # Execute queries and fetch data
    noaa_data = pd.read_sql(noaa_query, conn)
    print(f"completed with {len(noaa_data)} rows in {time.time() - tick:.6f}s")
    return noaa_data


def fetch_openaq_data(conn, cities, start_ts, end_ts):
    """Retrieve OpenAQ measurements with EPA AQI conversion"""
   
    # Fetch OpenAQ data
    tick = time.time()
    openaq_query = f"""
        SELECT 
            city_name AS city,
            ts,
            parameter,
            value
        FROM 
            ods_openaq_addr
        WHERE 
            ts >= {start_ts} and ts < {end_ts}
            AND city_name IN ('{"','".join(cities)}')
    """
    
    # Execute queries and fetch data
    openaq_data = pd.read_sql(openaq_query, conn)
    print(f"completed with {len(openaq_data)} rows in {time.time() - tick:.6f}s")
    return openaq_data


In [7]:
if __name__ == "__main__":
    # Configuration

    # conn = hive.Connection(host='localhost', port=10000, database='default')
    conn = create_engine('hive://localhost:10000/default')
    # cursor = conn.cursor()
    
    cities = [
        'Adams County',
        'Albany',
        'Amarillo',
        'Anchorage',
        'Austin',
        'Boston',
        'Buckeye',
        'Bullhead City',
        'Burke County',
        'Calera',
        'Casa Grande',
        'Charleston',
        'Chatham County',
        'Chester',
        'Chesterfield County',
        'Chicago',
        'Clark County',
        'Cleburne',
        # 'Cleveland',
        # 'Columbia',
        # 'Cornwall',     
    ]
    start_ts = int(pd.to_datetime('2023-01-01', utc=True).timestamp())
    end_ts = int(pd.to_datetime('2023-05-01', utc=True).timestamp())
    
    # Data Pipeline
    print(f"Fetching NOAA data ({start_ts} ~ {end_ts})...")
    noaa_data = fetch_noaa_data(conn, cities, start_ts, end_ts)
    display(noaa_data.head(1))
    
    print(f"Fetching OpenAQ data ({start_ts} ~ {end_ts})...")
    aqi_data = fetch_openaq_data(conn, cities, start_ts, end_ts)
    display(aqi_data.head(1))




Fetching NOAA data (1672531200 ~ 1682899200)...
completed with 3409 rows in 4.246039s


Unnamed: 0,city,ts,latitude,longitude,temperature,dew_point,sea_level_pressure,wind_speed,max_wind_speed,precipitation,gust
0,Anchorage,1672531200,60.78351,-148.84839,35.4,32.0,975.3,12.4,21.0,1.24,51.1


Fetching OpenAQ data (1672531200 ~ 1682899200)...
completed with 108605 rows in 5.296676s


Unnamed: 0,city,ts,parameter,value
0,Anchorage,1672567200,pm10,24.0


In [85]:
from sklearn.cluster import KMeans  # For spatial clustering

def resample_noaa(noaa_data):
    noaa_data['dt'] = pd.to_datetime(noaa_data['ts'], unit='s', utc=True)
    noaa_data = noaa_data.set_index('dt')
    noaa_data['gust'] = noaa_data['gust'].replace(999.9, np.nan)
    noaa_data['gust'] = noaa_data.groupby('city')['gust'].ffill()
    noaa_by_day = noaa_data.groupby(['city']).resample('d').agg(
        temperature_avg=('temperature', 'mean'),
        temperature_open=('temperature', 'first'),
        temperature_close=('temperature', 'last'),
        temperature_max=('temperature', 'max'),
        temperature_min=('temperature', 'min'),
    
        dew_point_avg=  ('dew_point', 'mean'),
        dew_point_open= ('dew_point', 'first'),
        dew_point_close=('dew_point', 'last'),
        dew_point_max=  ('dew_point', 'max'),
        dew_point_min=  ('dew_point', 'min'),
    
        sea_level_pressure_avg=  ('sea_level_pressure', 'mean'),
        sea_level_pressure_open= ('sea_level_pressure', 'first'),
        sea_level_pressure_close=('sea_level_pressure', 'last'),
        sea_level_pressure_max=  ('sea_level_pressure', 'max'),
        sea_level_pressure_min=  ('sea_level_pressure', 'min'),
    
        wind_speed_avg=  ('wind_speed', 'mean'),
        wind_speed_open= ('wind_speed', 'first'),
        wind_speed_close=('wind_speed', 'last'),
        wind_speed_max=  ('wind_speed', 'max'),
        wind_speed_min=  ('wind_speed', 'min'),
    
        max_wind_speed_avg=  ('max_wind_speed', 'mean'),
        max_wind_speed_open= ('max_wind_speed', 'first'),
        max_wind_speed_close=('max_wind_speed', 'last'),
        max_wind_speed_max=  ('max_wind_speed', 'max'),
        max_wind_speed_min=  ('max_wind_speed', 'min'),

        gust_avg=  ('gust', 'mean'),
        gust_open= ('gust', 'first'),
        gust_close=('gust', 'last'),
        gust_max=  ('gust', 'max'),
        gust_min=  ('gust', 'min'),
    
       precipitation=  ('precipitation', 'max'),
    ).reset_index()
    
    noaa_by_day['ts'] = noaa_by_day['dt'].astype('int64')
    return noaa_by_day.set_index('dt')
noaa_by_day = resample_noaa(noaa_data)
display(noaa_by_day)


Unnamed: 0_level_0,city,temperature_avg,temperature_open,temperature_close,temperature_max,temperature_min,dew_point_avg,dew_point_open,dew_point_close,dew_point_max,dew_point_min,sea_level_pressure_avg,sea_level_pressure_open,sea_level_pressure_close,sea_level_pressure_max,sea_level_pressure_min,wind_speed_avg,wind_speed_open,wind_speed_close,wind_speed_max,wind_speed_min,max_wind_speed_avg,max_wind_speed_open,max_wind_speed_close,max_wind_speed_max,max_wind_speed_min,gust_avg,gust_open,gust_close,gust_max,gust_min,precipitation,ts
dt,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1
2023-01-01 00:00:00+00:00,Cleveland,41.950000,41.9,42.0,42.0,41.9,38.400000,38.6,38.2,38.6,38.2,1015.100000,1014.6,1015.6,1015.6,1014.6,5.850000,5.1,6.6,6.6,5.1,14.000000,15.0,13.0,15.0,13.0,23.000000,19.0,27.0,27.0,19.0,0.20,1672531200000000000
2023-01-02 00:00:00+00:00,Cleveland,46.550000,45.2,47.9,47.9,45.2,43.950000,43.0,44.9,44.9,43.0,1018.450000,1017.7,1019.2,1019.2,1017.7,4.450000,5.1,3.8,5.1,3.8,9.050000,11.1,7.0,11.1,7.0,23.000000,19.0,27.0,27.0,19.0,0.00,1672617600000000000
2023-01-03 00:00:00+00:00,Cleveland,53.600000,53.0,54.2,54.2,53.0,50.900000,50.3,51.5,51.5,50.3,1010.850000,1008.8,1012.9,1012.9,1008.8,5.950000,6.1,5.8,6.1,5.8,10.950000,12.0,9.9,12.0,9.9,22.550000,18.1,27.0,27.0,18.1,0.26,1672704000000000000
2023-01-04 00:00:00+00:00,Cleveland,58.000000,58.7,57.3,58.7,57.3,54.200000,54.9,53.5,54.9,53.5,1006.700000,1005.9,1007.5,1007.5,1005.9,8.950000,9.7,8.2,9.7,8.2,17.000000,15.0,19.0,19.0,15.0,22.550000,21.0,24.1,24.1,21.0,0.35,1672790400000000000
2023-01-05 00:00:00+00:00,Cleveland,42.150000,42.6,41.7,42.6,41.7,31.600000,32.2,31.0,32.2,31.0,1013.150000,1012.7,1013.6,1013.6,1012.7,10.750000,11.6,9.9,11.6,9.9,17.000000,15.0,19.0,19.0,15.0,22.000000,22.0,22.0,22.0,22.0,1.01,1672876800000000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-04-26 00:00:00+00:00,Cornwall,49.900000,51.3,49.9,51.3,48.3,44.950000,48.2,43.6,48.2,42.7,5507.225000,9999.9,9999.9,9999.9,1014.4,12.225000,14.5,13.3,14.5,10.0,15.550000,17.1,18.1,18.1,13.0,28.250000,33.0,26.0,33.0,26.0,0.01,1682467200000000000
2023-04-27 00:00:00+00:00,Cornwall,51.750000,51.6,52.3,52.3,50.8,50.775000,50.6,51.4,51.4,50.2,5505.050000,9999.9,9999.9,9999.9,1010.2,12.225000,14.8,12.2,14.8,10.5,19.550000,22.0,21.0,22.0,17.1,28.250000,33.0,26.0,33.0,26.0,99.99,1682553600000000000
2023-04-28 00:00:00+00:00,Cornwall,51.700000,51.6,51.4,52.1,51.4,50.675000,51.3,50.7,51.3,49.9,5507.200000,9999.9,9999.9,9999.9,1014.1,7.950000,8.5,8.4,8.5,6.6,13.975000,13.0,15.9,15.9,13.0,28.250000,33.0,26.0,33.0,26.0,99.99,1682640000000000000
2023-04-29 00:00:00+00:00,Cornwall,53.000000,52.0,53.8,53.8,52.0,50.850000,51.2,51.4,51.4,49.9,5509.950000,9999.9,9999.9,9999.9,1019.8,5.525000,6.2,5.8,6.2,4.9,10.200000,9.9,11.1,11.1,9.9,24.750000,33.0,26.0,33.0,13.0,99.99,1682726400000000000


In [101]:
import numpy as np
# display(noaa_data)
# display(aqi_data)
def resample_openaq(aqi_data):
    aqi_data['dt'] = pd.to_datetime(aqi_data['ts'], unit='s', utc=True)
    aqi_data = aqi_data.set_index('dt')
    
    aqi_by_day = aqi_data.groupby(['city', 'parameter']).resample('d').agg(
        avg=('value', 'mean'),
        open=('value', 'first'),
        close=('value', 'last'),
        max=('value', 'max'),
        min=('value', 'min')
    ).reset_index()
    aqi_by_day['ts'] = aqi_by_day['dt'].astype('int64')
    return aqi_by_day.set_index('dt')

def pivot_openaq(df):
    all_parameters = ['bc', 'co', 'no', 'no2', 'nox', 'o3', 'pm1', 'pm10', 'pm25', 'so2']
    reshaped_df = df.pivot_table(
        index=['dt', 'city'],
        columns=['parameter'],
        values=['avg', 'open', 'close', 'min', 'max'],
        aggfunc='first'  # Use 'first' to keep the original values
    )
    
    # Step 5: Flatten the MultiIndex columns
    reshaped_df.columns = [f"{param}_{stat}" for stat, param in reshaped_df.columns]
    
    # Step 6: Reset the index to make 'dt' a column again
    reshaped_df = reshaped_df.reset_index()
    
    # Step 7: Add the datetime index as a column
    # reshaped_df.insert(0, 'dt', df.index.unique())
    
    # Step 8: Ensure all required columns exist (even if they contain only NaN values)
    for param in all_parameters:
        for stat in ['avg', 'open', 'close', 'min', 'max']:
            col_name = f"{param}_{stat}"
            if col_name not in reshaped_df.columns:
                reshaped_df[col_name] = np.nan
    reshaped_df['ts'] = reshaped_df['dt'].astype('int64')
    return reshaped_df.set_index('dt')

aqi_by_day = resample_openaq(aqi_data)
# rehaped_aqi = pivot_openaq(aqi_by_day.reset_index())

display(aqi_by_day.head(5))

Unnamed: 0_level_0,city,parameter,avg,open,close,max,min,ts
dt,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2023-01-01 00:00:00+00:00,Cleveland,co,0.116667,0.1,0.1,0.3,0.1,1672531200000000000
2023-01-02 00:00:00+00:00,Cleveland,co,0.208333,0.2,0.5,0.5,0.1,1672617600000000000
2023-01-03 00:00:00+00:00,Cleveland,co,0.225,0.4,0.1,0.7,0.1,1672704000000000000
2023-01-04 00:00:00+00:00,Cleveland,co,0.173913,0.1,0.1,0.3,0.1,1672790400000000000
2023-01-05 00:00:00+00:00,Cleveland,co,0.120833,0.1,0.1,0.2,0.1,1672876800000000000


In [102]:

def merge_data(noaa_df, aqi_df):
    """Merge and engineer spatio-temporal features"""
    # Temporal alignment

    merged = pd.merge(
        noaa_df.reset_index(),
        aqi_df.loc[aqi_df['parameter'] == 'pm25'].reset_index(),
        how='inner',  # Change to 'left', 'right', or 'outer' as needed
        on=['dt', 'city']
    )
    # merged = merged.loc[~ merged['ts_x'] == np.nan]
    # merged = merged.loc[~ merged['ts_y'] == np.nan]
    
    # EPA-recommended feature engineering
    merged['temp_wind_interaction'] = merged['temperature_avg'] * merged['wind_speed_avg']
    merged['precip_accum_72h'] = merged.groupby('city')['precipitation'].transform(
        lambda x: x.rolling(3, min_periods=1).sum()
    )
    
    # Temporal features
    merged['day_of_week'] = merged['dt'].dt.dayofweek
    # merged['is_holiday'] = merged['ts'].dt.date.apply(check_holiday)
    merged['is_weekend'] = merged['day_of_week'].isin([5,6]).astype(int)
    merged['month'] = merged['dt'].dt.month
    merged = merged.drop(columns=['open', 'close', 'max', 'min'])

    # 3. Spatial Clustering (Example for 3 climate zones)    
    merged = merged.rename(columns={'avg': 'value'})
    # # Lag features
    for lag in [1, 2, 3]:
        merged[f'aqi_lag_{lag}d'] = merged.groupby('city')['value'].shift(lag)
    merged = merged.drop(columns=['ts_y'])
    merged = merged.rename(columns={'ts_x': 'ts'})

    for col in merged.select_dtypes(include=np.number).columns.tolist():
        merged[col] = merged[col].fillna(merged[col].mean())
    
    return merged
    # return merged.dropna()

pd.set_option("display.min_rows", 10)
pd.set_option("display.max_columns", 100)

merged_data = merge_data(noaa_by_day, aqi_by_day)
display(merged_data)

Unnamed: 0,dt,city,temperature_avg,temperature_open,temperature_close,temperature_max,temperature_min,dew_point_avg,dew_point_open,dew_point_close,dew_point_max,dew_point_min,sea_level_pressure_avg,sea_level_pressure_open,sea_level_pressure_close,sea_level_pressure_max,sea_level_pressure_min,wind_speed_avg,wind_speed_open,wind_speed_close,wind_speed_max,wind_speed_min,max_wind_speed_avg,max_wind_speed_open,max_wind_speed_close,max_wind_speed_max,max_wind_speed_min,gust_avg,gust_open,gust_close,gust_max,gust_min,precipitation,ts,parameter,value,temp_wind_interaction,precip_accum_72h,day_of_week,is_weekend,month,aqi_lag_1d,aqi_lag_2d,aqi_lag_3d
0,2023-01-01 00:00:00+00:00,Cleveland,41.950000,41.9,42.0,42.0,41.9,38.400000,38.6,38.2,38.6,38.2,1015.100000,1014.6,1015.6,1015.6,1014.6,5.850000,5.1,6.6,6.6,5.1,14.000,15.0,13.0,15.0,13.0,23.000000,19.0,27.0,27.0,19.0,0.20,1672531200000000000,pm25,11.945714,245.407500,0.20,6,1,1,7.915147,7.879845,7.832595
1,2023-01-02 00:00:00+00:00,Cleveland,46.550000,45.2,47.9,47.9,45.2,43.950000,43.0,44.9,44.9,43.0,1018.450000,1017.7,1019.2,1019.2,1017.7,4.450000,5.1,3.8,5.1,3.8,9.050,11.1,7.0,11.1,7.0,23.000000,19.0,27.0,27.0,19.0,0.00,1672617600000000000,pm25,12.702083,207.147500,0.20,0,0,1,11.945714,7.879845,7.832595
2,2023-01-03 00:00:00+00:00,Cleveland,53.600000,53.0,54.2,54.2,53.0,50.900000,50.3,51.5,51.5,50.3,1010.850000,1008.8,1012.9,1012.9,1008.8,5.950000,6.1,5.8,6.1,5.8,10.950,12.0,9.9,12.0,9.9,22.550000,18.1,27.0,27.0,18.1,0.26,1672704000000000000,pm25,10.847826,318.920000,0.46,1,0,1,12.702083,11.945714,7.832595
3,2023-01-04 00:00:00+00:00,Cleveland,58.000000,58.7,57.3,58.7,57.3,54.200000,54.9,53.5,54.9,53.5,1006.700000,1005.9,1007.5,1007.5,1005.9,8.950000,9.7,8.2,9.7,8.2,17.000,15.0,19.0,19.0,15.0,22.550000,21.0,24.1,24.1,21.0,0.35,1672790400000000000,pm25,7.072917,519.100000,0.61,2,0,1,10.847826,12.702083,11.945714
4,2023-01-05 00:00:00+00:00,Cleveland,42.150000,42.6,41.7,42.6,41.7,31.600000,32.2,31.0,32.2,31.0,1013.150000,1012.7,1013.6,1013.6,1012.7,10.750000,11.6,9.9,11.6,9.9,17.000,15.0,19.0,19.0,15.0,22.000000,22.0,22.0,22.0,22.0,1.01,1672876800000000000,pm25,5.235417,453.112500,1.62,3,0,1,7.072917,10.847826,12.702083
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
313,2023-04-12 00:00:00+00:00,Cornwall,44.950000,44.8,46.1,46.1,43.2,37.675000,38.0,38.2,38.2,36.6,5498.525000,9999.9,9999.9,9999.9,996.2,24.750000,30.6,26.3,30.6,20.3,35.000,39.0,40.0,40.0,28.0,40.750000,54.0,55.0,55.0,27.0,99.99,1681257600000000000,pm25,7.613043,1112.512500,299.97,2,0,4,6.247368,7.879845,3.400000
314,2023-04-13 00:00:00+00:00,Cornwall,46.150000,47.5,46.4,47.5,44.5,39.675000,40.5,40.5,40.5,38.3,5502.575000,9999.9,9999.9,9999.9,1004.8,11.550000,14.7,12.5,14.7,8.9,18.200,19.0,22.0,22.0,15.9,33.025000,54.0,24.1,54.0,24.1,99.99,1681344000000000000,pm25,9.704545,533.032500,299.97,3,0,4,7.613043,6.247368,7.832595
315,2023-04-14 00:00:00+00:00,Cornwall,47.875000,49.5,48.0,49.5,45.8,44.675000,46.9,44.7,46.9,42.7,5501.300000,9999.9,9999.9,9999.9,1002.4,13.625000,19.0,12.9,19.0,11.3,19.500,26.0,19.0,26.0,15.9,33.025000,54.0,24.1,54.0,24.1,99.99,1681430400000000000,pm25,11.041667,652.296875,299.97,4,0,4,9.704545,7.613043,6.247368
316,2023-04-15 00:00:00+00:00,Cornwall,49.775000,50.5,50.1,50.5,48.6,42.125000,43.0,41.9,43.0,40.6,5510.325000,9999.9,9999.9,9999.9,1020.5,6.225000,6.2,7.0,7.0,5.5,10.725,8.0,14.0,14.0,8.0,33.025000,54.0,24.1,54.0,24.1,0.04,1681516800000000000,pm25,13.387500,309.849375,200.02,5,1,4,11.041667,9.704545,7.613043


In [104]:
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import os
# Set this to force CPU if you're having GPU-related issues
# os.environ["CUDA_VISIBLE_DEVICES"] = ""


def train_autogluon_model(train_data, prediction_length=1):  # Reduced prediction length
    """Configure and train AutoGluon ensemble"""
    predictor = TimeSeriesPredictor(
        target="value",
        prediction_length=prediction_length,  # Shorter prediction horizon
        known_covariates_names=[
            'temperature_avg', 'dew_point_avg', 'wind_speed_avg', 'max_wind_speed_max', 'gust_avg',
            'temp_wind_interaction', 'precip_accum_72h'
        ],
        eval_metric="MASE",
        path="aqi_models_new",
        freq="D"
    )
    
    # Modified hyperparameters with device specification        "use_holidays": True

    hyperparameters = {
        "RecursiveTabular": {
            "max_depth": 8,
            "learning_rate": 0.05
        },
        # Specify CPU device explicitly for neural network models
        "DeepAR": {
            "device": "cuda", 
            "num_layers": 2,  # Simpler architecture
            "hidden_size": 256
        },
        "TemporalFusionTransformer": {
            "device": "cuda",
            "hidden_dim": 128,
            "dropout_rate": 0.1
        }
    }
    
    predictor.fit(
        train_data,
        presets="high_quality",
        hyperparameters=hyperparameters,
        time_limit=7200,
        enable_ensemble=True,
    )
    
    return predictor

# merged_data = merged_data.loc[~(merged_data['value'].isnull())]

ts_dataframe = TimeSeriesDataFrame.from_data_frame(
    merged_data,
    id_column="city",
    timestamp_column="ts",
)

predictor = train_autogluon_model(ts_dataframe)

# # Save artifacts
predictor.save()
# print(f"Model saved to {predictor.path}")

Beginning AutoGluon training... Time limit = 7200s
AutoGluon will save models to '/mnt/c/Users/rog-9/aws/aqi_models_new'
AutoGluon Version:  1.2
Python Version:     3.9.19
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Nov 5 00:21:55 UTC 2024
CPU Count:          20
GPU Count:          1
Memory Avail:       9.31 GB / 15.43 GB (60.3%)
Disk Space Avail:   111.57 GB / 928.35 GB (12.0%)
Setting presets to: high_quality

Fitting with arguments:
{'enable_ensemble': True,
 'eval_metric': MASE,
 'freq': 'D',
 'hyperparameters': {'DeepAR': {'device': 'cuda',
                                'hidden_size': 256,
                                'num_layers': 2},
                     'RecursiveTabular': {'learning_rate': 0.05,
                                          'max_depth': 8},
                     'TemporalFusionTransformer': {'device': 'cuda',
                                                   'dropout_rate': 0.1,
                                         

ValueError: At least some time series in train_data must have >= 6 observations. Please provide longer time series as train_data or reduce prediction_length, num_val_windows, or val_step_size.

In [103]:
merged_data['ts'] = merged_data['ts'] // 10**9
merged_data.to_csv('train_data.csv', index=False)
