In [49]:
!pip freeze | grep scikit-learn

scikit-learn==1.2.2


In [50]:
import pickle
import pandas as pd
import numpy as np

## Parameters

In [51]:
year = 2022
month = 2
input_file = f"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_{year:04d}-{month:02d}.parquet"
output_file = f"output/yellow_tripdata_{year:04d}-{month:02d}.parquet"

In [52]:
def load_model(model_file_name):
    with open(model_file_name, 'rb') as f_in:
        dv, model = pickle.load(f_in)
    return dv, model

In [53]:
def read_data(filename):
    df = pd.read_parquet(filename)
    
    df['duration'] = df.tpep_dropoff_datetime - df.tpep_pickup_datetime
    df['duration'] = df.duration.dt.total_seconds() / 60

    df = df[(df.duration >= 1) & (df.duration <= 60)].copy()

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].fillna(-1).astype('int').astype('str')
    
    return df

In [54]:
def prepare_dictionary(df):
    categorical = ['PULocationID', 'DOLocationID']
    dicts = df[categorical].to_dict(orient='records')
    return dicts

In [68]:
def apply_model(input_file, model, output_file):
    print("apply model")
    df = read_data(input_file)
    print(df.shape)
    dicts = prepare_dictionary(df)
    
    dv, model = load_model(model)
    
    X_val = dv.transform(dicts)
    y_pred = model.predict(X_val)
    print(f"Prediction stdiv: {np.std(y_pred)}")
    
    ride_id = f'{year:04d}/{month:02d}_' + df.index.astype('str')
    df_result = pd.DataFrame({'ride_id': ride_id, 'prediction':y_pred})
    
    print(df_result.head(5))
    print(output_file)
    
    df_result.to_parquet(
        output_file,
        engine='pyarrow',
        compression=None,
        index=False
    )
    print("completed")

In [71]:
apply_model(input_file, "model.bin", output_file)

apply model
(2918187, 20)
Prediction stdiv: 5.28140357655334
     ride_id  prediction
0  2022/02_0   18.527783
1  2022/02_1   23.065782
2  2022/02_2   33.686359
3  2022/02_3   23.757436
4  2022/02_4   21.492904
output/yellow_tripdata_2022-02.parquet
completed


In [72]:
params = {'year': year, 'month': month}

In [74]:
params['year']

2022