In [1]:
# Import required libraries
from sklearn.datasets import load_wine
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd

# Load the wine dataset
wine = load_wine()

# View the wine data keys
print(wine.keys())

# Inspect the target names
print(wine.target_names)

# Convert data and target into a DataFrame
X = pd.DataFrame(data=wine.data, columns=wine.feature_names)
y = pd.DataFrame(data=wine.target, columns=['wineType'])

# Inspect the features
print(X.head(100))

# Inspect the target
print(y.head(100))

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create an MLPClassifier
mdlMLP = MLPClassifier(random_state=42, max_iter=300)

# Train the classifier
mdlMLP.fit(X_train, y_train.values.ravel())

# Make predictions on the test set
y_pred = mdlMLP.predict(X_test)

# Calculate accuracy of the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# Generate the Confusion Matrix
print("MLPClassifier Confusion Matrix:\n", confusion_matrix(y_test, y_pred))

dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])
['class_0' 'class_1' 'class_2']
    alcohol  malic_acid   ash  alcalinity_of_ash  magnesium  total_phenols  \
0     14.23        1.71  2.43               15.6      127.0           2.80   
1     13.20        1.78  2.14               11.2      100.0           2.65   
2     13.16        2.36  2.67               18.6      101.0           2.80   
3     14.37        1.95  2.50               16.8      113.0           3.85   
4     13.24        2.59  2.87               21.0      118.0           2.80   
..      ...         ...   ...                ...        ...            ...   
95    12.47        1.52  2.20               19.0      162.0           2.50   
96    11.81        2.12  2.74               21.5      134.0           1.60   
97    12.29        1.41  1.98               16.0       85.0           2.55   
98    12.37        1.07  2.10               18.5       88.0           3.52   
99    12.29        3.17  2.21

