In [1]:
import data_prep

In [2]:
from data_prep import X_train_1, X_test_1, y_train_1, y_test_1, X_train_2, X_test_2, y_train_2, y_test_2

In [None]:
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from mlflow.models.signature import infer_signature

def train_model(X_train, X_test, y_train, y_test):
    with mlflow.start_run(run_name="Random Forest - Training on fraud_data"):
        rf_model = RandomForestClassifier(random_state=42)
        rf_model.fit(X_train, y_train)
        
        y_pred = rf_model.predict(X_test)
        
        accuracy = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)
        
        mlflow.log_metric("accuracy", accuracy)
        mlflow.log_metric("f1_score", f1)
        
        signature = infer_signature(X_train, y_pred)
        mlflow.sklearn.log_model(rf_model, "random_forest", signature=signature, input_example=X_train.head(5))
        
        print(f"Train on fraud_Data: Accuracy = {accuracy}, F1 Score = {f1}")
    
    return rf_model

# a function to fine-tune the model on the credit dataset
def finetune_model(initial_model, X_train, X_test, y_train, y_test):
    with mlflow.start_run(run_name="Random Forest - Fine-Tuning on credit_Dataset"):
        initial_model.fit(X_train, y_train)
        
        y_pred = initial_model.predict(X_test)
        
        accuracy = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)
        
        mlflow.log_metric("accuracy", accuracy)
        mlflow.log_metric("f1_score", f1)
        
        signature = infer_signature(X_train, y_pred)
        mlflow.sklearn.log_model(initial_model, "random_forest_finetuned_model", signature=signature, input_example=X_train.head(5))
        
        print(f"Fine-Tuning -credit_Data : Accuracy = {accuracy}, F1 Score = {f1}")


initial_model = train_model(X_train_1, X_test_1, y_train_1, y_test_1)
finetune_model(initial_model, X_train_2, X_test_2, y_train_2, y_test_2)


