In [None]:
import os
from typing import Tuple, List
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import confusion_matrix, recall_score, f1_score
import pennylane as qml
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import time
import matplotlib.pyplot as plt
import seaborn as sns
import csv
from datetime import datetime
import random
import joblib # Import joblib for saving/loading sklearn objects

# Constants
CLASS_MAP = {
    "tumor": 0,
    "no_tumor": 1
}
CLASS_MAP_NUMBER = len(CLASS_MAP)
BASE_URL= "../data/dataset_binary/"
TRAINING_URL= BASE_URL + "Training/"
TESTING_URL= BASE_URL + "Testing/"
RESULTS_GRAPHICS_URL = '../results/graphics/'
RESULTS_CSV_URL = '../results/csv/'
RESULTS_MODELS_URL = '../models'
RESULTS_PREPROCESSING_URL = '../models'

class PlotUtils:
    @staticmethod
    def plot_loss(loss_history, title='Loss function by epochs', save_path=None):
        plt.figure()
        plt.plot(loss_history, marker='o')
        plt.title(title)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_loss_vs_accuracy(loss_history, accuracy_history, title='Loss vs Accuracy by Epoch', save_path=None):
        plt.figure()
        plt.plot(loss_history, label='Loss', color='red')
        plt.plot(accuracy_history, label='Accuracy', color='blue')
        plt.xlabel('Epoch')
        plt.ylabel('Value')
        plt.title(title)
        plt.legend()
        plt.grid(True)
        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_confusion_matrix(cm, class_names=None, title='Confusion Matrix', save_path=None):
        plt.figure()
        xticks = class_names if class_names is not None else 'auto'
        yticks = class_names if class_names is not None else 'auto'
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=xticks, yticklabels=yticks)
        plt.title(title)
        plt.xlabel('Prediction')
        plt.ylabel('Real')
        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()

def prepare_data_multiclass(
    data_dir: str = TRAINING_URL,
    image_size: int = 512,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Carga, redimensiona y etiqueta imágenes de las clases para clasificación binaria.
    Devuelve (X, y) como arrays de numpy.
    """
    random.seed(seed)
    class_map = CLASS_MAP
    files_by_class = {}
    for class_name in class_map:
        class_dir = os.path.join(data_dir, class_name)
        files = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if f.lower().endswith('.jpg')]
        files_by_class[class_name] = files
    X, y = [], []
    for class_name, label in class_map.items():
        for f in files_by_class[class_name]:
            img = Image.open(f).convert('L').resize((image_size, image_size))
            X.append(np.array(img))
            y.append(label)
    X = np.stack(X)
    y = np.array(y)
    return X, y

class QuantumClassifier:
    def __init__(self, n_qubits=16, pca_features=8, batch_size=16, epochs=20, lr=0.01, layers=3, seed=42):
        self.n_qubits = n_qubits
        self.pca_features = pca_features
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = lr
        self.layers = layers
        self.seed = seed
        
        # Initialize preprocessing tools
        self.scaler = StandardScaler()
        self.pca = PCA(n_components=self.pca_features)
        self.scaler_angle = MinMaxScaler(feature_range=(0, np.pi / 2))

        torch.manual_seed(seed)
        np.random.seed(seed)
        self._prepare_data_custom()
        self._build_model()

    def _prepare_data_custom(self):
        # Prepare training data and fit transformers
        X_train_raw, self.y_train = prepare_data_multiclass(data_dir=TRAINING_URL, image_size=256, seed=self.seed)
        X_train_flat = X_train_raw.reshape((X_train_raw.shape[0], -1)) / 255.0
        X_train_scaled = self.scaler.fit_transform(X_train_flat)
        X_train_pca = self.pca.fit_transform(X_train_scaled)
        self.x_train = self.scaler_angle.fit_transform(X_train_pca)

        # Prepare testing data and apply the same transformations
        X_test_raw, self.y_test = prepare_data_multiclass(data_dir=TESTING_URL, image_size=256, seed=self.seed)
        X_test_flat = X_test_raw.reshape((X_test_raw.shape[0], -1)) / 255.0
        X_test_scaled = self.scaler.transform(X_test_flat) # Use transform, not fit_transform
        X_test_pca = self.pca.transform(X_test_scaled) # Use transform
        self.x_test = self.scaler_angle.transform(X_test_pca) # Use transform

    def _build_model(self):
        dev = qml.device("lightning.qubit", wires=self.n_qubits)
        @qml.qnode(dev)
        def qnode(inputs, weights):
            aub = self.n_qubits // 2
            qml.AngleEmbedding(inputs, wires=range(aub))
            qml.BasicEntanglerLayers(weights, wires=range(aub))
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(aub)]
        weight_shapes = {"weights": (self.layers, self.n_qubits // 2)}
        
        class HybridModel(nn.Module):
            def __init__(self, input_features=16):
                super().__init__()
                self.input_features = input_features
                self.clayer_1 = torch.nn.Linear(input_features, input_features)
                self.qlayer_1 = qml.qnn.TorchLayer(qnode, weight_shapes)
                self.qlayer_2 = qml.qnn.TorchLayer(qnode, weight_shapes)
                self.clayer_2 = torch.nn.Linear(input_features, input_features)
                self.final_layer = torch.nn.Linear(input_features, 1)

            def forward(self, x):
                x = self.clayer_1(x)
                x_1, x_2 = torch.split(x, self.input_features//2, dim=1)
                x_1 = self.qlayer_1(x_1)
                x_2 = self.qlayer_2(x_2)
                x = torch.cat([x_1, x_2], dim=1)
                x = self.clayer_2(x)
                x = self.final_layer(x)
                return x

        self.model = HybridModel(input_features=self.pca_features)

    def train_and_evaluate(self):
        x_train_t = torch.tensor(self.x_train, dtype=torch.float32)
        y_train_t = torch.tensor(self.y_train, dtype=torch.float32).unsqueeze(1)
        train_loader = DataLoader(TensorDataset(x_train_t, y_train_t), batch_size=self.batch_size, shuffle=True)

        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)
        loss_fn = nn.BCEWithLogitsLoss()

        epoch_results = []
        loss_history = []
        accuracy_history = []
        print(f"--------------- EPOCHS --------------------")

        for epoch in range(self.epochs):
            running_loss = 0.0
            self.model.train()

            for xb, yb in train_loader:
                optimizer.zero_grad()
                loss = loss_fn(self.model(xb), yb)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            avg_loss = running_loss / len(train_loader)
            loss_history.append(avg_loss)
            print(f"Average loss over epoch {epoch + 1}: {avg_loss:.4f}")

            self.model.eval()
            with torch.no_grad():
                x_test_t = torch.tensor(self.x_test, dtype=torch.float32)
                y_test_t = torch.tensor(self.y_test, dtype=torch.float32).unsqueeze(1)

                outputs = self.model(x_test_t)
                probs = torch.sigmoid(outputs)
                preds = (probs >= 0.5).int()
                correct = (preds == y_test_t.int()).sum().item()
                acc = correct / len(y_test_t)
                accuracy_history.append(acc)
                y_true_np = y_test_t.numpy()
                preds_np = preds.numpy()

                f1 = f1_score(y_true_np, preds_np)
                recall = recall_score(y_true_np, preds_np)
                cm = confusion_matrix(y_true_np, preds_np)

                print(f"Accuracy: {acc:.4f} - F1: {f1:.4f} - Recall: {recall:.4f}")
                epoch_results.append((acc, f1, recall, epoch + 1, self.model.state_dict(), avg_loss, cm))
                print("---------------------------------------------------")
        
        epoch_results.sort(key=lambda x: x[0], reverse=True)
        best_acc, best_f1, best_recall, best_ep, best_state_dict, best_loss, best_cm = epoch_results[0]

        print("\nSaving best model and preprocessing artifacts:")
        best_model_filename = self._save_model(best_state_dict, best_ep, best_acc, RESULTS_MODELS_URL)
        
        # Save the preprocessing artifacts corresponding to the best model
        self._save_preprocessing_artifacts(best_model_filename, RESULTS_PREPROCESSING_URL)


        PlotUtils.plot_loss(loss_history, save_path=os.path.join(RESULTS_GRAPHICS_URL, "loss_plot.png"))
        PlotUtils.plot_confusion_matrix(best_cm, class_names=['Tumor', 'No Tumor'],
                                        save_path=os.path.join(RESULTS_GRAPHICS_URL, "confusion_matrix.png"))
        PlotUtils.plot_loss_vs_accuracy(
            loss_history,
            accuracy_history,
            save_path=os.path.join(RESULTS_GRAPHICS_URL, "loss_vs_accuracy.png")
        )
        return {
            'best_accuracy': best_acc,
            'best_f1': best_f1,
            'best_recall': best_recall,
            'best_loss': best_loss,
            'best_epoch': best_ep,
            'confusion_matrix': best_cm.tolist(),
            'model_filename': best_model_filename
        }

    def _save_model(self, model_state_dict, epoch, accuracy, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        filename = f"BC_best_model_epoch_{epoch}_acc_{accuracy:.4f}.pt"
        filepath = os.path.join(save_dir, filename)
        torch.save(model_state_dict, filepath)
        print(f"Model saved: {filepath}")
        return filename
    
    def _save_preprocessing_artifacts(self, model_filename, save_dir):
        """
        Saves the PCA and scaler objects to a specified directory.
        The filenames are derived from the model filename to link them.
        """
        os.makedirs(save_dir, exist_ok=True)
        base_filename = os.path.splitext(model_filename)[0]

        # Save PCA model
        pca_path = os.path.join(save_dir, f"{base_filename}_pca.pkl")
        joblib.dump(self.pca, pca_path)
        print(f"PCA model saved: {pca_path}")

        # Save StandardScaler
        scaler_path = os.path.join(save_dir, f"{base_filename}_scaler.pkl")
        joblib.dump(self.scaler, scaler_path)
        print(f"StandardScaler saved: {scaler_path}")

        # Save MinMaxScaler for angles
        scaler_angle_path = os.path.join(save_dir, f"{base_filename}_scaler_angle.pkl")
        joblib.dump(self.scaler_angle, scaler_angle_path)
        print(f"MinMaxScaler (angle) saved: {scaler_angle_path}")


class ExperimentRunner:
    def __init__(self,n_qubits, epochs, lr, features, layers, batch_size, seed):
        self.n_qubits = n_qubits
        self.epochs = epochs
        self.lr = lr
        self.features = features
        self.layers = layers
        self.batch_size = batch_size
        self.seed = seed

    def csv_log(self, results, csv_file):
        duration_seconds = results.get('execution_time', None)
        log_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        best_accuracy = results['best_accuracy']
        best_epoch = results['best_epoch']
        best_f1 = results.get('best_f1', '')
        best_recall = results.get('best_recall', '')
        loss = results.get('best_loss', '')
        model_filename = results.get('model_filename', '')

        header = [
            'date', 'execution_time_seconds', 'epochs', 'learning_rate', 'features', 'layers', 'batch_size',
            'loss', 'accuracy', 'recall', 'f1_score', 'epoch', 'model_filename'
        ]
        row = [
            log_date,
            f'{duration_seconds:.2f}' if duration_seconds is not None else 'no seconds',
            self.epochs, self.lr, self.features, self.layers, self.batch_size,
            loss, best_accuracy, best_recall, best_f1, best_epoch, model_filename
        ]

        file_exists = os.path.isfile(csv_file)
        with open(csv_file, 'a', newline='') as f:
            writer = csv.writer(f)
            if not file_exists:
                writer.writerow(header)
            writer.writerow(row)

        print(f'Results saved to {csv_file}')
        print('Run summary:')
        print(row)

    def log_class_distribution(self,unique_classes, counts, data_type):
        print(f"Distribution of classes in {data_type}:")
        for cls, count in zip(unique_classes, counts):
            print(f"  Class {cls}: {count} samples")

class QuantumRunner(ExperimentRunner):
    def run(self):
        print("\n--- Running QUANTUM QuantumClassifier ---")
        qc = QuantumClassifier(
            n_qubits=self.n_qubits,
            pca_features=self.features,
            batch_size=self.batch_size,
            epochs=self.epochs,
            lr=self.lr,
            layers=self.layers,
            seed=self.seed
        )
        unique_classes_train, counts_train = np.unique(qc.y_train, return_counts=True)
        self.log_class_distribution(unique_classes_train, counts_train, "training data (y_train)")

        unique_classes_test, counts_test = np.unique(qc.y_test, return_counts=True)
        self.log_class_distribution(unique_classes_test, counts_test, "test data (y_test)")

        start_time = time.time()
        results = qc.train_and_evaluate()
        end_time = time.time()
        self.duration = end_time - start_time
        results['execution_time'] = self.duration
        self.csv_log(results, os.path.join(RESULTS_CSV_URL, 'BC_results_log.csv'))

if __name__ == "__main__":
    # =====================
    #### CONFIGURATION CONSTANTS
    # =====================
    EPOCHS = 200
    LEARNING_RATE = 0.05
    FEATURES = 12
    N_QUBITS = FEATURES
    LAYERS = 3
    BATCH_SIZE = 64
    SEED = 42

    # Run the experiment
    QuantumRunner(n_qubits=N_QUBITS, epochs=EPOCHS, lr=LEARNING_RATE, features=FEATURES, layers=LAYERS, batch_size=BATCH_SIZE, seed=SEED).run()



--- Running QUANTUM QuantumClassifier ---
Distribution of classes in training data (y_train):
  Class 0: 1998 samples
  Class 1: 2000 samples
Distribution of classes in test data (y_test):
  Class 0: 249 samples
  Class 1: 250 samples
--------------- EPOCHS --------------------
Average loss over epoch 1: 0.6959
Accuracy: 0.4990 - F1: 0.0000 - Recall: 0.0000
---------------------------------------------------
Average loss over epoch 2: 0.6925
Accuracy: 0.5892 - F1: 0.4092 - Recall: 0.2840
---------------------------------------------------
Average loss over epoch 3: 0.6917
Accuracy: 0.6092 - F1: 0.5324 - Recall: 0.4440
---------------------------------------------------
Average loss over epoch 4: 0.6912
Accuracy: 0.5371 - F1: 0.6647 - Recall: 0.9160
---------------------------------------------------
Average loss over epoch 5: 0.6907
Accuracy: 0.6232 - F1: 0.6163 - Recall: 0.6040
---------------------------------------------------
Average loss over epoch 6: 0.6901
Accuracy: 0.6192 - F1