In [24]:
# http://localhost:8888/tree?token=828ad94db4d7d69d99c8c59436853dc031be47d05fee9a61

In [1]:
# ! gcloud auth login

In [2]:
# SERVICE_ACCOUNT = "flight-ml-demo-general@aia-ds-accelerator-flight-1.iam.gserviceaccount.com"  # @param {type:"string"}

In [15]:
# data_preprocessing.py
from datetime import datetime, timedelta
import polars as pl
from deltalake import DeltaTable
import json

class DataPreprocessor:
    def __init__(self, gcp_client, bucket: str, table_path: str):
        self.gcp_client = gcp_client
        self.table_path = table_path
        self.bucket = bucket
        self.table_path_full = f'gs://{bucket}/{table_path}'
    
    def get_date_from_lookback(self, lookback_days: int, return_as_str = True):
        target_date = datetime.utcnow() - timedelta(days=lookback_days)
        
        if return_as_str:
            year, month, day = target_date.year, target_date.month, target_date.day
            return f'{year}/{month:02d}/{day:02d}'
        else:
            return target_date

    def read_data(self):
        gcp_creds_json_str = json.dumps(self.gcp_client.creds_json)
        storage_options = {"service_account_key": gcp_creds_json_str}

        start_date = self.get_date_from_lookback(lookback_days=365, return_as_str=True)
        
        df = pl.scan_delta(
                self.table_path_full,
                pyarrow_options={"partitions": [("crt_ts_date", ">=", start_date)]},
                storage_options=storage_options,
        )
        
        print(f"Data loaded from {self.table_path_full}")
        return df

    def create_target(self, df):
        # Create target variable, time to landing
        landing_times = df.group_by('fa_flight_id').agg(pl.max('actual_in').alias('actual_in_filled'))
        df = df.join(landing_times, on='fa_flight_id')
        df = df.with_columns( ((pl.col('actual_in_filled').dt.timestamp("ms")-pl.col('event_ts').dt.timestamp("ms"))/1000/60/60 ).alias('target') )
        
        df = df.sort(['actual_in_filled','crt_ts'])
        return df

    def remove_incomplete_flights(self, df):
        # Remove flights that haven't landed yet
        return df.filter( pl.col('actual_in_filled').is_not_null() )

    def removed_arrival_events(self, df):
        # Arrival events are not useful for training
        return df.filter(pl.col('event_type') != 'actual_in')

    def process_data(self, df):
        df = self.create_target(df)
        df = self.remove_incomplete_flights(df)
        df = self.removed_arrival_events(df)
        return df
    
    def write_data_to_gcs(self, df, path_out: str):
        print(f"Writing data to {path_out}")
        bucket = self.gcp_client.storage_client.get_bucket(self.bucket) 
        blob = bucket.blob(f'{path_out}.csv')
        blob.upload_from_string(df.collect().to_csv(), 'text/csv')

def main():

    from gcp import GCPClient

    gcp_client = GCPClient()

    project_id = 'aia-ds-accelerator-flight-1'
    bucket = 'datalake-flight-dev-1'
    table_path_in = 'flightsummary-delta-processed-stream'
    table_path_out = 'training/flightsummary-training'

    data_preprocessor = DataPreprocessor(gcp_client, bucket=bucket, table_path=table_path_in)

    # Step 1: Load data
    df = data_preprocessor.read_data().collect()

    # # Step 2: Process data
    df1 = data_preprocessor.process_data(df)

    # # Step 3: Write data
    data_preprocessor.write_data_to_gcs(df, path_out=table_path_out)

if __name__ == "__main__":
    main()
