# Wine Quality Classification using Machine Learning

#### Overview
The purpose of this project is to build and evaluate machine learning models that can accurately predict the quality of red wine. By analyzing various physicochemical properties, we aim to understand which features are most indicative of wine quality and to develop a robust classifier. This can be particularly useful for vintners seeking to assess and improve wine quality or for retailers recommending wines based on customer preferences.



### Select a dataset.
- dbutils.fs.ls('/databricks-datasets/')

In [None]:
dbutils.fs.ls('/databricks-datasets/')

[FileInfo(path='dbfs:/databricks-datasets/COVID/', name='COVID/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/README.md', name='README.md', size=976, modificationTime=1532502324000),
 FileInfo(path='dbfs:/databricks-datasets/Rdatasets/', name='Rdatasets/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/SPARK_README.md', name='SPARK_README.md', size=3359, modificationTime=1455505834000),
 FileInfo(path='dbfs:/databricks-datasets/adult/', name='adult/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/airlines/', name='airlines/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/amazon/', name='amazon/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/asa/', name='asa/', size=0, modificationTime=1721442662906),
 FileInfo(path='dbfs:/databricks-datasets/atlas_higgs/', name='atlas_higgs/', size=0, modificationTime=

In [None]:
# List the files in the 'sfo_customer_survey' directory
sfo_customer_survey_files = dbutils.fs.ls('dbfs:/databricks-datasets/wine-quality/')
print("Files in sfo_customer_survey directory:")
for file in sfo_customer_survey_files:
    print(file.path)

Files in sfo_customer_survey directory:
dbfs:/databricks-datasets/wine-quality/README.md
dbfs:/databricks-datasets/wine-quality/winequality-red.csv
dbfs:/databricks-datasets/wine-quality/winequality-white.csv


In [None]:
# Path to the README file
readme_file_path = '/databricks-datasets/wine-quality/README.md'

# Read the README file
readme_df = spark.read.text(readme_file_path)

# Show the contents of the README file
readme_df.show(truncate=False)


+------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|value                                                                                                                                                                   |
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Wine Quality Data Set                                                                                                                                                   |
|Two datasets related to red and white variants of the Portuguese "Vinho Verde" wine.                                                                                    |
|                                                                                                                                                

### Question of interest.
My question of interest is whether we can predict the classify a wine based on its physicochemical properties. We are using the Red Wine dataset from the Vinho Verde in northwestern Portugal. This will be a classification problem in which we try to predict a low, medium or high quality.

### Perform EDA on your dataset.

In [None]:
# Set your user name in the widgit in the upper left of the screen. 
# This is required so that you can create a folder for yourself!

# Your User Name Here
username = dbutils.widgets.get("username")
save_path = f"dbfs:/tmp/w8/{username}"

silver_path = f"{save_path}/silver"

# View the paths
print(silver_path)

dbfs:/tmp/w8/cthirtee/silver


In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# Read in the raw data from the CSV Source
red_wine = spark.read.csv(
    "/databricks-datasets/wine-quality/winequality-red.csv", 
    schema="`fixed_acidity` DOUBLE, `volatile_acidity` DOUBLE, `citric_acid` DOUBLE, `residual_sugar` DOUBLE, `chlorides` DOUBLE, `free_sulfur_dioxide` DOUBLE, `total_sulfur_dioxide` DOUBLE, `density` DOUBLE, `pH` DOUBLE, `sulphates` DOUBLE, `alcohol` DOUBLE, `quality` DOUBLE", header=True, sep=';'
)

# Define the categorization functions
def categorize_quality(score):
    if score <= 4:
        return 'Poor'
    elif score <= 5:
        return 'Fair'
    elif score <= 6:
        return 'Commended'
    elif score <= 7:
        return 'Bronze'
    elif score <= 8:
        return 'Silver Medal'
    else:
        return 'Gold'
    
categorize_quality_udf = udf(categorize_quality, StringType())
red_wine = red_wine.withColumn('quality_category', categorize_quality_udf(red_wine['quality']))

    # Display the updated DataFrame
#red_wine.display()

# Create our Delta Table in our silversilver_path staging area
red_wine.write.format('delta').mode('overwrite').option("mergeSchema", "true").save(f"{silver_path}/wine_quality")




fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,quality_category
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5.0,Fair
7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5.0,Fair
11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6.0,Commended
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.4,0.66,0.0,1.8,0.075,13.0,40.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.9,0.6,0.06,1.6,0.069,15.0,59.0,0.9964,3.3,0.46,9.4,5.0,Fair
7.3,0.65,0.0,1.2,0.065,15.0,21.0,0.9946,3.39,0.47,10.0,7.0,Bronze
7.8,0.58,0.02,2.0,0.073,9.0,18.0,0.9968,3.36,0.57,9.5,7.0,Bronze
7.5,0.5,0.36,6.1,0.071,17.0,102.0,0.9978,3.35,0.8,10.5,5.0,Fair


In [None]:
# Read the silver red wine data into a SparkDataFrame
red_wine_delta = spark.read.format("delta").load(f"{silver_path}/wine_quality")

display(red_wine_delta)

fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,quality_category
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5.0,Fair
7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5.0,Fair
11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6.0,Commended
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.4,0.66,0.0,1.8,0.075,13.0,40.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.9,0.6,0.06,1.6,0.069,15.0,59.0,0.9964,3.3,0.46,9.4,5.0,Fair
7.3,0.65,0.0,1.2,0.065,15.0,21.0,0.9946,3.39,0.47,10.0,7.0,Bronze
7.8,0.58,0.02,2.0,0.073,9.0,18.0,0.9968,3.36,0.57,9.5,7.0,Bronze
7.5,0.5,0.36,6.1,0.071,17.0,102.0,0.9978,3.35,0.8,10.5,5.0,Fair


In [None]:
# Check schema and column names
red_wine_delta.printSchema()

root
 |-- fixed_acidity: double (nullable = true)
 |-- volatile_acidity: double (nullable = true)
 |-- citric_acid: double (nullable = true)
 |-- residual_sugar: double (nullable = true)
 |-- chlorides: double (nullable = true)
 |-- free_sulfur_dioxide: double (nullable = true)
 |-- total_sulfur_dioxide: double (nullable = true)
 |-- density: double (nullable = true)
 |-- pH: double (nullable = true)
 |-- sulphates: double (nullable = true)
 |-- alcohol: double (nullable = true)
 |-- quality: double (nullable = true)
 |-- quality_category: string (nullable = true)




#### Explanation of the variables
**fixed acidity**: The amount of non-volatile acids in the wine. These acids do not evaporate easily and contribute to the overall acidity of the wine. Higher levels of fixed acidity can contribute to a wine's crispness and tartness. However, excessively high levels can make the wine taste too sharp or sour. The optimal level depends on the wine style and balance with other components.

**volatile acidity**: The amount of acetic acid in wine, which can lead to an unpleasant vinegar-like taste if too high. Lower volatile acidity is typically preferred for higher-quality wines.

**citric acid**: Provides a fresh flavor to wines and is usually found in small quantities. Wines with higher citric acid might have a more refreshing taste.

**residual_sugar**: Refers to the amount of sugar remaining in the wine after fermentation, measured in grams per liter. The perception of sweetness in wine is influenced by residual sugar levels. For dry wines, low residual sugar is preferred to maintain balance and allow other flavors to shine. In sweeter wines, higher residual sugar can contribute to a perceived fullness and roundness.

**chlorides**: Represents the amount of salt in the wine. Chloride levels are typically low in wine but can influence its taste and mouthfeel. Higher chlorides might contribute to a salty or briny taste, undesirable in excess.

**free sulfur dioxide**: Measures the free form of sulfur dioxide (SO2) in the wine, which acts as an antioxidant and antimicrobial agent. Adequate free sulfur dioxide levels help preserve wine freshness and prevent spoilage.

**total sulfur dioxide**: Indicates the total amount of sulfur dioxide (free + bound forms) in the wine. Too high total sulfur dioxide levels can lead to a pungent aroma and affect taste negatively.

**density**: Represents the density of the wine, which is close to that of water depending on the alcohol and sugar content. Density affects mouthfeel and body. Higher density wines may feel fuller-bodied, while lower density wines can feel lighter. It contributes to the overall texture and perceived quality of the wine.

**pH**: Measures the acidity or basicity of the wine on a scale from 0 to 14, with lower values indicating higher acidity. Wines with lower pH levels tend to be crisper and more acidic, enhancing freshness. Higher pH levels can lead to a flatter taste and may indicate microbial instability.

**sulphates**: Adds to the wine's antimicrobial and antioxidant properties. Proper levels of sulphates help maintain wine quality and stability.

**alcohol**: Indicates the alcohol content of the wine, typically measured in percent volume. Alcohol contributes to wine body, texture, and perceived warmth. Well-integrated alcohol levels enhance complexity and balance. High alcohol can dominate flavors, while low alcohol may lack depth.

**quality**: Subjective quality rating of the wine. Higher quality wines typically exhibit balanced acidity, complexity, harmony of flavors, and a pleasing mouthfeel.

In [None]:
# For visualizations
# Read the silver red wine data into a SparkDataFrame
red_wine_delta = spark.read.format("delta").load(f"{silver_path}/wine_quality")
display(red_wine_delta)



fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,quality_category
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5.0,Fair
7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5.0,Fair
11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6.0,Commended
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.4,0.66,0.0,1.8,0.075,13.0,40.0,0.9978,3.51,0.56,9.4,5.0,Fair
7.9,0.6,0.06,1.6,0.069,15.0,59.0,0.9964,3.3,0.46,9.4,5.0,Fair
7.3,0.65,0.0,1.2,0.065,15.0,21.0,0.9946,3.39,0.47,10.0,7.0,Bronze
7.8,0.58,0.02,2.0,0.073,9.0,18.0,0.9968,3.36,0.57,9.5,7.0,Bronze
7.5,0.5,0.36,6.1,0.071,17.0,102.0,0.9978,3.35,0.8,10.5,5.0,Fair


Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

Analyzing the graphs, we see the following based on our summary statistics:
- 'fixed_acidity' is approximately normal with a slight skew towards higher values. Our mean is 8.32, and the standard deviation is 1.74, with moderate variability around the mean. 
- 'volatile_acidity' is right skewed towards lower values. The mean is around 0.53, and the standard deviation is approximately 0.18, suggesting relatively low variability around the mean.
- 'citric_acid' is right skewed but with three peaks. The mean value is about 0.27, and the standard deviation is around 0.19, indicating variability in citric acid content.
- 'residual_sugar is right skewed with a long tail. The mean value is approximately 2.54, and the standard deviation is about 1.41.
- 'chlorides' is right skewed with a long tail. The mean value is around 0.09, and the standard deviation is approximately 0.05.
- 'free_sulfur_dioxide' is right skewed with a long tail as well. The mean value is about 15.87, and the standard deviation is around 10.46, indicating variability in free sulfur dioxide levels.
- 'total_sulfur_dioxide' is right skewed with a long tail, with potential outliers at higher values. The mean value is approximately 46.47, and the standard deviation is about 32.90, suggesting variability in total sulfur dioxide levels.
- 'density' is approximately normal with a slight skew towards higher values. The mean value is around 0.9967, and the standard deviation is approximately 0.0019, indicating low variability around the mean.
- 'pH' is normally distributed, and a mean value is about 3.31, with a standard deviation of around 0.15.
- 'sulphates' is right skewed with a mean value of approximately 0.66 and a standard deviation of about 0.17.
- 'alcohol' is right skewed, with a mean value of around 10.42 and a standard deviation of approximately 1.07.


In [None]:
# Summary statistics
red_wine_delta.describe().display()

# Checking for missing values
from pyspark.sql.functions import col, isnan, when, count
red_wine_delta.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in red_wine_delta.columns]).display()



summary,fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,quality_category
count,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599.0,1599
mean,8.319637273295838,0.5278205128205131,0.2709756097560964,2.538805503439652,0.0874665415884925,15.87492182614134,46.46779237023139,0.9967466791744832,3.311113195747343,0.6581488430268921,10.422983114446502,5.636022514071295,
stddev,1.7410963181276948,0.1790597041535352,0.1948011374053182,1.40992805950728,0.04706530201009,10.46015696980971,32.89532447829907,0.0018873339538427,0.1543864649035427,0.1695069795901101,1.0656675818473935,0.8075694397347051,
min,4.6,0.12,0.0,0.9,0.012,1.0,6.0,0.99007,2.74,0.33,8.4,3.0,Bronze
max,15.9,1.58,1.0,15.5,0.611,72.0,289.0,1.00369,4.01,2.0,14.9,8.0,Silver Medal


fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,quality_category
0,0,0,0,0,0,0,0,0,0,0,0,0


In [None]:
# Grouping by 'quality' and aggregating mean values
red_wine_delta.groupBy('quality_category').agg({
    'fixed_acidity': 'mean', 'volatile_acidity': 'mean', 'citric_acid': 'mean', 'residual_sugar': 'mean', 
    'chlorides': 'mean', 'free_sulfur_dioxide': 'mean', 'total_sulfur_dioxide': 'mean', 'density': 'mean',
    'pH': 'mean', 'sulphates': 'mean', 'alcohol': 'mean'
}).display()

quality_category,avg(sulphates),avg(chlorides),avg(residual_sugar),avg(free_sulfur_dioxide),avg(density),avg(volatile_acidity),avg(pH),avg(citric_acid),avg(alcohol),avg(total_sulfur_dioxide),avg(fixed_acidity)
Fair,0.6209691629955947,0.0927356828193832,2.528854625550658,16.983847283406753,0.9971036270190888,0.5770411160058732,3.304948604992654,0.2436857562408219,9.899706314243751,56.51395007342144,8.167254038179149
Poor,0.5922222222222221,0.0957301587301587,2.684920634920635,12.063492063492063,0.9966887301587302,0.7242063492063486,3.384126984126985,0.1736507936507936,10.215873015873017,34.44444444444444,7.871428571428573
Silver Medal,0.7677777777777778,0.0684444444444444,2.577777777777777,13.27777777777778,0.9952122222222224,0.4233333333333334,3.2672222222222214,0.3911111111111111,12.094444444444443,33.44444444444444,8.566666666666665
Bronze,0.7412562814070353,0.0765879396984924,2.7206030150753797,14.045226130653266,0.9961042713567828,0.4039195979899498,3.290753768844219,0.3751758793969849,11.465912897822443,35.02010050251256,8.872361809045225
Commended,0.6753291536050158,0.0849561128526645,2.477194357366772,15.711598746081505,0.9966150626959256,0.4974843260188096,3.318072100313484,0.2738244514106587,10.629519331243465,40.86990595611285,8.347178683385575


### Model your data.
To address the question of interest of whether we can classify wine quality based on its physicochemical properties, I explored multiple models and hyperparameters to determine the best approach for classification. Here are the steps taken:
1. Data Preparation:
- Utilized the VectorAssembler to transform the feature columns into a single features column.
- Applied StringIndexer to convert string labels into numeric indices for the classification models.
- Split the dataset into training (70%) and testing (30%) subsets.
2. Model Experimentation:
- Logistic Regression: Conducted initial experiments with Logistic Regression using 1000 iterations.
- Random Forest Classifier: Trained a Random Forest model with 225 trees.
- Both models were trained and evaluated using PySpark and tracked with MLFlow for experiment management.
3. Model Evaluation:
- Employed various evaluation metrics: accuracy, precision, recall, and F1-score.
- Used MulticlassClassificationEvaluator from PySpark for primary metric calculation.
- Converted predictions to Pandas DataFrame to use sklearn’s metrics functions for detailed evaluation.

In [None]:
import mlflow.sklearn
from pyspark.ml.feature import StringIndexer

# MLFlow can automatically logging your models.  
# Support is provided for most of the most popular libraries.
mlflow.sklearn.autolog(log_models=True)

from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler
import matplotlib.pyplot as plt
import seaborn as sns

# Import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.linear_model import LogisticRegression

# Use our 11 measurements  
feature_columns = ["fixed_acidity", "volatile_acidity", "citric_acid", "residual_sugar", "chlorides", "free_sulfur_dioxide", "total_sulfur_dioxide", "density", "pH", "sulphates", "alcohol"
]

# removed 

#input_data = red_wine_delta[feature_columns]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(red_wine_delta).select("features", col("quality_category").alias("label"))

# StringIndexer to convert the string labels into numeric indices
label_indexer = StringIndexer(inputCol="label", outputCol="indexedLabel")

# Apply StringIndexer to convert 'label' to numeric indices
data = label_indexer.fit(data).transform(data)

# split our dataset into test and training
train_data, test_data = data.randomSplit([0.7, 0.3], seed=1842)

train_data_count = train_data.count()
test_data_count = test_data.count()
print(f"Training data count: {train_data_count}")
print(f"Test data count: {test_data_count}")

Training data count: 1141
Test data count: 458


In [None]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix



# Set the experiment where I want to track my training
mlflow.set_experiment(experiment_id="2930558423655646")

# Start an MLflow run; the "with" keyword ensures we'll close the run even if this cell crashes
with mlflow.start_run() as run:

    # Train a Logistic Regression model using PySpark with 1000 iterations
    lr = LogisticRegression(featuresCol='features', labelCol='indexedLabel', maxIter=1000)
    lr_model = lr.fit(train_data)

    # Make predictions on the test data
    predictions = lr_model.transform(test_data)
    
    # Calculate metrics using PySpark
    evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", predictionCol="prediction")
    accuracy = evaluator.evaluate(predictions, {evaluator.metricName: "accuracy"})
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy)
    
    # For confusion matrix and other metrics, you need to convert predictions to Pandas DataFrame first
    # and then use sklearn's metrics functions
    predictions_pd = predictions.select("indexedLabel", "prediction").toPandas()
    y_true = predictions_pd['indexedLabel']
    y_pred = predictions_pd['prediction']
    
    # Calculate sklearn metrics
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    
    # Log sklearn metrics
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)
    mlflow.log_metric("f1_score", f1)


    # End the run
    mlflow.end_run()

    



In [None]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import mlflow
import numpy as np


# Set the experiment where you want to track your training
mlflow.set_experiment(experiment_id="2930558423655646")

# Start an MLflow run; the "with" keyword ensures we'll close the run even if this cell crashes
with mlflow.start_run() as run:
    # Train a Random Forest model using PySpark
    rf = RandomForestClassifier(featuresCol='features', labelCol='indexedLabel', numTrees=225)
    rf_model = rf.fit(train_data)
    
    # Make predictions on the test data
    predictions = rf_model.transform(test_data)
    
    # Calculate metrics using PySpark
    evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", predictionCol="prediction")
    accuracy = evaluator.evaluate(predictions, {evaluator.metricName: "accuracy"})
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy)
    
    # For confusion matrix and other metrics, you need to convert predictions to Pandas DataFrame first
    # and then use sklearn's metrics functions
    predictions_pd = predictions.select("indexedLabel", "prediction").toPandas()
    y_true = predictions_pd['indexedLabel']
    y_pred = predictions_pd['prediction']
    
    # Calculate sklearn metrics
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    
    # Log sklearn metrics
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)
    mlflow.log_metric("f1_score", f1)
    

    
    # Calculate and log confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred) 
conf_matrix_file = "confusion_matrix.txt"
np.savetxt(conf_matrix_file, conf_matrix, fmt='%d')
mlflow.log_artifact(conf_matrix_file)

mlflow.end_run()

In [None]:

id = 2930558423655646
max_results = 1000  # Maximum number of rows to retrieve
runs = mlflow.search_runs(experiment_ids=[id], max_results=max_results)

# Display dataframe in cell output
display(runs)


run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.recall,metrics.accuracy,metrics.precision,metrics.f1_score,tags.mlflow.databricks.cluster.id,tags.mlflow.databricks.notebookID,tags.mlflow.databricks.cluster.info,tags.mlflow.databricks.notebookPath,tags.mlflow.source.name,tags.mlflow.databricks.workspaceID,tags.mlflow.user,tags.mlflow.runName,tags.mlflow.databricks.workspaceURL,tags.mlflow.databricks.notebook.commandID,tags.mlflow.databricks.webappURL,tags.mlflow.databricks.notebookRevisionID,tags.sparkDatasourceInfo,tags.mlflow.source.type,tags.mlflow.databricks.cluster.libraries,tags.mlflow.autologging
53d1940069d840be862d871ec2a7522f,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/53d1940069d840be862d871ec2a7522f/artifacts,2024-07-20T05:18:24.343Z,2024-07-20T05:18:25.276Z,,,,,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,vaunted-goat-485,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7740484311477635555_13748c7e954f4e2bbb6c2c5d416b65ac,https://nvirginia.cloud.databricks.com,1721452705477,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
effbd8d9372c4f01be72a7bd990f656c,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/effbd8d9372c4f01be72a7bd990f656c/artifacts,2024-07-20T05:18:15.463Z,2024-07-20T05:18:22.872Z,0.6266375545851528,0.6266375545851528,0.5970758630468073,0.6053727815018616,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,adventurous-pig-709,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7740484311477635555_13748c7e954f4e2bbb6c2c5d416b65ac,https://nvirginia.cloud.databricks.com,1721452703113,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
1b5f0d99541e4be5ba2331302ac5c27e,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/1b5f0d99541e4be5ba2331302ac5c27e/artifacts,2024-07-20T05:18:06.393Z,2024-07-20T05:18:14.241Z,0.5829694323144105,0.5829694323144105,0.5754968787496891,0.5717880808396009,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,capricious-snake-42,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_8070824054951356846_40e59c22f86047b99f15f4afdbea9b3f,https://nvirginia.cloud.databricks.com,1721452694450,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
4c2bc703a27041b180b1c1efbb4c796d,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/4c2bc703a27041b180b1c1efbb4c796d/artifacts,2024-07-20T05:17:41.541Z,2024-07-20T05:17:50.132Z,0.5829694323144105,0.5829694323144105,0.5754968787496891,0.5717880808396009,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,sneaky-bear-489,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7998946216182976800_ffb59bcf23644127879b0c2d2193e691,https://nvirginia.cloud.databricks.com,1721452670337,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
2962d7fcd08f484fa5ffa9d74a7622ca,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/2962d7fcd08f484fa5ffa9d74a7622ca/artifacts,2024-07-20T05:17:39.603Z,2024-07-20T05:17:40.575Z,,,,,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,calm-ray-642,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_5479743039677029233_369c5f63af3a4e48a92957abbb476349,https://nvirginia.cloud.databricks.com,1721452660743,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
7b8bae5961d145cba03efab4f36aaf77,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/7b8bae5961d145cba03efab4f36aaf77/artifacts,2024-07-20T05:17:34.048Z,2024-07-20T05:17:38.295Z,0.6200873362445415,0.6200873362445415,0.5881510833324122,0.5990803691625679,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,nosy-vole-982,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_5479743039677029233_369c5f63af3a4e48a92957abbb476349,https://nvirginia.cloud.databricks.com,1721452658486,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
47403c1565ee46d3a98ba05d745731f7,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/47403c1565ee46d3a98ba05d745731f7/artifacts,2024-07-20T05:17:25.242Z,2024-07-20T05:17:33.021Z,0.5829694323144105,0.5829694323144105,0.5754968787496891,0.5717880808396009,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,judicious-mink-559,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7763270687311136877_778d1b1b4d3c477ab85e97dbd71e13cd,https://nvirginia.cloud.databricks.com,1721452653195,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
88aecdc2769b460bb05fce3271de3d3c,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/88aecdc2769b460bb05fce3271de3d3c/artifacts,2024-07-20T05:17:23.145Z,2024-07-20T05:17:24.159Z,,,,,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,rare-penguin-797,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7208100797116719440_73ffb484ed6244fe8819cae89ca39d56,https://nvirginia.cloud.databricks.com,1721452644406,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
153d91f351344111a8bbc748c74cf01f,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/153d91f351344111a8bbc748c74cf01f/artifacts,2024-07-20T05:17:15.936Z,2024-07-20T05:17:21.768Z,0.6331877729257642,0.6331877729257642,0.6036915544111321,0.6108475104108292,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,grandiose-robin-150,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_7208100797116719440_73ffb484ed6244fe8819cae89ca39d56,https://nvirginia.cloud.databricks.com,1721452642008,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",
edb74d16c0554c8491ac3d56a41cd71f,2930558423655646,FINISHED,dbfs:/databricks/mlflow-tracking/2930558423655646/edb74d16c0554c8491ac3d56a41cd71f/artifacts,2024-07-20T05:17:06.884Z,2024-07-20T05:17:14.862Z,0.5829694323144105,0.5829694323144105,0.5754968787496891,0.5717880808396009,0629-212722-foxngdfx,2930558423655646,"{""cluster_name"":""Cristian Thirteen's Personal Compute Cluster"",""spark_version"":""15.3.x-cpu-ml-scala2.12"",""node_type_id"":""i3.xlarge"",""driver_node_type_id"":""i3.xlarge"",""autotermination_minutes"":120,""disk_spec"":{},""num_workers"":0}",/Users/cthirtee@nd.edu/Homework | Week Nine,/Users/cthirtee@nd.edu/Homework | Week Nine,305542976800088,cthirtee@nd.edu,dashing-ox-545,dbc-73c959e0-3c39.cloud.databricks.com,1721425699769_6529345752306571684_e1fbb6ec09f44f539ce0b08a9f99a8cd,https://nvirginia.cloud.databricks.com,1721452635058,"path=dbfs:/databricks-datasets/adult/adult.data,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=0,format=delta path=dbfs:/databricks-datasets/wine-quality/winequality-red.csv,format=csv path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=1,format=delta path=dbfs:/tmp/w8/cthirtee/silver/wine_quality,version=4,format=delta",NOTEBOOK,"{""installable"":[],""redacted"":[]}",


### Obervations

For this experiment, I ran two different models: linear regression and random forest. For the random forest model, I used 100, 300, and 400 trees. I experimented with removing and adding various features. When dealing with red wine only, I observed that removing more features led to worse model performance.

I tracked the following metrics for the models: Recall, Accuracy, Precision, and F1 Score.

Based on the results, the top performers are:
- grandiose-robin-150: Accuracy: 0.633, F1 Score: 0.611
- capricious-sloth-19: Accuracy: 0.629, F1 Score: 0.606
- invincible-carp-66 and adventurous-pig-709: Accuracy: 0.627, F1 Score: 0.605
- nosy-vole-982: Accuracy: 0.620, F1 Score: 0.599

### How you would use this model in production.
We could use this model in one of two ways. If we were to increase the number of wines in our database, we could analyze the quality of our wine if we were a vintner. The other use case could involve using this database for wine recommendations in stores or an online shop where users would enter their preferences. If we were to use it within a retail environment, we would probably want to use both batch and online inference.

Based on the current accuracy and use case, I would focus on using batch inference. In this scenario, I would take the perspective of a winemaker. Since my model predicts the quality of wine based on its physicochemical properties, it would be an invaluable tool for a winemaker. We could run batch predictions at intervals, such as once a month or quarter, to assess the quality of new batches of wine. The ideal deployment for this model would be through MLflow, which facilitates easy model management and tracking.