# Mushroom

In [109]:
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from mlp import MLP, load_weights, save_weights
from sklearn.model_selection import train_test_split


In [110]:
def load_and_preprocess_data(filepath: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    data = pd.read_csv("mushrooms.csv", delimiter=",")
    X = data.drop("class", axis=1)
    y = data["class"].apply(lambda x: np.float64(1) if x == "p" else np.float64(0))
    X = pd.get_dummies(X, dtype=np.float64).fillna(0)
    X = X.drop("stalk-root_?", axis=1).values
    return X, y


In [111]:
X, y = load_and_preprocess_data("mushrooms.csv")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
input_size = X_train.shape[1]
mlp = MLP(input_size=input_size, hidden_sizes=[(input_size + 1) // 2], output_size=1)
# mlp = load_weights("mushrooms_weights.pkl")

mlp.train(X_train, y_train, learning_rate=1e-3, epochs=20)
save_weights(mlp, "mushrooms_weights.pkl")

predictions = mlp.predict(X_test) > 0.5

100%|██████████| 20/20 [00:06<00:00,  3.05it/s]


## Metrics

$$Accuracy = \frac{TP + TN}{TP + FN + TN + FP}$$
$$Precision = \frac{TP}{TP + FP}$$
$$Recall = \frac{TP}{TP + FN}$$
$$F1—score = \frac{2 \cdot precision \cdot recall}{precision + recall}$$

In [112]:
def calculate_metrics(y_true: list[float], y_pred: list[float]) -> Dict[str, float]:
    tp = sum((y_true == 1) & (y_pred == 1))
    tn = sum((y_true == 0) & (y_pred == 0))
    fp = sum((y_true == 0) & (y_pred == 1))
    fn = sum((y_true == 1) & (y_pred == 0))

    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    f1_score = (
        2 * (precision * recall) / (precision + recall) if (precision + recall) else 0
    )
    accuracy = (tp + tn) / (tp + fn + tn + fp)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-score": f1_score,
    }


metrics = calculate_metrics(y_test, predictions.flatten())

for k, v in metrics.items():
    print(f"{k}: {v}")

Accuracy: 0.9870769230769231
Precision: 1.0
Recall: 0.9726918075422627
F1-score: 0.9861568885959131
