# Refactoring of the train-ml-model.ipynb notebook

In [5]:
import os
from dotenv import load_dotenv

import pandas as pd

import mlflow
from mlflow.tracking.client import MlflowClient

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error


### Extract and load

Rewrite to accept multiple months for retraining

In [6]:
year = 2021
month = 1
color = "green"
model_name="mlops-project"


In [24]:
from typing import Optional
# Download the data

def extract_data(file_name: str): 
    if not os.path.exists(f"./data/{file_name}"):
        os.system(f"wget -P ./data https://d37ci6vzurychx.cloudfront.net/trip-data/{file_name}")

def load_data(file_name: str,
        df2: Optional[pd.DataFrame] = None,
        ) -> pd.DataFrame:
    
    df = pd.read_parquet(f"./data/{file_name}")
    if df2 is None:
        df = pd.concat([df, df2], ignore_index=True)
    return df

#### Test with single month

In [13]:
file_name = f"{color}_tripdata_{year}-{month:02d}.parquet"

extract_data(file_name)
df = load_data(file_name)

In [22]:
df.head(1)

Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,extra,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge
0,2,2021-01-01 00:15:56,2021-01-01 00:19:52,N,1.0,43,151,1.0,1.01,5.5,0.5,0.5,0.0,0.0,,0.3,6.8,2.0,1.0,0.0


In [21]:
df.shape

(76518, 20)

### Test with multiple months

In [25]:
months = [1,2,3]
df = None
for month in months:
    file_name = f"{color}_tripdata_{year}-{month:02d}.parquet"

    extract_data(file_name)
    df = load_data(file_name, df2=df)

--2023-07-21 11:05:08--  https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-03.parquet
Auflösen des Hostnamens d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)… 2600:9000:2070:3800:b:20a5:b140:21, 2600:9000:2070:1c00:b:20a5:b140:21, 2600:9000:2070:7000:b:20a5:b140:21, ...
Verbindungsaufbau zu d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)|2600:9000:2070:3800:b:20a5:b140:21|:443 … verbunden.
HTTP-Anforderung gesendet, auf Antwort wird gewartet … 200 OK
Länge: 1474538 (1,4M) [binary/octet-stream]
Wird in »./data/green_tripdata_2021-03.parquet« gespeichert.

     0K .......... .......... .......... .......... ..........  3% 1,14M 1s
    50K .......... .......... .......... .......... ..........  6%  809K 1s
   100K .......... .......... .......... .......... .......... 10%  713K 2s
   150K .......... .......... .......... .......... .......... 13% 3,92M 1s
   200K .......... .......... .......... .......... .......... 17% 3,96M 1s
   250K .....

In [26]:
df.head(1)

Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,extra,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge
0,2,2021-03-01 00:05:42,2021-03-01 00:14:03,N,1.0,83,129,1.0,1.56,7.5,0.5,0.5,0.0,0.0,,0.3,8.8,1.0,1.0,0.0


In [27]:
df.shape

(83827, 20)

In [30]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83827 entries, 0 to 83826
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype         
---  ------                 --------------  -----         
 0   VendorID               83827 non-null  int64         
 1   lpep_pickup_datetime   83827 non-null  datetime64[ns]
 2   lpep_dropoff_datetime  83827 non-null  datetime64[ns]
 3   store_and_fwd_flag     43293 non-null  object        
 4   RatecodeID             43293 non-null  float64       
 5   PULocationID           83827 non-null  int64         
 6   DOLocationID           83827 non-null  int64         
 7   passenger_count        43293 non-null  float64       
 8   trip_distance          83827 non-null  float64       
 9   fare_amount            83827 non-null  float64       
 10  extra                  83827 non-null  float64       
 11  mta_tax                83827 non-null  float64       
 12  tip_amount             83827 non-null  float64       
 13  t

### Preprocess data

I decided to remove passenger count from the features because there are so many null values and this data doesn't seem relevant for trip duration.

In [36]:
def calculate_trip_duration_in_minutes(df: pd.DataFrame):
    features = ["PULocationID", 
                "DOLocationID", 
                "trip_distance", 
                "fare_amount", 
                "total_amount"]
    target = "duration"

    df[target] = (df["lpep_dropoff_datetime"] - df["lpep_pickup_datetime"]).dt.total_seconds() / 60
    df = df[(df["duration"] >= 1) & (df["duration"] <= 60)]
    df = df[features + [target]]
    return df

def split_data(df: pd.DataFrame):
    df = df.copy()
    target = "duration"
    y=df[target]
    X=df.drop(columns=[target])
    X_train, X_test, y_train, y_test = train_test_split(
        X, 
        y, 
        random_state=42, 
        test_size=0.2)
    return X_train, X_test, y_train, y_test

In [37]:
df = calculate_trip_duration_in_minutes(df)
X_train, X_test, y_train, y_test = split_data(df)

In [39]:
X_train.head(2)

Unnamed: 0,PULocationID,DOLocationID,trip_distance,fare_amount,total_amount
8196,42,75,2.32,10.0,10.8
75288,183,168,7.33,27.27,30.32


In [40]:
y_train.head(2)

8196     10.833333
75288    17.000000
Name: duration, dtype: float64

### Train and register model
I rewrote the code to update most recent model version to "Production" and archive previous version

In [42]:
features = ["PULocationID", 
                "DOLocationID", 
                "trip_distance", 
                "fare_amount", 
                "total_amount"]
target = "duration"


tags = {
    "model": "linear regression",
        "developer": "Chris Hedemann",
        "dataset": f"green-taxi",
        "year": year,
        "month": months,
        "features": features,
        "target": target
}

def train_model(model_name: str, 
                X_train: pd.DataFrame, 
                X_test: pd.DataFrame, 
                y_train: pd.Series, 
                y_test: pd.Series,
                tags: dict[str]):
    # Environment variables
    load_dotenv()
    MLFLOW_TRACKING_URI=os.getenv("MLFLOW_TRACKING_URI")

    # Get Google SA_KEY to access the MLFLow server
    SA_KEY= os.getenv("SA_KEY")
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = SA_KEY

    # Set up the connection to MLflow
    mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
    client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)
    mlflow.set_experiment(model_name)

    with mlflow.start_run():

        # Train and evluate model
        lr = LinearRegression()
        lr.fit(X_train, y_train)
        y_pred = lr.predict(X_test)
        rmse = mean_squared_error(y_test, y_pred, squared=False)
        
        # Log model
        mlflow.set_tags(tags)
        mlflow.log_metric("rmse", rmse)
        mlflow.sklearn.log_model(lr, "model")

        # Register and transition model to production
        run_id = mlflow.active_run().info.run_id
        model_uri = f"runs:/{run_id}/model"
        mlflow.register_model(model_uri=model_uri, name=model_name)
        latest_versions = client.get_latest_versions(model_name)
        model_version = int(latest_versions[-1].version)
        new_stage = "Production"
        client.transition_model_version_stage(
        name=model_name,
        version=model_version,
        stage=new_stage,
        archive_existing_versions=True
        )

In [43]:
train_model(model_name,
            X_train,
            X_test, 
            y_train, 
            y_test,
            tags)

2023/07/21 11:30:07 INFO mlflow.tracking.fluent: Experiment with name 'mlops-project' does not exist. Creating a new experiment.
Successfully registered model 'mlops-project'.
2023/07/21 11:30:13 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: mlops-project, version 1
Created version '1' of model 'mlops-project'.
