In [25]:
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_curve
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

import sys
sys.path.append("../lib/")
from pathlib import Path

In [26]:
dataset = "L_540_2022_C_E_R"
results_path = Path("../results/") / dataset
data = pd.read_csv(results_path / (dataset + "_preprocessed.csv"))
data = data.drop(columns=['file'], axis=1)

In [None]:
# Encode target variable
le = LabelEncoder()
data.iloc[:,0] = le.fit_transform(data.iloc[:,0])

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(data.iloc[:,1:], data.iloc[:,0], test_size=0.2, random_state=42)

In [27]:
class MultiClassification:
    def __init__(self, method):
        self.method = method
        self.plot_data = {"roc_curve": [], "accuracy_boxplot": []}
    
    def fit(self, X, y):
        if self.method == "PCA-LDA":
            pca = PCA(n_components=2)
            X = pca.fit_transform(X)
            lda = LinearDiscriminantAnalysis()
            X = lda.fit_transform(X, y)
            self.classifier = lda
        else:
            raise ValueError("Invalid method name")

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        self.classifier.fit(self.X_train, self.y_train)

    def predict(self, X):
        if self.method == "PCA-LDA":
            X = self.classifier.transform(X)
        else:
            raise ValueError("Invalid method name")

        return self.classifier.predict(X)
    
    def evaluate(self):
        y_pred = self.predict(self.X_test)
        accuracy = accuracy_score(self.y_test, y_pred)
        self.plot_data["accuracy_boxplot"].append(accuracy)
        fpr, tpr, _ = roc_curve(self.y_test, y_pred)
        self.plot_data["roc_curve"].append((fpr, tpr))

    def plot_roc_curve(self):
        for i, roc_data in enumerate(self.plot_data["roc_curve"]):
            plt.plot(roc_data[0], roc_data[1], label=f"ROC curve {i}")
        plt.title("ROC curves")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend()
        plt.show()

    def plot_accuracy_boxplot(self):
        plt.boxplot(self.plot_data["accuracy_boxplot"])
        plt.title("Accuracy boxplot")
        plt.show()

# Example usage
iris = load_iris()
X, y = iris.data, iris.target

clf = MultiClassification("PCA-LDA")
clf.fit(X, y)
clf.evaluate()
clf.plot_roc_curve()
clf.plot_accuracy_boxplot()

ValueError: multiclass format is not supported