In [1]:
import os

from dotenv import load_dotenv
from sqlalchemy import create_engine, text
import webdataset as wds
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

import net.preprocessing as preprocessing
import gfs.fetch


In [2]:
TRAIN_DATA_PATH = "../analytics/training/train_data"
VAL_DATA_PATH = "../analytics/training/val_data"


In [3]:
load_dotenv()

connection_string = "postgresql://{user}:{password}@{host}:{port}/{db}".format(
    user=os.getenv('DB_USER'),
    password=os.getenv('DB_PASSWORD'),
    host=os.getenv('DB_HOST'),
    port=os.getenv('DB_PORT'),
    db=os.getenv('DB_NAME')
)
engine = create_engine(connection_string)

In [4]:
def create_webdataset(engine, output_path, weather_features, site_features,
                       site_id_col, date_col, chunk_size=1000, total_chunks=None,
                       is_validation=False):
    """
    Convert database data to WebDataset format, processing in chunks
    
    Args:
        engine: SQLAlchemy engine
        output_path: path where to save the .tar files
        weather_features: list of column names for weather features
        site_features: list of column names for site features
        site_id_col: column name for site ID
        date_col: column name for date
        is_validation: whether to process validation or training data
        chunk_size: number of samples per shard/chunk
        total_chunks: optional limit on number of chunks to process
    """
    os.makedirs(output_path, exist_ok=True)
    
    # Query to get total count (if total_chunks not specified)
    if total_chunks is None:
        count_query = f"""
        SELECT COUNT(*) 
        FROM glideator_fs.features_with_target
        WHERE is_validation = {is_validation}
        """
        with engine.connect() as conn:
            total_rows = conn.execute(text(count_query)).scalar()
        total_chunks = (total_rows + chunk_size - 1) // chunk_size
    
    # Process data in chunks
    for chunk_idx in tqdm(range(total_chunks)):
        # Create a new tar file for each chunk
        shard_name = f"{output_path}/shard_{chunk_idx:06d}.tar"
        
        # Query for this chunk
        chunk_query = f"""
        SELECT *
        FROM glideator_fs.features_with_target
        WHERE is_validation = {is_validation}
        LIMIT {chunk_size}
        OFFSET {chunk_idx * chunk_size}
        """
        
        # Process chunk
        with engine.connect() as conn:
            chunk_df = pd.read_sql(chunk_query, conn)
        
        # Drop row with nulls
        chunk_df = chunk_df.dropna()
        # Convert date column to datetime
        chunk_df[date_col] = pd.to_datetime(chunk_df[date_col])

        thresholds = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
        preprocessing.add_date_features(chunk_df)
        preprocessing.add_targets(chunk_df, thresholds=thresholds)
        date_cols = ['weekend', 'year', 'day_of_year_sin', 'day_of_year_cos']
        target_cols = [f'XC{threshold}' for threshold in thresholds]
        
        with wds.TarWriter(shard_name) as sink:
            for idx, row in chunk_df.iterrows():
                # Create sample key
                key = f"{chunk_idx}_{idx:08d}"
                
                # Extract and combine features
                weather_data_9 = row[weather_features[9]].values.astype(np.float32)
                weather_data_12 = row[weather_features[12]].values.astype(np.float32)
                weather_data_15 = row[weather_features[15]].values.astype(np.float32)
                site_data = row[site_features].values.astype(np.float32)
                date_data = row[date_cols].values.astype(np.float32)
                site_id = np.int64(row[site_id_col])  # Convert directly to numpy int64
                
                # Combine all features
                features = {
                    'weather': {
                        '9': torch.tensor(weather_data_9),
                        '12': torch.tensor(weather_data_12),
                        '15': torch.tensor(weather_data_15)
                    },
                    'site': torch.tensor(site_data),
                    'site_id': torch.tensor(site_id),
                    'date': torch.tensor(date_data)
                }
                
                # Extract target
                targets = row[target_cols].values.astype(np.float32)
                
                # Create sample dictionary
                sample = {
                    "__key__": key,
                    "features.pth": features,
                    "targets.pth": torch.tensor(targets),
                    "date.pth": torch.tensor(row['date'].timestamp())
                }
                
                # Write sample to tar file
                sink.write(sample)
        
        del chunk_df

In [5]:
col_names = gfs.fetch.get_col_order()

references = (
    (6, 3),
    (12, 0),
    (12, 3)
)
col_names_full = []
for run, delta in references:
    for col in col_names:
        col_names_full.append(f'{col}_{run+delta}')

In [6]:
weather_features = {
    9: [col for col in col_names_full if col.endswith('9')],
    12: [col for col in col_names_full if col.endswith('12')],
    15: [col for col in col_names_full if col.endswith('15')],
}
site_features = ['latitude', 'longitude', 'altitude']
site_id_col = 'site_id'
date_col = 'date'

In [7]:
create_webdataset(
    engine=engine,
    output_path=TRAIN_DATA_PATH,
    weather_features=weather_features,  # List of weather feature column names
    site_features=site_features,  # List of site feature column names 
    site_id_col=site_id_col,  # Column name for site ID
    date_col=date_col,  # Column name for date
    chunk_size=1000,  # Adjust based on your memory constraints
    total_chunks=None,  # Process all data
    is_validation=False
)
create_webdataset(
    engine=engine,
    output_path=VAL_DATA_PATH,
    weather_features=weather_features,  # List of weather feature column names
    site_features=site_features,  # List of site feature column names
    site_id_col=site_id_col,  # Column name for site ID
    date_col=date_col,  # Column name for date
    chunk_size=1000,  # Adjust based on your memory constraints
    total_chunks=None,  # Process all data
    is_validation=True
)

100%|██████████████████████████████████████████████████████████████████████████████████████| 722/722 [2:10:35<00:00, 10.85s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████| 181/181 [28:51<00:00,  9.57s/it]
