In [1]:
import os

from dotenv import load_dotenv
from sqlalchemy import create_engine, text

import gfs.fetch

In [2]:
load_dotenv()

connection_string = "postgresql://{user}:{password}@{host}:{port}/{db}".format(
    user=os.getenv('DB_USER'),
    password=os.getenv('DB_PASSWORD'),
    host=os.getenv('DB_HOST'),
    port=os.getenv('DB_PORT'),
    db=os.getenv('DB_NAME')
)
engine = create_engine(connection_string)

In [3]:
col_names = gfs.fetch.get_col_order()

references = (
    (6, 3),
    (12, 0),
    (12, 3)
)

query_template = """
MAX(CASE WHEN run = {run} AND delta = {delta} THEN {col} END) AS {col}_{suffix}
"""

cols = []
col_names_full = []
for run, delta in references:
    for col in col_names:
        col_names_full.append(f'{col}_{run+delta}')
        cols.append(query_template.format(run=run, delta=delta, col=col, suffix=run+delta))

cols = ',\n'.join(cols)

In [4]:
query = f"""
DROP TABLE IF EXISTS glideator_fs.features_with_target;
CREATE TABLE glideator_fs.features_with_target AS
WITH
sites AS (
    SELECT
        name AS site,
        site_id,
        latitude,
        longitude,
        altitude,
        lat_gfs,
        lon_gfs
    FROM glideator_mart.dim_sites
    WHERE site_id <= 250
),
stats AS (
    SELECT
        site,
        date,
        max_points
    FROM glideator_mart.mart_daily_flight_stats
),
gfs AS (
    SELECT
        lat,
        lon,
        date,
        {cols}
    FROM source.gfs
    GROUP BY lat, lon, date
),
features AS (
    SELECT
        sites.site,
        sites.site_id,
        sites.latitude,
        sites.longitude,
        sites.altitude,
        gfs.*
    FROM gfs
    JOIN sites 
    ON gfs.lat = sites.lat_gfs
    AND gfs.lon = sites.lon_gfs
),
joined_features AS (
    SELECT
        features.*,
        stats.max_points
    FROM features
    JOIN stats
    ON features.site = stats.site
    AND features.date = stats.date
),
-- Add row numbers partitioned by site_id
numbered_rows AS (
    SELECT 
        *,
        ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY RANDOM()) as rn,
        COUNT(*) OVER (PARTITION BY site_id) as total_site_rows
    FROM joined_features
),
features_with_target AS (
    SELECT
        *,
        -- Mark 20% of each site's data as validation
        CASE 
            WHEN rn <= CEIL(0.2 * total_site_rows) THEN TRUE 
            ELSE FALSE 
        END AS is_validation
    FROM numbered_rows
)
SELECT 
    *
FROM features_with_target
ORDER BY RANDOM()
"""

with engine.connect() as conn:
    conn.execute(text(query))
    conn.commit()