# Evaluation

In [None]:
import sys
import os

repo_path = os.path.sep.join(os.getcwd().split(os.path.sep)[:-3])
sys.path.append(repo_path)
# resolves problem for Niklas os.chdir(repo_path)



In [None]:
import pandas as pd
from clustpy.deep.autoencoders import FeedforwardAutoencoder
from practical.DeepClustering.DeepECT.evaluation_pipeline import DatasetType, evaluate_multiple_seeds, calculate_flat_mean_for_multiple_seeds, calculate_hierarchical_mean_for_multiple_seeds




Set multiple Seeds here

In [None]:
seeds = [0, 1, 2, 3, 4]

#### MNIST

In [None]:
flat_mnist_multiple_seeds, hierarchical_mnist_multiple_seeds = evaluate_multiple_seeds(
    init_autoencoder=FeedforwardAutoencoder, dataset_type=DatasetType.MNIST, seeds=seeds
)

mean_flat_mnist = calculate_flat_mean_for_multiple_seeds(flat_mnist_multiple_seeds)
mean_hierarchical_mnist = calculate_hierarchical_mean_for_multiple_seeds(hierarchical_mnist_multiple_seeds)

#### USPS

In [None]:
flat_usps_multiple_seeds, hierarchical_usps_multiple_seeds = evaluate_multiple_seeds(
    init_autoencoder=FeedforwardAutoencoder, dataset_type=DatasetType.USPS, seeds=seeds
)

mean_flat_usps = calculate_flat_mean_for_multiple_seeds(flat_usps_multiple_seeds)
mean_hierarchical_usps = calculate_hierarchical_mean_for_multiple_seeds(hierarchical_usps_multiple_seeds)

#### FASHION MNIST

In [None]:
flat_fashion_multiple_seeds, hierarchical_fashion_multiple_seeds = evaluate_multiple_seeds(
    init_autoencoder=FeedforwardAutoencoder, dataset_type=DatasetType.FASHION_MNIST, seeds=seeds
)

mean_flat_fashion = calculate_flat_mean_for_multiple_seeds(flat_fashion_multiple_seeds)
mean_hierarchical_fashion = calculate_hierarchical_mean_for_multiple_seeds(hierarchical_fashion_multiple_seeds)

#### REUTERS

In [None]:
flat_reuters_multiple_seeds, hierarchical_reuters_multiple_seeds = evaluate_multiple_seeds(
    init_autoencoder=FeedforwardAutoencoder, dataset_type=DatasetType.REUTERS, seeds=seeds
)

mean_flat_reuters = calculate_flat_mean_for_multiple_seeds(flat_reuters_multiple_seeds)
mean_hierarchical_reuters = calculate_hierarchical_mean_for_multiple_seeds(hierarchical_reuters_multiple_seeds)

## Flat Clustering

In [None]:
flat_combined_df = pd.concat([mean_flat_mnist, mean_flat_usps, mean_flat_fashion, mean_flat_reuters], ignore_index=True)
print(flat_combined_df)
# Pivot the DataFrame to match the desired format
pivot_df = flat_combined_df.pivot(index='method', columns='dataset', values=['nmi', 'acc', 'ari',])

# Reorder the columns to match the order in the image
pivot_df = pivot_df.reindex(columns=[('nmi', DatasetType.MNIST.value), ('acc', DatasetType.MNIST.value), ('ari', DatasetType.MNIST.value),
                                     ('nmi', DatasetType.USPS.value), ('acc', DatasetType.USPS.value), ('ari', DatasetType.USPS.value),
                                     ('nmi', DatasetType.FASHION_MNIST.value), ('acc', DatasetType.FASHION_MNIST.value), ('ari', DatasetType.FASHION_MNIST.value),
                                     ('nmi', DatasetType.REUTERS.value), ('acc', DatasetType.REUTERS.value), ('ari', DatasetType.REUTERS.value)])


# For Jupyter Notebook display with better formatting
pivot_df.style.set_table_styles([
    {'selector': 'thead th', 'props': [('background-color', '#f7f7f9'), 
                                       ('color', '#333'), 
                                       ('border', '1px solid #ddd')]},
    {'selector': 'tbody tr:nth-child(even)', 'props': [('background-color', '#f9f9f9'), ('color', '#333')]},
    {'selector': 'tbody tr:nth-child(odd)', 'props': [('background-color', '#fff'), ('color', '#333') ]}
]).set_caption("Results Table")

## Hierarchical Clustering

In [None]:
hierarchical_combined = pd.concat([mean_hierarchical_mnist, mean_flat_usps, mean_hierarchical_fashion, mean_hierarchical_reuters], ignore_index=True)
# Pivot the DataFrame to match the desired format
pivot_df = hierarchical_combined.pivot(index='method', columns='dataset', values=['dp', 'lp'])

# Reorder the columns to match the order in the image
pivot_df = pivot_df.reindex(columns=[('dp', DatasetType.MNIST.value), ('lp', DatasetType.MNIST.value), 
                                     ('dp', DatasetType.USPS.value), ('lp', DatasetType.USPS.value),
                                     ('dp', DatasetType.FASHION_MNIST.value), ('lp', DatasetType.FASHION_MNIST.value),
                                     ('dp', DatasetType.REUTERS.value), ('lp', DatasetType.REUTERS.value)])


# For Jupyter Notebook display with better formatting
pivot_df.style.set_table_styles([
    {'selector': 'thead th', 'props': [('background-color', '#f7f7f9'), 
                                       ('color', '#333'), 
                                       ('border', '1px solid #ddd')]},
    {'selector': 'tbody tr:nth-child(even)', 'props': [('background-color', '#f9f9f9'), ('color', '#333')]},
    {'selector': 'tbody tr:nth-child(odd)', 'props': [('background-color', '#fff'), ('color', '#333') ]}
]).set_caption("Results Table")